jaxspec 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,73 +0,0 @@
1
- """This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
2
- pure callback to enable reading data from the files without saturating the memory."""
3
-
4
- import h5netcdf
5
- import jax
6
- import jax.numpy as jnp
7
-
8
- from ...util.online_storage import table_manager
9
-
10
-
11
- @jax.jit
12
- def temperature_table_getter():
13
- with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
14
- temperature = jnp.asarray(f["/temperature"])
15
-
16
- return temperature
17
-
18
-
19
- @jax.jit
20
- def get_temperature(kT):
21
- temperature = temperature_table_getter()
22
- idx = jnp.searchsorted(temperature, kT) - 1
23
-
24
- return idx, temperature[idx], temperature[idx + 1]
25
-
26
-
27
- @jax.jit
28
- def continuum_table_getter():
29
- with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
30
- continuum_energy = jnp.asarray(f["/continuum_energy"])
31
- continuum_emissivity = jnp.asarray(f["/continuum_emissivity"])
32
- continuum_end_index = jnp.asarray(f["/continuum_end_index"])
33
-
34
- return continuum_energy, continuum_emissivity, continuum_end_index
35
-
36
-
37
- @jax.jit
38
- def pseudo_table_getter():
39
- with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
40
- pseudo_energy = jnp.asarray(f["/pseudo_energy"])
41
- pseudo_emissivity = jnp.asarray(f["/pseudo_emissivity"])
42
- pseudo_end_index = jnp.asarray(f["/pseudo_end_index"])
43
-
44
- return pseudo_energy, pseudo_emissivity, pseudo_end_index
45
-
46
-
47
- @jax.jit
48
- def line_table_getter():
49
- with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
50
- line_energy = jnp.asarray(f["/line_energy"])
51
- line_element = jnp.asarray(f["/line_element"])
52
- line_emissivity = jnp.asarray(f["/line_emissivity"])
53
- line_end_index = jnp.asarray(f["/line_end_index"])
54
-
55
- return line_energy, line_element, line_emissivity, line_end_index
56
-
57
-
58
- @jax.jit
59
- def get_continuum(idx):
60
- continuum_energy, continuum_emissivity, continuum_end_index = continuum_table_getter()
61
- return continuum_energy[idx], continuum_emissivity[idx], continuum_end_index[idx]
62
-
63
-
64
- @jax.jit
65
- def get_pseudo(idx):
66
- pseudo_energy, pseudo_emissivity, pseudo_end_index = pseudo_table_getter()
67
- return pseudo_energy[idx], pseudo_emissivity[idx], pseudo_end_index[idx]
68
-
69
-
70
- @jax.jit
71
- def get_lines(idx):
72
- line_energy, line_element, line_emissivity, line_end_index = line_table_getter()
73
- return line_energy[idx], line_element[idx], line_emissivity[idx], line_end_index[idx]