jaxprop 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Roberto Agromayor
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
jaxprop-0.3.1/PKG-INFO ADDED
@@ -0,0 +1,49 @@
1
+ Metadata-Version: 2.1
2
+ Name: jaxprop
3
+ Version: 0.3.1
4
+ Summary: JAX-compatible thermodynamic property calculations.
5
+ Home-page: https://github.com/turbo-sim/jaxprop
6
+ License: MIT
7
+ Author: Roberto Agromayor
8
+ Requires-Python: >=3.11,<3.14
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Requires-Dist: CoolProp (>=7.0.0,<8.0.0)
14
+ Requires-Dist: diffrax (>=0.7.0,<0.8.0)
15
+ Requires-Dist: equinox (>=0.12.0,<0.13.0)
16
+ Requires-Dist: jax (>=0.6.0,<0.7.0)
17
+ Requires-Dist: jaxlib (>=0.6.0,<0.7.0)
18
+ Requires-Dist: matplotlib (>=3.10.6,<4.0.0)
19
+ Requires-Dist: numpy (>=2.3.0,<3.0.0)
20
+ Requires-Dist: optimistix (>=0.0.10,<0.0.11)
21
+ Requires-Dist: pysolver_view (>=0.6.8,<0.7.0)
22
+ Requires-Dist: scipy (>=1.16.1,<2.0.0)
23
+ Project-URL: Documentation, https://github.com/turbo-sim/jaxpconrop
24
+ Project-URL: Repository, https://github.com/turbo-sim/jaxprop
25
+ Description-Content-Type: text/markdown
26
+
27
+ # JAXprop
28
+
29
+ `jaxprop` provides JAX-compatible thermodynamic property calculations with support for automatic differentiation, vectorization, and JIT compilation.
30
+
31
+ 🔗 **Docs**: [turbo-sim.github.io/jaxprop](https://turbo-sim.github.io/jaxprop/)
32
+ 📦 **PyPI**: [pypi.org/project/jaxprop](https://pypi.org/project/jaxprop/)
33
+
34
+ **Note**: This project is based on the [CoolProp](https://www.coolprop.org) library but is not affiliated with or endorsed by the CoolProp project.
35
+
36
+ ## Key features
37
+
38
+ - Compute and plot phase envelopes and spinodal lines for pure fluids.
39
+ - Evaluate thermodynamic properties from Helmholtz energy–based equations of state, including metastable states inside the two-phase region.
40
+ - Perform flash calculations for any input pair with a custom solver and user-defined initial guesses.
41
+ - Work with structured property dictionaries and immutable `FluidState` objects.
42
+ - Evaluate properties over arrays of input conditions for efficient parametric studies and plotting.
43
+ - Full JAX compatibility: supports `jit`, `grad`, `vmap`, and parallel evaluation.
44
+
45
+ ## Installation
46
+
47
+ ```bash
48
+ pip install jaxprop
49
+
@@ -0,0 +1,22 @@
1
+ # JAXprop
2
+
3
+ `jaxprop` provides JAX-compatible thermodynamic property calculations with support for automatic differentiation, vectorization, and JIT compilation.
4
+
5
+ 🔗 **Docs**: [turbo-sim.github.io/jaxprop](https://turbo-sim.github.io/jaxprop/)
6
+ 📦 **PyPI**: [pypi.org/project/jaxprop](https://pypi.org/project/jaxprop/)
7
+
8
+ **Note**: This project is based on the [CoolProp](https://www.coolprop.org) library but is not affiliated with or endorsed by the CoolProp project.
9
+
10
+ ## Key features
11
+
12
+ - Compute and plot phase envelopes and spinodal lines for pure fluids.
13
+ - Evaluate thermodynamic properties from Helmholtz energy–based equations of state, including metastable states inside the two-phase region.
14
+ - Perform flash calculations for any input pair with a custom solver and user-defined initial guesses.
15
+ - Work with structured property dictionaries and immutable `FluidState` objects.
16
+ - Evaluate properties over arrays of input conditions for efficient parametric studies and plotting.
17
+ - Full JAX compatibility: supports `jit`, `grad`, `vmap`, and parallel evaluation.
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ pip install jaxprop
@@ -0,0 +1,46 @@
1
+ # Highlight exception messages
2
+ # https://stackoverflow.com/questions/25109105/how-to-colorize-the-output-of-python-errors-in-the-gnome-terminal/52797444#52797444
3
+ try:
4
+ import IPython.core.ultratb
5
+ except ImportError:
6
+ # No IPython. Use default exception printing.
7
+ pass
8
+ else:
9
+ import sys
10
+ sys.excepthook = IPython.core.ultratb.FormattedTB(color_scheme='linux', call_pdb=False)
11
+
12
+
13
+ import os
14
+ os.environ["JAX_PLATFORM_NAME"] = "cpu"
15
+ import jax
16
+ jax.config.update("jax_enable_x64", True)
17
+
18
+
19
+
20
+ from .graphics import *
21
+ from .utils import *
22
+ from .helpers_jax import *
23
+ from .helpers_coolprop import *
24
+
25
+ # Import subpackages
26
+ from . import coolprop
27
+ from . import perfect_gas
28
+ # from . import bicubic
29
+
30
+ # Import API classes
31
+ from .perfect_gas import FluidPerfectGas
32
+ from .coolprop import Fluid, FluidJAX
33
+
34
+
35
+ from . import components
36
+
37
+
38
+ # Package info
39
+ __version__ = "0.3.1"
40
+ PACKAGE_NAME = "jaxprop"
41
+ URL_GITHUB = "https://github.com/turbo-sim/jaxprop"
42
+ URL_DOCS = "https://turbo-sim.github.io/jaxprop/"
43
+ URL_PYPI = "https://pypi.org/project/jaxprop/"
44
+ URL_DTU = "https://thermalpower.dtu.dk/"
45
+ BREAKLINE = 80 * "-"
46
+
@@ -0,0 +1,3 @@
1
+ from .generate_tables import *
2
+ from .bicubic_interpolant_property import *
3
+ from .jax_bicubic_HEOS_interpolation_1 import *
@@ -0,0 +1,69 @@
1
+ import numpy as np
2
+ # import jax.numpy as jnp
3
+ from ..jax_import import jnp
4
+ import pickle
5
+
6
+ from .jax_bicubic_HEOS_interpolation_1 import compute_bicubic_coefficients_of_ij, bicubic_interpolant
7
+
8
+
9
+ def bicubic_interpolant_property(h, P, table):
10
+ """
11
+ Interpolates all properties at a given enthalpy (h) and pressure (P) using bicubic interpolation.
12
+
13
+ Args:
14
+ h (float): Enthalpy in J/kg
15
+ P (float): Pressure in Pa
16
+ table_path (str): Path to saved property table (dict-based .pkl)
17
+
18
+ Returns:
19
+ dict: Dictionary of interpolated property values at (h, P)
20
+ """
21
+ h_vals = jnp.array(table['h'])
22
+ P_vals = jnp.array(table['P'])
23
+
24
+ Nh, Np = len(h_vals), len(P_vals)
25
+ hmin, hmax = float(h_vals[0]), float(h_vals[-1])
26
+ Lmin, Lmax = float(jnp.log(P_vals[0])), float(jnp.log(P_vals[-1]))
27
+
28
+ logP = jnp.log(P)
29
+
30
+ # Identify cell (i, j)
31
+ i = int((h - hmin) / (hmax - hmin) * (Nh - 1))
32
+ j = int((logP - Lmin) / (Lmax - Lmin) * (Np - 1))
33
+
34
+ i = np.clip(i, 0, Nh - 2)
35
+ j = np.clip(j, 0, Np - 2)
36
+
37
+ deltah = float(h_vals[1] - h_vals[0])
38
+ deltaL = float(jnp.log(P_vals[1]) - jnp.log(P_vals[0]))
39
+
40
+ interpolated_props = {}
41
+
42
+ for prop, prop_data in table.items():
43
+ if prop in ['h', 'P']:
44
+ continue # skip grid axes
45
+
46
+ f_grid = prop_data['value']
47
+ fx_grid = prop_data['d_dh']
48
+ fy_grid = prop_data['d_dP']
49
+ fxy_grid = prop_data['d2_dhdP']
50
+
51
+ coeffs_local = compute_bicubic_coefficients_of_ij(
52
+ i, j,
53
+ f_grid,
54
+ fx_grid * deltah,
55
+ fy_grid * deltaL,
56
+ fxy_grid * deltah * deltaL
57
+ )
58
+
59
+ coeffs = jnp.zeros((Nh, Np, 16), dtype=jnp.float64)
60
+ coeffs = coeffs.at[i, j, :].set(coeffs_local)
61
+
62
+ val = bicubic_interpolant(
63
+ h, P, h_vals, jnp.log(P_vals), coeffs,
64
+ Nh, Np, hmin, hmax, Lmin, Lmax
65
+ )
66
+
67
+ interpolated_props[prop] = float(val)
68
+
69
+ return interpolated_props
@@ -0,0 +1,85 @@
1
+ import numpy as np
2
+
3
+ from ..jax_import import jnp
4
+ # import jax.numpy as jnp
5
+ import CoolProp.CoolProp as cp
6
+ # import turboflow as tf
7
+ import os
8
+ import pickle
9
+ from jaxprop.fluid_properties import Fluid
10
+
11
+ def generate_property_table(hmin, hmax, Pmin, Pmax, fluid_name, Nh, Np, outdir='fluid_tables'):
12
+ fluid = Fluid(fluid_name)
13
+ # fluid = tf.Fluid(fluid_name)
14
+ h_vals = jnp.linspace(hmin, hmax, Nh)
15
+ Lmin=jnp.log(Pmin) # Log of P
16
+ Lmax=jnp.log(Pmax) # Log of P
17
+ P_vals = jnp.linspace(Lmin, Lmax, Np)
18
+
19
+ deltah = float(h_vals[1] - h_vals[0])
20
+ deltaL = P_vals[1]-P_vals[0]
21
+ eps_h = 0.001 * deltah
22
+ eps_P = 1e-6 * float(Pmin)
23
+
24
+ properties = {
25
+ 'T': 'T', # Temperature [K]
26
+ 'd': 'D', # Density [kg/m³]
27
+ 's': 'S', # Entropy [J/kg/K]
28
+ 'mu': 'V', # Viscosity [Pa·s]
29
+ 'k': 'L', # Thermal conductivity [W/m/K]
30
+ }
31
+
32
+ # Initialize property grids
33
+ table = {
34
+ 'h': np.array(h_vals),
35
+ 'P': np.array(jnp.exp(P_vals))
36
+ }
37
+
38
+ for key in properties:
39
+ table[key] = {
40
+ 'value': np.zeros((Nh, Np), dtype=np.float64),
41
+ 'd_dh': np.zeros((Nh, Np), dtype=np.float64),
42
+ 'd_dP': np.zeros((Nh, Np), dtype=np.float64),
43
+ 'd2_dhdP': np.zeros((Nh, Np), dtype=np.float64),
44
+ }
45
+
46
+ # Loop over grid and populate values
47
+ for i, h in enumerate(h_vals):
48
+ for j, P in enumerate(P_vals):
49
+ hf = float(h)
50
+ Pf = jnp.exp(P)
51
+
52
+ try:
53
+ f_0 = fluid.get_state(cp.HmassP_INPUTS, float(hf), float(Pf))
54
+ f_h = fluid.get_state(cp.HmassP_INPUTS, float(hf + eps_h), float(Pf))
55
+ f_p = fluid.get_state(cp.HmassP_INPUTS, float(hf), float(Pf + eps_P))
56
+ f_hp = fluid.get_state(cp.HmassP_INPUTS, float(hf + eps_h), float(Pf + eps_P))
57
+ # Uncomment below if using turboflow
58
+ # f_0 = tf.get_props_custom_jvp(fluid, cp.HmassP_INPUTS, hf, Pf)
59
+ # f_h = tf.get_props_custom_jvp(fluid, cp.HmassP_INPUTS, hf + eps_h, Pf)
60
+ # f_p = tf.get_props_custom_jvp(fluid, cp.HmassP_INPUTS, hf, Pf + eps_P)
61
+ # f_hp = tf.get_props_custom_jvp(fluid, cp.HmassP_INPUTS, hf + eps_h, Pf + eps_P)
62
+ except Exception:
63
+ continue # Skip invalid points
64
+
65
+ for key in properties:
66
+ val = f_0[key]
67
+ dval_dh = (f_h[key] - f_0[key]) / eps_h
68
+ dval_dP = (f_p[key] - f_0[key]) / eps_P
69
+ d2val_dhdP = (f_hp[key] - f_h[key] - f_p[key] + f_0[key]) / (eps_h * eps_P)
70
+
71
+ table[key]['value'][i, j] = val
72
+ table[key]['d_dh'][i, j] = dval_dh
73
+ table[key]['d_dP'][i, j] = dval_dP
74
+ table[key]['d2_dhdP'][i, j] = d2val_dhdP
75
+
76
+ # Save as pickle only (most useful for JAX processing)
77
+ os.makedirs(outdir, exist_ok=True)
78
+ pkl_path = os.path.join(outdir, f'{fluid_name}_{Nh}_x_{Np}.pkl')
79
+
80
+ with open(pkl_path, 'wb') as f:
81
+ pickle.dump(table, f)
82
+
83
+ print(f" Saved the table to:\n -> Pickle: {pkl_path}")
84
+
85
+ return table
@@ -0,0 +1,292 @@
1
+ # ======================== Imports ========================
2
+ # import os
3
+ # import psutil
4
+ import time
5
+ # import jax
6
+ # import jax.numpy as jnp
7
+ # from jax import jit
8
+
9
+ from ..jax_import import jax, jnp, jit
10
+ # from jax.experimental import mesh_utils
11
+ # from jax.sharding import Mesh, PartitionSpec, NamedSharding
12
+ import CoolProp.CoolProp as cp
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ from functools import partial
16
+
17
+ # ======================== Config ========================
18
+ # NCORES = psutil.cpu_count(logical=False)
19
+ # os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={NCORES}"
20
+ # jax.config.update("jax_enable_x64", True)
21
+
22
+ # Global precision
23
+ float64 = jnp.dtype("float64")
24
+ complex128 = jnp.dtype("complex128")
25
+
26
+ # =================== Functions to Export ===================
27
+ # @jax.jit
28
+ def compute_bicubic_coefficients_of_ij(i, j, f, fx, fy, fxy):
29
+ #xx=f(0,0)&f(1,0)&f(0,1)&f(1,1)&f_x(0,0)&f_x(1,0)&f_x(0,1)&f_x(1,1)&f_y(0,0)&f_y(1,0)&f_y(0,1)&f_y(1,1)&f_{xy}(0,0)&f_{xy}(1,0)&f_{xy}(0,1)&f_{xy}(1,1)
30
+ xx=[f[i,j],f[i+1,j],f[i,j+1],f[i+1,j+1],fx[i,j],fx[i+1,j],fx[i,j+1],fx[i+1,j+1],fy[i,j],fy[i+1,j],fy[i,j+1],fy[i+1,j+1],fxy[i,j],fxy[i+1,j],fxy[i,j+1],fxy[i+1,j+1]]
31
+ A=[ [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0 ],
32
+ [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0 ],
33
+ [ -3., 3., 0., 0., -2., -1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0 ],
34
+ [ 2., -2., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0 ],
35
+ [ 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0 ],
36
+ [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0 ],
37
+ [ 0., 0., 0., 0., 0., 0., 0., 0., -3., 3., 0., 0., -2., -1., 0., 0 ],
38
+ [ 0., 0., 0., 0., 0., 0., 0., 0., 2., -2., 0., 0., 1., 1., 0., 0 ],
39
+ [ -3., 0., 3., 0., 0., 0., 0., 0., -2., 0., -1., 0., 0., 0., 0., 0 ],
40
+ [ 0., 0., 0., 0., -3., 0., 3., 0., 0., 0., 0., 0., -2., 0., -1., 0 ],
41
+ [ 9., -9., -9., 9., 6., 3., -6., -3., 6., -6., 3., -3., 4., 2., 2., 1 ],
42
+ [ -6., 6., 6., -6., -3., -3., 3., 3., -4., 4., -2., 2., -2., -2., -1., -1 ],
43
+ [ 2., 0., -2., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0 ],
44
+ [ 0., 0., 0., 0., 2., 0., -2., 0., 0., 0., 0., 0., 1., 0., 1., 0 ],
45
+ [ -6., 6., 6., -6., -4., -2., 4., 2., -3., 3., -3., 3., -2., -1., -2., -1 ],
46
+ [ 4., -4., -4., 4., 2., 2., -2., -2., 2., -2., 2., -2., 1., 1., 1., 1]]
47
+ return jnp.matmul(jnp.array(A,dtype=f.dtype),jnp.array(xx,dtype=f.dtype))
48
+
49
+
50
+ # def bicubic_interpolant(h_vals, P_vals, coeffs, hmin, hmax, Lmin, Lmax, Nh, Np):
51
+ # """
52
+ # Create a bicubic interpolant function using precomputed coefficients.
53
+ # """
54
+ # # Ensure that coeffs is in the right shape (Nh, Np, 16)
55
+ # assert coeffs.shape == (Nh, Np, 16), f"Expected coeffs to have shape (Nh, Np, 16), but got {coeffs.shape}"
56
+
57
+ # # Normalized grid values
58
+ # def normalize(value, min_val, max_val):
59
+ # return (value - min_val) / (max_val - min_val)
60
+
61
+ # # The actual interpolant function, it takes h and P as arguments
62
+ # @jit
63
+ # def interpolant_fn(h, P):
64
+ # # Ensure h_vals and P_vals are 1D arrays
65
+ # h_vals_flat = jnp.ravel(h_vals) # Flatten the h_vals to 1D
66
+ # P_vals_flat = jnp.ravel(P_vals) # Flatten the P_vals to 1D
67
+
68
+ # # Normalize h and P
69
+ # norm_h = normalize(h, hmin, hmax)
70
+ # norm_P = normalize(P, Lmin, Lmax)
71
+
72
+ # # Identify the grid points surrounding h, P
73
+ # i = jnp.clip(jnp.searchsorted(h_vals_flat, h) - 1, 0, Nh - 2)
74
+ # j = jnp.clip(jnp.searchsorted(P_vals_flat, P) - 1, 0, Np - 2)
75
+
76
+ # # Extract the coefficients for the surrounding grid points
77
+ # coeff = coeffs[i, j]
78
+
79
+ # # Interpolate in both directions
80
+ # h_diff = norm_h - h_vals_flat[i]
81
+ # P_diff = norm_P - P_vals_flat[j]
82
+
83
+ # # Calculate the interpolant using the bicubic coefficients
84
+ # result = (
85
+ # coeff[0] + coeff[1] * h_diff + coeff[2] * P_diff + coeff[3] * h_diff * P_diff +
86
+ # coeff[4] * h_diff**2 + coeff[5] * h_diff * P_diff**2 +
87
+ # coeff[6] * P_diff**2 + coeff[7] * h_diff**2 * P_diff +
88
+ # coeff[8] * h_diff**3 + coeff[9] * h_diff**2 * P_diff +
89
+ # coeff[10] * h_diff * P_diff**2 + coeff[11] * h_diff**3 * P_diff +
90
+ # coeff[12] * P_diff**3 + coeff[13] * h_diff * P_diff**3 +
91
+ # coeff[14] * h_diff**3 * P_diff**2 + coeff[15] * h_diff**2 * P_diff**3
92
+ # )
93
+
94
+ # return result
95
+
96
+ # return interpolant_fn
97
+
98
+ # @partial(jit, static_argnums=(5, 6)) # Nh and Np are static # static arguments are all except h, P
99
+ # def bicubic_interpolant(h, P, h_vals, P_vals, coeffs, Nh, Np, hmin, hmax, Lmin, Lmax):
100
+ # """
101
+ # Evaluate the bicubic interpolant at (h, P) using precomputed coefficients.
102
+ # """
103
+ # # Normalize
104
+ # norm_h = (h - hmin) / (hmax - hmin)
105
+ # norm_P = (P - Lmin) / (Lmax - Lmin)
106
+
107
+ # # Flatten grid arrays
108
+ # h_vals_flat = jnp.ravel(h_vals)
109
+ # P_vals_flat = jnp.ravel(P_vals)
110
+
111
+ # # Find surrounding grid indices
112
+ # i = jnp.clip(jnp.searchsorted(h_vals_flat, h) - 1, 0, Nh - 2)
113
+ # j = jnp.clip(jnp.searchsorted(P_vals_flat, P) - 1, 0, Np - 2)
114
+
115
+ # # Relative differences in normalized space
116
+ # h_base = (h_vals_flat[i] - hmin) / (hmax - hmin)
117
+ # P_base = (P_vals_flat[j] - Lmin) / (Lmax - Lmin)
118
+ # h_diff = norm_h - h_base
119
+ # P_diff = norm_P - P_base
120
+
121
+ # # Fetch bicubic coefficients
122
+ # coeff = coeffs[i, j]
123
+
124
+ # # Evaluate bicubic polynomial
125
+ # result = (
126
+ # coeff[0] + coeff[1] * h_diff + coeff[2] * P_diff + coeff[3] * h_diff * P_diff +
127
+ # coeff[4] * h_diff**2 + coeff[5] * h_diff * P_diff**2 +
128
+ # coeff[6] * P_diff**2 + coeff[7] * h_diff**2 * P_diff +
129
+ # coeff[8] * h_diff**3 + coeff[9] * h_diff**2 * P_diff +
130
+ # coeff[10] * h_diff * P_diff**2 + coeff[11] * h_diff**3 * P_diff +
131
+ # coeff[12] * P_diff**3 + coeff[13] * h_diff * P_diff**3 +
132
+ # coeff[14] * h_diff**3 * P_diff**2 + coeff[15] * h_diff**2 * P_diff**3
133
+ # )
134
+
135
+ # return result
136
+
137
+ # @partial(jit, static_argnums=(5, 6))
138
+ def bicubic_interpolant(h, P, h_vals, P_vals, coeffs, Nh, Np, hmin, hmax, Lmin, Lmax):
139
+ """
140
+ Evaluate the bicubic interpolant at (h, P) using precomputed coefficients.
141
+ """
142
+ # Log-transform P
143
+ L = jnp.log(P)
144
+
145
+ # Normalize positions to [0, 1] cell coordinates
146
+ ii = ((h - hmin) / (hmax - hmin) * (Nh - 1))
147
+ i = ii.astype(int)
148
+ x = ii - i
149
+
150
+ jj = ((L - Lmin) / (Lmax - Lmin) * (Np - 1))
151
+ j = jj.astype(int)
152
+ y = jj - j
153
+
154
+
155
+ # Evaluate bicubic polynomial
156
+ result = jnp.zeros_like(h) # use h shape
157
+ x_pow = jnp.ones_like(h) # x^0
158
+
159
+ for m in range(4): # m = x power
160
+ y_pow = jnp.ones_like(h) # y^0 initially
161
+ for n in range(4): # n = y power
162
+ c = coeffs[i, j, 4 * n + m]
163
+ result += c * x_pow * y_pow
164
+ y_pow = y_pow * y
165
+ x_pow = x_pow * x
166
+
167
+
168
+ return result
169
+
170
+
171
+
172
+ # @jax.jit
173
+ def inverse_interpolant_scalar_hD(h, D):
174
+ #Find the real(float) index
175
+ ii=((h-hmin)/(hmax-hmin)*(N-1))
176
+ #The integer part is the cell index
177
+ i=ii.astype(int)
178
+ #The remainder (for numerical stability better to use the difference)
179
+ #is instead the position within our interpolation cell.
180
+ x=ii-i
181
+ #find interval that contains the solution
182
+ xth=jnp.ones_like(h) #initialize x to the 0th power
183
+ #First we compute the nodal values, that is the values of D(h,P) where
184
+ #h is the actual enthalpy and P are grid values.
185
+ #TODO: instead of computing all the nodal values and then use sortedsearch
186
+ #to find the correct interval, we could do a binary search. This would
187
+ #constraint M to be a power of 2.
188
+ #Possible example (to be refined) to compute the node. Start with the node
189
+ #corresponding to j=M/2, then compute new index j=j+M/4*(2*(Dj>D)-1)
190
+ #then j=j+M/8*(2*(Dj>D)-1) and so on ...
191
+ #after log2(M) iteration we converged to the index j.
192
+ D_nodal=jnp.zeros(M)
193
+ for m in range(4):
194
+ D_nodal+=bicubic_coefficients[i,:,m]*xth
195
+ xth=xth*x
196
+ #We search more efficiently in which interval we have the solution
197
+ #if we assume a sorted vector.
198
+ #TODO: This assumes that P has a monotonic trend with respect to D
199
+ #at fixed h. This causes some problems and needs further investigation
200
+ if iD==cp.iSmass:
201
+ j=jax.numpy.searchsorted(-D_nodal,-D).astype(int)-1
202
+ else:
203
+ j=jax.numpy.searchsorted(D_nodal,D).astype(int)-1
204
+
205
+ #After we are in the unit square, that is for known i and j
206
+ #compute 1D cubic coefficients (as complex numbers to avoid promotion)
207
+ #Each coefficient is bj=sum(aij*x**i)
208
+ #Leading to the equation D=b0 + b1*y + b2*y**2 + b3*y**3
209
+ xth=jnp.ones_like(h)
210
+ b0 =jnp.zeros_like(h,dtype=complex128)
211
+ b1 =jnp.zeros_like(h,dtype=complex128)
212
+ b2 =jnp.zeros_like(h,dtype=complex128)
213
+ b3 =jnp.zeros_like(h,dtype=complex128)
214
+ for m in range(4):
215
+ b0 +=bicubic_coefficients[i,j,4*0+m]*xth
216
+ b1 +=bicubic_coefficients[i,j,4*1+m]*xth
217
+ b2 +=bicubic_coefficients[i,j,4*2+m]*xth
218
+ b3 +=bicubic_coefficients[i,j,4*3+m]*xth
219
+ xth=xth*x
220
+ #solve cubic equation - all three solutions
221
+ #TODO: if necessary, add solution for degenerate (quadratic and linear)
222
+ #For more information:https://en.wikipedia.org/wiki/Cubic_equation#General_cubic_formula
223
+ D0=b2*b2-3*b3*b1
224
+ D1=2*b2*b2*b2-9*b3*b2*b1+27*b3*b3*(b0-D)
225
+ C=((D1+(D1*D1-4*D0*D0*D0)**0.5)/2)**(1/3)
226
+ D0C=jax.lax.select(C==(0+0j),0+0j,D0/C)
227
+ z=jnp.array([1,-0.5+0.8660254037844386j,-0.5-0.8660254037844386j])
228
+ y=-1/(3*b3)*(b2+C*z+D0C/z)
229
+ #To find our solution we have two criteria:
230
+ # -0 imaginary part
231
+ # -real part between 0 and 1, that are the bounds of our cell
232
+ # We define a "badness" as the deviation from these critera, and pick the
233
+ # solution with the lowest badness
234
+ badness=jax.nn.relu(4*(jnp.real(y)-0.5)**2-1)+jnp.imag(y)**2
235
+ yreal=jnp.real(y[jnp.argmin(badness)])
236
+ jj=j+yreal
237
+ L=Lmin+jj*(Lmax-Lmin)/(M-1)
238
+ P=jnp.exp(L)
239
+ return P
240
+
241
+ # @jax.jit
242
+ def inverse_interpolant_scalar_DP(D, P):
243
+ # Convert pressure to log space
244
+ L = jnp.log(P)
245
+
246
+ # Compute index along pressure grid
247
+ jj = ((L - Lmin) / (Lmax - Lmin) * (M - 1))
248
+ j = jj.astype(int)
249
+ y = jj - j # fractional position in pressure direction
250
+
251
+ # Compute nodal D(h) values at fixed pressure (we'll search h index now)
252
+ yth = jnp.ones_like(D)
253
+ D_nodal = jnp.zeros(N)
254
+ for m in range(4):
255
+ D_nodal += bicubic_coefficients[:, j, m] * yth
256
+ yth = yth * y
257
+
258
+ # Search h-direction to find which cell to use
259
+ if iD == cp.iSmass:
260
+ i = jnp.searchsorted(-D_nodal, -D).astype(int) - 1
261
+ else:
262
+ i = jnp.searchsorted(D_nodal, D).astype(int) - 1
263
+
264
+ # Now build 1D cubic in x (h-direction) at fixed j
265
+ yth = jnp.ones_like(D)
266
+ b0 = jnp.zeros_like(D, dtype=complex128)
267
+ b1 = jnp.zeros_like(D, dtype=complex128)
268
+ b2 = jnp.zeros_like(D, dtype=complex128)
269
+ b3 = jnp.zeros_like(D, dtype=complex128)
270
+ for m in range(4):
271
+ b0 += bicubic_coefficients[i, j, m + 4*0] * yth
272
+ b1 += bicubic_coefficients[i, j, m + 4*1] * yth
273
+ b2 += bicubic_coefficients[i, j, m + 4*2] * yth
274
+ b3 += bicubic_coefficients[i, j, m + 4*3] * yth
275
+ yth = yth * y
276
+
277
+ # Solve cubic: D = b0 + b1*x + b2*x^2 + b3*x^3
278
+ D0 = b2*b2 - 3*b3*b1
279
+ D1 = 2*b2**3 - 9*b3*b2*b1 + 27*b3**2*(b0 - D)
280
+ C = ((D1 + jnp.sqrt(D1**2 - 4*D0**3)) / 2)**(1/3)
281
+ D0C = jax.lax.select(C == 0, 0 + 0j, D0 / C)
282
+ z = jnp.array([1, -0.5 + 0.8660254037844386j, -0.5 - 0.8660254037844386j])
283
+ x = -1/(3*b3)*(b2 + C*z + D0C/z)
284
+
285
+ # Pick root with lowest badness
286
+ badness = jax.nn.relu(4*(jnp.real(x)-0.5)**2 - 1) + jnp.imag(x)**2
287
+ xreal = jnp.real(x[jnp.argmin(badness)])
288
+
289
+ # Final result: compute h from i + x
290
+ ii = i + xreal
291
+ h = hmin + ii * (hmax - hmin) / (N - 1)
292
+ return h
@@ -0,0 +1,3 @@
1
+ # from ..jax_helpers import *
2
+ from .nozzle_model_core import *
3
+ from .nozzle_model_solver import *