jaxprop 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxprop-0.3.1/LICENSE.md +21 -0
- jaxprop-0.3.1/PKG-INFO +49 -0
- jaxprop-0.3.1/README.md +22 -0
- jaxprop-0.3.1/jaxprop/__init__.py +46 -0
- jaxprop-0.3.1/jaxprop/bicubic/__init__.py +3 -0
- jaxprop-0.3.1/jaxprop/bicubic/bicubic_interpolant_property.py +69 -0
- jaxprop-0.3.1/jaxprop/bicubic/generate_tables.py +85 -0
- jaxprop-0.3.1/jaxprop/bicubic/jax_bicubic_HEOS_interpolation_1.py +292 -0
- jaxprop-0.3.1/jaxprop/components/__init__.py +3 -0
- jaxprop-0.3.1/jaxprop/components/nozzle_model_core.py +247 -0
- jaxprop-0.3.1/jaxprop/components/nozzle_model_solver.py +573 -0
- jaxprop-0.3.1/jaxprop/coolprop/__init__.py +8 -0
- jaxprop-0.3.1/jaxprop/coolprop/core_calculations.py +1596 -0
- jaxprop-0.3.1/jaxprop/coolprop/fluid_properties.py +1954 -0
- jaxprop-0.3.1/jaxprop/coolprop/jax_wrapper.py +177 -0
- jaxprop-0.3.1/jaxprop/coolprop/jax_wrapper_working.py +177 -0
- jaxprop-0.3.1/jaxprop/graphics.py +338 -0
- jaxprop-0.3.1/jaxprop/helpers_coolprop.py +196 -0
- jaxprop-0.3.1/jaxprop/helpers_jax.py +69 -0
- jaxprop-0.3.1/jaxprop/math.py +445 -0
- jaxprop-0.3.1/jaxprop/perfect_gas/__init__.py +1 -0
- jaxprop-0.3.1/jaxprop/perfect_gas/perfect_gas.py +361 -0
- jaxprop-0.3.1/jaxprop/utils.py +177 -0
- jaxprop-0.3.1/pyproject.toml +52 -0
jaxprop-0.3.1/LICENSE.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Roberto Agromayor
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
jaxprop-0.3.1/PKG-INFO
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: jaxprop
|
|
3
|
+
Version: 0.3.1
|
|
4
|
+
Summary: JAX-compatible thermodynamic property calculations.
|
|
5
|
+
Home-page: https://github.com/turbo-sim/jaxprop
|
|
6
|
+
License: MIT
|
|
7
|
+
Author: Roberto Agromayor
|
|
8
|
+
Requires-Python: >=3.11,<3.14
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Requires-Dist: CoolProp (>=7.0.0,<8.0.0)
|
|
14
|
+
Requires-Dist: diffrax (>=0.7.0,<0.8.0)
|
|
15
|
+
Requires-Dist: equinox (>=0.12.0,<0.13.0)
|
|
16
|
+
Requires-Dist: jax (>=0.6.0,<0.7.0)
|
|
17
|
+
Requires-Dist: jaxlib (>=0.6.0,<0.7.0)
|
|
18
|
+
Requires-Dist: matplotlib (>=3.10.6,<4.0.0)
|
|
19
|
+
Requires-Dist: numpy (>=2.3.0,<3.0.0)
|
|
20
|
+
Requires-Dist: optimistix (>=0.0.10,<0.0.11)
|
|
21
|
+
Requires-Dist: pysolver_view (>=0.6.8,<0.7.0)
|
|
22
|
+
Requires-Dist: scipy (>=1.16.1,<2.0.0)
|
|
23
|
+
Project-URL: Documentation, https://github.com/turbo-sim/jaxpconrop
|
|
24
|
+
Project-URL: Repository, https://github.com/turbo-sim/jaxprop
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
|
|
27
|
+
# JAXprop
|
|
28
|
+
|
|
29
|
+
`jaxprop` provides JAX-compatible thermodynamic property calculations with support for automatic differentiation, vectorization, and JIT compilation.
|
|
30
|
+
|
|
31
|
+
🔗 **Docs**: [turbo-sim.github.io/jaxprop](https://turbo-sim.github.io/jaxprop/)
|
|
32
|
+
📦 **PyPI**: [pypi.org/project/jaxprop](https://pypi.org/project/jaxprop/)
|
|
33
|
+
|
|
34
|
+
**Note**: This project is based on the [CoolProp](https://www.coolprop.org) library but is not affiliated with or endorsed by the CoolProp project.
|
|
35
|
+
|
|
36
|
+
## Key features
|
|
37
|
+
|
|
38
|
+
- Compute and plot phase envelopes and spinodal lines for pure fluids.
|
|
39
|
+
- Evaluate thermodynamic properties from Helmholtz energy–based equations of state, including metastable states inside the two-phase region.
|
|
40
|
+
- Perform flash calculations for any input pair with a custom solver and user-defined initial guesses.
|
|
41
|
+
- Work with structured property dictionaries and immutable `FluidState` objects.
|
|
42
|
+
- Evaluate properties over arrays of input conditions for efficient parametric studies and plotting.
|
|
43
|
+
- Full JAX compatibility: supports `jit`, `grad`, `vmap`, and parallel evaluation.
|
|
44
|
+
|
|
45
|
+
## Installation
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install jaxprop
|
|
49
|
+
|
jaxprop-0.3.1/README.md
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# JAXprop
|
|
2
|
+
|
|
3
|
+
`jaxprop` provides JAX-compatible thermodynamic property calculations with support for automatic differentiation, vectorization, and JIT compilation.
|
|
4
|
+
|
|
5
|
+
🔗 **Docs**: [turbo-sim.github.io/jaxprop](https://turbo-sim.github.io/jaxprop/)
|
|
6
|
+
📦 **PyPI**: [pypi.org/project/jaxprop](https://pypi.org/project/jaxprop/)
|
|
7
|
+
|
|
8
|
+
**Note**: This project is based on the [CoolProp](https://www.coolprop.org) library but is not affiliated with or endorsed by the CoolProp project.
|
|
9
|
+
|
|
10
|
+
## Key features
|
|
11
|
+
|
|
12
|
+
- Compute and plot phase envelopes and spinodal lines for pure fluids.
|
|
13
|
+
- Evaluate thermodynamic properties from Helmholtz energy–based equations of state, including metastable states inside the two-phase region.
|
|
14
|
+
- Perform flash calculations for any input pair with a custom solver and user-defined initial guesses.
|
|
15
|
+
- Work with structured property dictionaries and immutable `FluidState` objects.
|
|
16
|
+
- Evaluate properties over arrays of input conditions for efficient parametric studies and plotting.
|
|
17
|
+
- Full JAX compatibility: supports `jit`, `grad`, `vmap`, and parallel evaluation.
|
|
18
|
+
|
|
19
|
+
## Installation
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
pip install jaxprop
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Highlight exception messages
|
|
2
|
+
# https://stackoverflow.com/questions/25109105/how-to-colorize-the-output-of-python-errors-in-the-gnome-terminal/52797444#52797444
|
|
3
|
+
try:
|
|
4
|
+
import IPython.core.ultratb
|
|
5
|
+
except ImportError:
|
|
6
|
+
# No IPython. Use default exception printing.
|
|
7
|
+
pass
|
|
8
|
+
else:
|
|
9
|
+
import sys
|
|
10
|
+
sys.excepthook = IPython.core.ultratb.FormattedTB(color_scheme='linux', call_pdb=False)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
|
15
|
+
import jax
|
|
16
|
+
jax.config.update("jax_enable_x64", True)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
from .graphics import *
|
|
21
|
+
from .utils import *
|
|
22
|
+
from .helpers_jax import *
|
|
23
|
+
from .helpers_coolprop import *
|
|
24
|
+
|
|
25
|
+
# Import subpackages
|
|
26
|
+
from . import coolprop
|
|
27
|
+
from . import perfect_gas
|
|
28
|
+
# from . import bicubic
|
|
29
|
+
|
|
30
|
+
# Import API classes
|
|
31
|
+
from .perfect_gas import FluidPerfectGas
|
|
32
|
+
from .coolprop import Fluid, FluidJAX
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
from . import components
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Package info
|
|
39
|
+
__version__ = "0.3.1"
|
|
40
|
+
PACKAGE_NAME = "jaxprop"
|
|
41
|
+
URL_GITHUB = "https://github.com/turbo-sim/jaxprop"
|
|
42
|
+
URL_DOCS = "https://turbo-sim.github.io/jaxprop/"
|
|
43
|
+
URL_PYPI = "https://pypi.org/project/jaxprop/"
|
|
44
|
+
URL_DTU = "https://thermalpower.dtu.dk/"
|
|
45
|
+
BREAKLINE = 80 * "-"
|
|
46
|
+
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
# import jax.numpy as jnp
|
|
3
|
+
from ..jax_import import jnp
|
|
4
|
+
import pickle
|
|
5
|
+
|
|
6
|
+
from .jax_bicubic_HEOS_interpolation_1 import compute_bicubic_coefficients_of_ij, bicubic_interpolant
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def bicubic_interpolant_property(h, P, table):
|
|
10
|
+
"""
|
|
11
|
+
Interpolates all properties at a given enthalpy (h) and pressure (P) using bicubic interpolation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
h (float): Enthalpy in J/kg
|
|
15
|
+
P (float): Pressure in Pa
|
|
16
|
+
table_path (str): Path to saved property table (dict-based .pkl)
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
dict: Dictionary of interpolated property values at (h, P)
|
|
20
|
+
"""
|
|
21
|
+
h_vals = jnp.array(table['h'])
|
|
22
|
+
P_vals = jnp.array(table['P'])
|
|
23
|
+
|
|
24
|
+
Nh, Np = len(h_vals), len(P_vals)
|
|
25
|
+
hmin, hmax = float(h_vals[0]), float(h_vals[-1])
|
|
26
|
+
Lmin, Lmax = float(jnp.log(P_vals[0])), float(jnp.log(P_vals[-1]))
|
|
27
|
+
|
|
28
|
+
logP = jnp.log(P)
|
|
29
|
+
|
|
30
|
+
# Identify cell (i, j)
|
|
31
|
+
i = int((h - hmin) / (hmax - hmin) * (Nh - 1))
|
|
32
|
+
j = int((logP - Lmin) / (Lmax - Lmin) * (Np - 1))
|
|
33
|
+
|
|
34
|
+
i = np.clip(i, 0, Nh - 2)
|
|
35
|
+
j = np.clip(j, 0, Np - 2)
|
|
36
|
+
|
|
37
|
+
deltah = float(h_vals[1] - h_vals[0])
|
|
38
|
+
deltaL = float(jnp.log(P_vals[1]) - jnp.log(P_vals[0]))
|
|
39
|
+
|
|
40
|
+
interpolated_props = {}
|
|
41
|
+
|
|
42
|
+
for prop, prop_data in table.items():
|
|
43
|
+
if prop in ['h', 'P']:
|
|
44
|
+
continue # skip grid axes
|
|
45
|
+
|
|
46
|
+
f_grid = prop_data['value']
|
|
47
|
+
fx_grid = prop_data['d_dh']
|
|
48
|
+
fy_grid = prop_data['d_dP']
|
|
49
|
+
fxy_grid = prop_data['d2_dhdP']
|
|
50
|
+
|
|
51
|
+
coeffs_local = compute_bicubic_coefficients_of_ij(
|
|
52
|
+
i, j,
|
|
53
|
+
f_grid,
|
|
54
|
+
fx_grid * deltah,
|
|
55
|
+
fy_grid * deltaL,
|
|
56
|
+
fxy_grid * deltah * deltaL
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
coeffs = jnp.zeros((Nh, Np, 16), dtype=jnp.float64)
|
|
60
|
+
coeffs = coeffs.at[i, j, :].set(coeffs_local)
|
|
61
|
+
|
|
62
|
+
val = bicubic_interpolant(
|
|
63
|
+
h, P, h_vals, jnp.log(P_vals), coeffs,
|
|
64
|
+
Nh, Np, hmin, hmax, Lmin, Lmax
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
interpolated_props[prop] = float(val)
|
|
68
|
+
|
|
69
|
+
return interpolated_props
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from ..jax_import import jnp
|
|
4
|
+
# import jax.numpy as jnp
|
|
5
|
+
import CoolProp.CoolProp as cp
|
|
6
|
+
# import turboflow as tf
|
|
7
|
+
import os
|
|
8
|
+
import pickle
|
|
9
|
+
from jaxprop.fluid_properties import Fluid
|
|
10
|
+
|
|
11
|
+
def generate_property_table(hmin, hmax, Pmin, Pmax, fluid_name, Nh, Np, outdir='fluid_tables'):
|
|
12
|
+
fluid = Fluid(fluid_name)
|
|
13
|
+
# fluid = tf.Fluid(fluid_name)
|
|
14
|
+
h_vals = jnp.linspace(hmin, hmax, Nh)
|
|
15
|
+
Lmin=jnp.log(Pmin) # Log of P
|
|
16
|
+
Lmax=jnp.log(Pmax) # Log of P
|
|
17
|
+
P_vals = jnp.linspace(Lmin, Lmax, Np)
|
|
18
|
+
|
|
19
|
+
deltah = float(h_vals[1] - h_vals[0])
|
|
20
|
+
deltaL = P_vals[1]-P_vals[0]
|
|
21
|
+
eps_h = 0.001 * deltah
|
|
22
|
+
eps_P = 1e-6 * float(Pmin)
|
|
23
|
+
|
|
24
|
+
properties = {
|
|
25
|
+
'T': 'T', # Temperature [K]
|
|
26
|
+
'd': 'D', # Density [kg/m³]
|
|
27
|
+
's': 'S', # Entropy [J/kg/K]
|
|
28
|
+
'mu': 'V', # Viscosity [Pa·s]
|
|
29
|
+
'k': 'L', # Thermal conductivity [W/m/K]
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
# Initialize property grids
|
|
33
|
+
table = {
|
|
34
|
+
'h': np.array(h_vals),
|
|
35
|
+
'P': np.array(jnp.exp(P_vals))
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
for key in properties:
|
|
39
|
+
table[key] = {
|
|
40
|
+
'value': np.zeros((Nh, Np), dtype=np.float64),
|
|
41
|
+
'd_dh': np.zeros((Nh, Np), dtype=np.float64),
|
|
42
|
+
'd_dP': np.zeros((Nh, Np), dtype=np.float64),
|
|
43
|
+
'd2_dhdP': np.zeros((Nh, Np), dtype=np.float64),
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
# Loop over grid and populate values
|
|
47
|
+
for i, h in enumerate(h_vals):
|
|
48
|
+
for j, P in enumerate(P_vals):
|
|
49
|
+
hf = float(h)
|
|
50
|
+
Pf = jnp.exp(P)
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
f_0 = fluid.get_state(cp.HmassP_INPUTS, float(hf), float(Pf))
|
|
54
|
+
f_h = fluid.get_state(cp.HmassP_INPUTS, float(hf + eps_h), float(Pf))
|
|
55
|
+
f_p = fluid.get_state(cp.HmassP_INPUTS, float(hf), float(Pf + eps_P))
|
|
56
|
+
f_hp = fluid.get_state(cp.HmassP_INPUTS, float(hf + eps_h), float(Pf + eps_P))
|
|
57
|
+
# Uncomment below if using turboflow
|
|
58
|
+
# f_0 = tf.get_props_custom_jvp(fluid, cp.HmassP_INPUTS, hf, Pf)
|
|
59
|
+
# f_h = tf.get_props_custom_jvp(fluid, cp.HmassP_INPUTS, hf + eps_h, Pf)
|
|
60
|
+
# f_p = tf.get_props_custom_jvp(fluid, cp.HmassP_INPUTS, hf, Pf + eps_P)
|
|
61
|
+
# f_hp = tf.get_props_custom_jvp(fluid, cp.HmassP_INPUTS, hf + eps_h, Pf + eps_P)
|
|
62
|
+
except Exception:
|
|
63
|
+
continue # Skip invalid points
|
|
64
|
+
|
|
65
|
+
for key in properties:
|
|
66
|
+
val = f_0[key]
|
|
67
|
+
dval_dh = (f_h[key] - f_0[key]) / eps_h
|
|
68
|
+
dval_dP = (f_p[key] - f_0[key]) / eps_P
|
|
69
|
+
d2val_dhdP = (f_hp[key] - f_h[key] - f_p[key] + f_0[key]) / (eps_h * eps_P)
|
|
70
|
+
|
|
71
|
+
table[key]['value'][i, j] = val
|
|
72
|
+
table[key]['d_dh'][i, j] = dval_dh
|
|
73
|
+
table[key]['d_dP'][i, j] = dval_dP
|
|
74
|
+
table[key]['d2_dhdP'][i, j] = d2val_dhdP
|
|
75
|
+
|
|
76
|
+
# Save as pickle only (most useful for JAX processing)
|
|
77
|
+
os.makedirs(outdir, exist_ok=True)
|
|
78
|
+
pkl_path = os.path.join(outdir, f'{fluid_name}_{Nh}_x_{Np}.pkl')
|
|
79
|
+
|
|
80
|
+
with open(pkl_path, 'wb') as f:
|
|
81
|
+
pickle.dump(table, f)
|
|
82
|
+
|
|
83
|
+
print(f" Saved the table to:\n -> Pickle: {pkl_path}")
|
|
84
|
+
|
|
85
|
+
return table
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# ======================== Imports ========================
|
|
2
|
+
# import os
|
|
3
|
+
# import psutil
|
|
4
|
+
import time
|
|
5
|
+
# import jax
|
|
6
|
+
# import jax.numpy as jnp
|
|
7
|
+
# from jax import jit
|
|
8
|
+
|
|
9
|
+
from ..jax_import import jax, jnp, jit
|
|
10
|
+
# from jax.experimental import mesh_utils
|
|
11
|
+
# from jax.sharding import Mesh, PartitionSpec, NamedSharding
|
|
12
|
+
import CoolProp.CoolProp as cp
|
|
13
|
+
import numpy as np
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
from functools import partial
|
|
16
|
+
|
|
17
|
+
# ======================== Config ========================
|
|
18
|
+
# NCORES = psutil.cpu_count(logical=False)
|
|
19
|
+
# os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={NCORES}"
|
|
20
|
+
# jax.config.update("jax_enable_x64", True)
|
|
21
|
+
|
|
22
|
+
# Global precision
|
|
23
|
+
float64 = jnp.dtype("float64")
|
|
24
|
+
complex128 = jnp.dtype("complex128")
|
|
25
|
+
|
|
26
|
+
# =================== Functions to Export ===================
|
|
27
|
+
# @jax.jit
|
|
28
|
+
def compute_bicubic_coefficients_of_ij(i, j, f, fx, fy, fxy):
|
|
29
|
+
#xx=f(0,0)&f(1,0)&f(0,1)&f(1,1)&f_x(0,0)&f_x(1,0)&f_x(0,1)&f_x(1,1)&f_y(0,0)&f_y(1,0)&f_y(0,1)&f_y(1,1)&f_{xy}(0,0)&f_{xy}(1,0)&f_{xy}(0,1)&f_{xy}(1,1)
|
|
30
|
+
xx=[f[i,j],f[i+1,j],f[i,j+1],f[i+1,j+1],fx[i,j],fx[i+1,j],fx[i,j+1],fx[i+1,j+1],fy[i,j],fy[i+1,j],fy[i,j+1],fy[i+1,j+1],fxy[i,j],fxy[i+1,j],fxy[i,j+1],fxy[i+1,j+1]]
|
|
31
|
+
A=[ [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0 ],
|
|
32
|
+
[ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0 ],
|
|
33
|
+
[ -3., 3., 0., 0., -2., -1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0 ],
|
|
34
|
+
[ 2., -2., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0 ],
|
|
35
|
+
[ 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0 ],
|
|
36
|
+
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0 ],
|
|
37
|
+
[ 0., 0., 0., 0., 0., 0., 0., 0., -3., 3., 0., 0., -2., -1., 0., 0 ],
|
|
38
|
+
[ 0., 0., 0., 0., 0., 0., 0., 0., 2., -2., 0., 0., 1., 1., 0., 0 ],
|
|
39
|
+
[ -3., 0., 3., 0., 0., 0., 0., 0., -2., 0., -1., 0., 0., 0., 0., 0 ],
|
|
40
|
+
[ 0., 0., 0., 0., -3., 0., 3., 0., 0., 0., 0., 0., -2., 0., -1., 0 ],
|
|
41
|
+
[ 9., -9., -9., 9., 6., 3., -6., -3., 6., -6., 3., -3., 4., 2., 2., 1 ],
|
|
42
|
+
[ -6., 6., 6., -6., -3., -3., 3., 3., -4., 4., -2., 2., -2., -2., -1., -1 ],
|
|
43
|
+
[ 2., 0., -2., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0 ],
|
|
44
|
+
[ 0., 0., 0., 0., 2., 0., -2., 0., 0., 0., 0., 0., 1., 0., 1., 0 ],
|
|
45
|
+
[ -6., 6., 6., -6., -4., -2., 4., 2., -3., 3., -3., 3., -2., -1., -2., -1 ],
|
|
46
|
+
[ 4., -4., -4., 4., 2., 2., -2., -2., 2., -2., 2., -2., 1., 1., 1., 1]]
|
|
47
|
+
return jnp.matmul(jnp.array(A,dtype=f.dtype),jnp.array(xx,dtype=f.dtype))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# def bicubic_interpolant(h_vals, P_vals, coeffs, hmin, hmax, Lmin, Lmax, Nh, Np):
|
|
51
|
+
# """
|
|
52
|
+
# Create a bicubic interpolant function using precomputed coefficients.
|
|
53
|
+
# """
|
|
54
|
+
# # Ensure that coeffs is in the right shape (Nh, Np, 16)
|
|
55
|
+
# assert coeffs.shape == (Nh, Np, 16), f"Expected coeffs to have shape (Nh, Np, 16), but got {coeffs.shape}"
|
|
56
|
+
|
|
57
|
+
# # Normalized grid values
|
|
58
|
+
# def normalize(value, min_val, max_val):
|
|
59
|
+
# return (value - min_val) / (max_val - min_val)
|
|
60
|
+
|
|
61
|
+
# # The actual interpolant function, it takes h and P as arguments
|
|
62
|
+
# @jit
|
|
63
|
+
# def interpolant_fn(h, P):
|
|
64
|
+
# # Ensure h_vals and P_vals are 1D arrays
|
|
65
|
+
# h_vals_flat = jnp.ravel(h_vals) # Flatten the h_vals to 1D
|
|
66
|
+
# P_vals_flat = jnp.ravel(P_vals) # Flatten the P_vals to 1D
|
|
67
|
+
|
|
68
|
+
# # Normalize h and P
|
|
69
|
+
# norm_h = normalize(h, hmin, hmax)
|
|
70
|
+
# norm_P = normalize(P, Lmin, Lmax)
|
|
71
|
+
|
|
72
|
+
# # Identify the grid points surrounding h, P
|
|
73
|
+
# i = jnp.clip(jnp.searchsorted(h_vals_flat, h) - 1, 0, Nh - 2)
|
|
74
|
+
# j = jnp.clip(jnp.searchsorted(P_vals_flat, P) - 1, 0, Np - 2)
|
|
75
|
+
|
|
76
|
+
# # Extract the coefficients for the surrounding grid points
|
|
77
|
+
# coeff = coeffs[i, j]
|
|
78
|
+
|
|
79
|
+
# # Interpolate in both directions
|
|
80
|
+
# h_diff = norm_h - h_vals_flat[i]
|
|
81
|
+
# P_diff = norm_P - P_vals_flat[j]
|
|
82
|
+
|
|
83
|
+
# # Calculate the interpolant using the bicubic coefficients
|
|
84
|
+
# result = (
|
|
85
|
+
# coeff[0] + coeff[1] * h_diff + coeff[2] * P_diff + coeff[3] * h_diff * P_diff +
|
|
86
|
+
# coeff[4] * h_diff**2 + coeff[5] * h_diff * P_diff**2 +
|
|
87
|
+
# coeff[6] * P_diff**2 + coeff[7] * h_diff**2 * P_diff +
|
|
88
|
+
# coeff[8] * h_diff**3 + coeff[9] * h_diff**2 * P_diff +
|
|
89
|
+
# coeff[10] * h_diff * P_diff**2 + coeff[11] * h_diff**3 * P_diff +
|
|
90
|
+
# coeff[12] * P_diff**3 + coeff[13] * h_diff * P_diff**3 +
|
|
91
|
+
# coeff[14] * h_diff**3 * P_diff**2 + coeff[15] * h_diff**2 * P_diff**3
|
|
92
|
+
# )
|
|
93
|
+
|
|
94
|
+
# return result
|
|
95
|
+
|
|
96
|
+
# return interpolant_fn
|
|
97
|
+
|
|
98
|
+
# @partial(jit, static_argnums=(5, 6)) # Nh and Np are static # static arguments are all except h, P
|
|
99
|
+
# def bicubic_interpolant(h, P, h_vals, P_vals, coeffs, Nh, Np, hmin, hmax, Lmin, Lmax):
|
|
100
|
+
# """
|
|
101
|
+
# Evaluate the bicubic interpolant at (h, P) using precomputed coefficients.
|
|
102
|
+
# """
|
|
103
|
+
# # Normalize
|
|
104
|
+
# norm_h = (h - hmin) / (hmax - hmin)
|
|
105
|
+
# norm_P = (P - Lmin) / (Lmax - Lmin)
|
|
106
|
+
|
|
107
|
+
# # Flatten grid arrays
|
|
108
|
+
# h_vals_flat = jnp.ravel(h_vals)
|
|
109
|
+
# P_vals_flat = jnp.ravel(P_vals)
|
|
110
|
+
|
|
111
|
+
# # Find surrounding grid indices
|
|
112
|
+
# i = jnp.clip(jnp.searchsorted(h_vals_flat, h) - 1, 0, Nh - 2)
|
|
113
|
+
# j = jnp.clip(jnp.searchsorted(P_vals_flat, P) - 1, 0, Np - 2)
|
|
114
|
+
|
|
115
|
+
# # Relative differences in normalized space
|
|
116
|
+
# h_base = (h_vals_flat[i] - hmin) / (hmax - hmin)
|
|
117
|
+
# P_base = (P_vals_flat[j] - Lmin) / (Lmax - Lmin)
|
|
118
|
+
# h_diff = norm_h - h_base
|
|
119
|
+
# P_diff = norm_P - P_base
|
|
120
|
+
|
|
121
|
+
# # Fetch bicubic coefficients
|
|
122
|
+
# coeff = coeffs[i, j]
|
|
123
|
+
|
|
124
|
+
# # Evaluate bicubic polynomial
|
|
125
|
+
# result = (
|
|
126
|
+
# coeff[0] + coeff[1] * h_diff + coeff[2] * P_diff + coeff[3] * h_diff * P_diff +
|
|
127
|
+
# coeff[4] * h_diff**2 + coeff[5] * h_diff * P_diff**2 +
|
|
128
|
+
# coeff[6] * P_diff**2 + coeff[7] * h_diff**2 * P_diff +
|
|
129
|
+
# coeff[8] * h_diff**3 + coeff[9] * h_diff**2 * P_diff +
|
|
130
|
+
# coeff[10] * h_diff * P_diff**2 + coeff[11] * h_diff**3 * P_diff +
|
|
131
|
+
# coeff[12] * P_diff**3 + coeff[13] * h_diff * P_diff**3 +
|
|
132
|
+
# coeff[14] * h_diff**3 * P_diff**2 + coeff[15] * h_diff**2 * P_diff**3
|
|
133
|
+
# )
|
|
134
|
+
|
|
135
|
+
# return result
|
|
136
|
+
|
|
137
|
+
# @partial(jit, static_argnums=(5, 6))
|
|
138
|
+
def bicubic_interpolant(h, P, h_vals, P_vals, coeffs, Nh, Np, hmin, hmax, Lmin, Lmax):
|
|
139
|
+
"""
|
|
140
|
+
Evaluate the bicubic interpolant at (h, P) using precomputed coefficients.
|
|
141
|
+
"""
|
|
142
|
+
# Log-transform P
|
|
143
|
+
L = jnp.log(P)
|
|
144
|
+
|
|
145
|
+
# Normalize positions to [0, 1] cell coordinates
|
|
146
|
+
ii = ((h - hmin) / (hmax - hmin) * (Nh - 1))
|
|
147
|
+
i = ii.astype(int)
|
|
148
|
+
x = ii - i
|
|
149
|
+
|
|
150
|
+
jj = ((L - Lmin) / (Lmax - Lmin) * (Np - 1))
|
|
151
|
+
j = jj.astype(int)
|
|
152
|
+
y = jj - j
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# Evaluate bicubic polynomial
|
|
156
|
+
result = jnp.zeros_like(h) # use h shape
|
|
157
|
+
x_pow = jnp.ones_like(h) # x^0
|
|
158
|
+
|
|
159
|
+
for m in range(4): # m = x power
|
|
160
|
+
y_pow = jnp.ones_like(h) # y^0 initially
|
|
161
|
+
for n in range(4): # n = y power
|
|
162
|
+
c = coeffs[i, j, 4 * n + m]
|
|
163
|
+
result += c * x_pow * y_pow
|
|
164
|
+
y_pow = y_pow * y
|
|
165
|
+
x_pow = x_pow * x
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
return result
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# @jax.jit
|
|
173
|
+
def inverse_interpolant_scalar_hD(h, D):
|
|
174
|
+
#Find the real(float) index
|
|
175
|
+
ii=((h-hmin)/(hmax-hmin)*(N-1))
|
|
176
|
+
#The integer part is the cell index
|
|
177
|
+
i=ii.astype(int)
|
|
178
|
+
#The remainder (for numerical stability better to use the difference)
|
|
179
|
+
#is instead the position within our interpolation cell.
|
|
180
|
+
x=ii-i
|
|
181
|
+
#find interval that contains the solution
|
|
182
|
+
xth=jnp.ones_like(h) #initialize x to the 0th power
|
|
183
|
+
#First we compute the nodal values, that is the values of D(h,P) where
|
|
184
|
+
#h is the actual enthalpy and P are grid values.
|
|
185
|
+
#TODO: instead of computing all the nodal values and then use sortedsearch
|
|
186
|
+
#to find the correct interval, we could do a binary search. This would
|
|
187
|
+
#constraint M to be a power of 2.
|
|
188
|
+
#Possible example (to be refined) to compute the node. Start with the node
|
|
189
|
+
#corresponding to j=M/2, then compute new index j=j+M/4*(2*(Dj>D)-1)
|
|
190
|
+
#then j=j+M/8*(2*(Dj>D)-1) and so on ...
|
|
191
|
+
#after log2(M) iteration we converged to the index j.
|
|
192
|
+
D_nodal=jnp.zeros(M)
|
|
193
|
+
for m in range(4):
|
|
194
|
+
D_nodal+=bicubic_coefficients[i,:,m]*xth
|
|
195
|
+
xth=xth*x
|
|
196
|
+
#We search more efficiently in which interval we have the solution
|
|
197
|
+
#if we assume a sorted vector.
|
|
198
|
+
#TODO: This assumes that P has a monotonic trend with respect to D
|
|
199
|
+
#at fixed h. This causes some problems and needs further investigation
|
|
200
|
+
if iD==cp.iSmass:
|
|
201
|
+
j=jax.numpy.searchsorted(-D_nodal,-D).astype(int)-1
|
|
202
|
+
else:
|
|
203
|
+
j=jax.numpy.searchsorted(D_nodal,D).astype(int)-1
|
|
204
|
+
|
|
205
|
+
#After we are in the unit square, that is for known i and j
|
|
206
|
+
#compute 1D cubic coefficients (as complex numbers to avoid promotion)
|
|
207
|
+
#Each coefficient is bj=sum(aij*x**i)
|
|
208
|
+
#Leading to the equation D=b0 + b1*y + b2*y**2 + b3*y**3
|
|
209
|
+
xth=jnp.ones_like(h)
|
|
210
|
+
b0 =jnp.zeros_like(h,dtype=complex128)
|
|
211
|
+
b1 =jnp.zeros_like(h,dtype=complex128)
|
|
212
|
+
b2 =jnp.zeros_like(h,dtype=complex128)
|
|
213
|
+
b3 =jnp.zeros_like(h,dtype=complex128)
|
|
214
|
+
for m in range(4):
|
|
215
|
+
b0 +=bicubic_coefficients[i,j,4*0+m]*xth
|
|
216
|
+
b1 +=bicubic_coefficients[i,j,4*1+m]*xth
|
|
217
|
+
b2 +=bicubic_coefficients[i,j,4*2+m]*xth
|
|
218
|
+
b3 +=bicubic_coefficients[i,j,4*3+m]*xth
|
|
219
|
+
xth=xth*x
|
|
220
|
+
#solve cubic equation - all three solutions
|
|
221
|
+
#TODO: if necessary, add solution for degenerate (quadratic and linear)
|
|
222
|
+
#For more information:https://en.wikipedia.org/wiki/Cubic_equation#General_cubic_formula
|
|
223
|
+
D0=b2*b2-3*b3*b1
|
|
224
|
+
D1=2*b2*b2*b2-9*b3*b2*b1+27*b3*b3*(b0-D)
|
|
225
|
+
C=((D1+(D1*D1-4*D0*D0*D0)**0.5)/2)**(1/3)
|
|
226
|
+
D0C=jax.lax.select(C==(0+0j),0+0j,D0/C)
|
|
227
|
+
z=jnp.array([1,-0.5+0.8660254037844386j,-0.5-0.8660254037844386j])
|
|
228
|
+
y=-1/(3*b3)*(b2+C*z+D0C/z)
|
|
229
|
+
#To find our solution we have two criteria:
|
|
230
|
+
# -0 imaginary part
|
|
231
|
+
# -real part between 0 and 1, that are the bounds of our cell
|
|
232
|
+
# We define a "badness" as the deviation from these critera, and pick the
|
|
233
|
+
# solution with the lowest badness
|
|
234
|
+
badness=jax.nn.relu(4*(jnp.real(y)-0.5)**2-1)+jnp.imag(y)**2
|
|
235
|
+
yreal=jnp.real(y[jnp.argmin(badness)])
|
|
236
|
+
jj=j+yreal
|
|
237
|
+
L=Lmin+jj*(Lmax-Lmin)/(M-1)
|
|
238
|
+
P=jnp.exp(L)
|
|
239
|
+
return P
|
|
240
|
+
|
|
241
|
+
# @jax.jit
|
|
242
|
+
def inverse_interpolant_scalar_DP(D, P):
|
|
243
|
+
# Convert pressure to log space
|
|
244
|
+
L = jnp.log(P)
|
|
245
|
+
|
|
246
|
+
# Compute index along pressure grid
|
|
247
|
+
jj = ((L - Lmin) / (Lmax - Lmin) * (M - 1))
|
|
248
|
+
j = jj.astype(int)
|
|
249
|
+
y = jj - j # fractional position in pressure direction
|
|
250
|
+
|
|
251
|
+
# Compute nodal D(h) values at fixed pressure (we'll search h index now)
|
|
252
|
+
yth = jnp.ones_like(D)
|
|
253
|
+
D_nodal = jnp.zeros(N)
|
|
254
|
+
for m in range(4):
|
|
255
|
+
D_nodal += bicubic_coefficients[:, j, m] * yth
|
|
256
|
+
yth = yth * y
|
|
257
|
+
|
|
258
|
+
# Search h-direction to find which cell to use
|
|
259
|
+
if iD == cp.iSmass:
|
|
260
|
+
i = jnp.searchsorted(-D_nodal, -D).astype(int) - 1
|
|
261
|
+
else:
|
|
262
|
+
i = jnp.searchsorted(D_nodal, D).astype(int) - 1
|
|
263
|
+
|
|
264
|
+
# Now build 1D cubic in x (h-direction) at fixed j
|
|
265
|
+
yth = jnp.ones_like(D)
|
|
266
|
+
b0 = jnp.zeros_like(D, dtype=complex128)
|
|
267
|
+
b1 = jnp.zeros_like(D, dtype=complex128)
|
|
268
|
+
b2 = jnp.zeros_like(D, dtype=complex128)
|
|
269
|
+
b3 = jnp.zeros_like(D, dtype=complex128)
|
|
270
|
+
for m in range(4):
|
|
271
|
+
b0 += bicubic_coefficients[i, j, m + 4*0] * yth
|
|
272
|
+
b1 += bicubic_coefficients[i, j, m + 4*1] * yth
|
|
273
|
+
b2 += bicubic_coefficients[i, j, m + 4*2] * yth
|
|
274
|
+
b3 += bicubic_coefficients[i, j, m + 4*3] * yth
|
|
275
|
+
yth = yth * y
|
|
276
|
+
|
|
277
|
+
# Solve cubic: D = b0 + b1*x + b2*x^2 + b3*x^3
|
|
278
|
+
D0 = b2*b2 - 3*b3*b1
|
|
279
|
+
D1 = 2*b2**3 - 9*b3*b2*b1 + 27*b3**2*(b0 - D)
|
|
280
|
+
C = ((D1 + jnp.sqrt(D1**2 - 4*D0**3)) / 2)**(1/3)
|
|
281
|
+
D0C = jax.lax.select(C == 0, 0 + 0j, D0 / C)
|
|
282
|
+
z = jnp.array([1, -0.5 + 0.8660254037844386j, -0.5 - 0.8660254037844386j])
|
|
283
|
+
x = -1/(3*b3)*(b2 + C*z + D0C/z)
|
|
284
|
+
|
|
285
|
+
# Pick root with lowest badness
|
|
286
|
+
badness = jax.nn.relu(4*(jnp.real(x)-0.5)**2 - 1) + jnp.imag(x)**2
|
|
287
|
+
xreal = jnp.real(x[jnp.argmin(badness)])
|
|
288
|
+
|
|
289
|
+
# Final result: compute h from i + x
|
|
290
|
+
ii = i + xreal
|
|
291
|
+
h = hmin + ii * (hmax - hmin) / (N - 1)
|
|
292
|
+
return h
|