integrate_module 0.99.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- integrate/__init__.py +144 -0
- integrate/gex.py +402 -0
- integrate/integrate.py +4063 -0
- integrate/integrate_borehole.py +1127 -0
- integrate/integrate_hdf5_info_cli.py +122 -0
- integrate/integrate_io.py +5293 -0
- integrate/integrate_plot.py +4986 -0
- integrate/integrate_query.py +1609 -0
- integrate/integrate_rejection.py +1836 -0
- integrate/integrate_rejection_cli.py +210 -0
- integrate/integrate_rejection_jax.py +494 -0
- integrate/integrate_timing_cli.py +407 -0
- integrate/integrate_www_cli.py +8 -0
- integrate_module-0.99.1.dist-info/METADATA +229 -0
- integrate_module-0.99.1.dist-info/RECORD +19 -0
- integrate_module-0.99.1.dist-info/WHEEL +5 -0
- integrate_module-0.99.1.dist-info/entry_points.txt +5 -0
- integrate_module-0.99.1.dist-info/licenses/LICENSE +21 -0
- integrate_module-0.99.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
"""
|
|
3
|
+
INTEGRATE Rejection Sampling CLI
|
|
4
|
+
|
|
5
|
+
Command-line interface for probabilistic inversion using rejection sampling.
|
|
6
|
+
Provides access to the integrate_rejection function with various options for
|
|
7
|
+
Bayesian inversion and posterior sampling.
|
|
8
|
+
|
|
9
|
+
Author: Thomas Mejer Hansen
|
|
10
|
+
Email: tmeha@geo.au.dk
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
import sys
|
|
15
|
+
import os
|
|
16
|
+
import multiprocessing
|
|
17
|
+
|
|
18
|
+
# Import the integrate module
|
|
19
|
+
try:
|
|
20
|
+
import integrate as ig
|
|
21
|
+
except ImportError:
|
|
22
|
+
print("Error: Could not import integrate module. Please ensure it is properly installed.")
|
|
23
|
+
sys.exit(1)
|
|
24
|
+
|
|
25
|
+
def main():
|
|
26
|
+
"""Entry point for the integrate_rejection command."""
|
|
27
|
+
|
|
28
|
+
# Set up multiprocessing support
|
|
29
|
+
multiprocessing.freeze_support()
|
|
30
|
+
|
|
31
|
+
# Create argument parser
|
|
32
|
+
parser = argparse.ArgumentParser(
|
|
33
|
+
description='INTEGRATE rejection sampling for Bayesian inversion',
|
|
34
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
35
|
+
epilog="""
|
|
36
|
+
Examples:
|
|
37
|
+
integrate_rejection --prior prior.h5 --data data.h5 --output post.h5
|
|
38
|
+
integrate_rejection --prior prior.h5 --data data.h5 --samples 1000000 --parallel
|
|
39
|
+
integrate_rejection --prior prior.h5 --data data.h5 --auto-temp --cpus 4
|
|
40
|
+
|
|
41
|
+
For more information, see the INTEGRATE documentation.
|
|
42
|
+
"""
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Required arguments
|
|
46
|
+
parser.add_argument('--prior', '-p',
|
|
47
|
+
type=str,
|
|
48
|
+
required=True,
|
|
49
|
+
help='Path to HDF5 file containing prior model and data samples')
|
|
50
|
+
|
|
51
|
+
parser.add_argument('--data', '-d',
|
|
52
|
+
type=str,
|
|
53
|
+
required=True,
|
|
54
|
+
help='Path to HDF5 file containing observed data for inversion')
|
|
55
|
+
|
|
56
|
+
# Optional arguments
|
|
57
|
+
parser.add_argument('--output', '-o',
|
|
58
|
+
type=str,
|
|
59
|
+
default='',
|
|
60
|
+
help='Output path for posterior samples (auto-generated if not specified)')
|
|
61
|
+
|
|
62
|
+
parser.add_argument('--samples', '-n',
|
|
63
|
+
type=int,
|
|
64
|
+
default=100000000,
|
|
65
|
+
help='Maximum number of prior samples to use for inversion (default: 100000000)')
|
|
66
|
+
|
|
67
|
+
parser.add_argument('--auto-temp', '-T',
|
|
68
|
+
action='store_true',
|
|
69
|
+
help='Enable automatic temperature estimation (default: disabled)')
|
|
70
|
+
|
|
71
|
+
parser.add_argument('--temp-base',
|
|
72
|
+
type=float,
|
|
73
|
+
default=1.0,
|
|
74
|
+
help='Base temperature for sampling (default: 1.0)')
|
|
75
|
+
|
|
76
|
+
parser.add_argument('--nr',
|
|
77
|
+
type=int,
|
|
78
|
+
default=400,
|
|
79
|
+
help='Number of resamples for temperature estimation (default: 400)')
|
|
80
|
+
|
|
81
|
+
parser.add_argument('--cpus', '-c',
|
|
82
|
+
type=int,
|
|
83
|
+
default=0,
|
|
84
|
+
help='Number of CPU cores to use (0 = auto-detect, default: 0)')
|
|
85
|
+
|
|
86
|
+
parser.add_argument('--no-parallel',
|
|
87
|
+
action='store_true',
|
|
88
|
+
help='Disable parallel processing')
|
|
89
|
+
|
|
90
|
+
parser.add_argument('--chunks',
|
|
91
|
+
type=int,
|
|
92
|
+
default=0,
|
|
93
|
+
help='Number of chunks for processing (0 = auto, default: 0)')
|
|
94
|
+
|
|
95
|
+
parser.add_argument('--id-use',
|
|
96
|
+
type=str,
|
|
97
|
+
default='',
|
|
98
|
+
help='Comma-separated list of data IDs to use for inversion')
|
|
99
|
+
|
|
100
|
+
parser.add_argument('--ip-range',
|
|
101
|
+
type=str,
|
|
102
|
+
default='',
|
|
103
|
+
help='Comma-separated IP range for distributed processing')
|
|
104
|
+
|
|
105
|
+
parser.add_argument('--use-n-best',
|
|
106
|
+
type=int,
|
|
107
|
+
default=0,
|
|
108
|
+
help='Use N best samples for analysis (default: 0)')
|
|
109
|
+
|
|
110
|
+
parser.add_argument('--backend',
|
|
111
|
+
choices=['numpy', 'jax'],
|
|
112
|
+
default='numpy',
|
|
113
|
+
help='Rejection sampling backend: numpy (default) or jax')
|
|
114
|
+
|
|
115
|
+
parser.add_argument('--verbose', '-v',
|
|
116
|
+
action='store_true',
|
|
117
|
+
help='Enable verbose output')
|
|
118
|
+
|
|
119
|
+
parser.add_argument('--version',
|
|
120
|
+
action='store_true',
|
|
121
|
+
help='Show version information')
|
|
122
|
+
|
|
123
|
+
# Parse arguments
|
|
124
|
+
args = parser.parse_args()
|
|
125
|
+
|
|
126
|
+
# Handle version request
|
|
127
|
+
if args.version:
|
|
128
|
+
try:
|
|
129
|
+
from integrate import __version__
|
|
130
|
+
print(f"INTEGRATE version: {__version__}")
|
|
131
|
+
except ImportError:
|
|
132
|
+
print("INTEGRATE version: unknown")
|
|
133
|
+
return 0
|
|
134
|
+
|
|
135
|
+
# Validate input files
|
|
136
|
+
if not os.path.exists(args.prior):
|
|
137
|
+
print(f"Error: Prior file not found: {args.prior}")
|
|
138
|
+
return 1
|
|
139
|
+
|
|
140
|
+
if not os.path.exists(args.data):
|
|
141
|
+
print(f"Error: Data file not found: {args.data}")
|
|
142
|
+
return 1
|
|
143
|
+
|
|
144
|
+
# Parse comma-separated arguments
|
|
145
|
+
id_use = []
|
|
146
|
+
if args.id_use:
|
|
147
|
+
try:
|
|
148
|
+
id_use = [int(x.strip()) for x in args.id_use.split(',')]
|
|
149
|
+
except ValueError:
|
|
150
|
+
print(f"Error: Invalid ID list format: {args.id_use}")
|
|
151
|
+
return 1
|
|
152
|
+
|
|
153
|
+
ip_range = []
|
|
154
|
+
if args.ip_range:
|
|
155
|
+
ip_range = [x.strip() for x in args.ip_range.split(',')]
|
|
156
|
+
|
|
157
|
+
# Set up parallel processing
|
|
158
|
+
parallel = not args.no_parallel
|
|
159
|
+
if parallel:
|
|
160
|
+
# Check if parallel processing is supported
|
|
161
|
+
parallel = ig.use_parallel(showInfo=1 if args.verbose else 0)
|
|
162
|
+
|
|
163
|
+
# Print configuration if verbose
|
|
164
|
+
if args.verbose:
|
|
165
|
+
print("Configuration:")
|
|
166
|
+
print(f" Prior file: {args.prior}")
|
|
167
|
+
print(f" Data file: {args.data}")
|
|
168
|
+
print(f" Output file: {args.output if args.output else 'auto-generated'}")
|
|
169
|
+
print(f" Max samples: {args.samples}")
|
|
170
|
+
print(f" Auto temperature: {args.auto_temp}")
|
|
171
|
+
print(f" Base temperature: {args.temp_base}")
|
|
172
|
+
print(f" Parallel processing: {parallel}")
|
|
173
|
+
print(f" CPU cores: {args.cpus if args.cpus > 0 else 'auto-detect'}")
|
|
174
|
+
print(f" Backend: {args.backend}")
|
|
175
|
+
print("")
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
# Call the integrate_rejection function
|
|
179
|
+
f_post_h5 = ig.integrate_rejection(
|
|
180
|
+
f_prior_h5=args.prior,
|
|
181
|
+
f_data_h5=args.data,
|
|
182
|
+
f_post_h5=args.output,
|
|
183
|
+
N_use=args.samples,
|
|
184
|
+
id_use=id_use,
|
|
185
|
+
ip_range=ip_range,
|
|
186
|
+
nr=args.nr,
|
|
187
|
+
autoT=1 if args.auto_temp else 0,
|
|
188
|
+
T_base=args.temp_base,
|
|
189
|
+
Nchunks=args.chunks,
|
|
190
|
+
Ncpu=args.cpus,
|
|
191
|
+
parallel=parallel,
|
|
192
|
+
use_N_best=args.use_n_best,
|
|
193
|
+
backend=args.backend,
|
|
194
|
+
showInfo=1 if args.verbose else 0
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
print(f"Rejection sampling completed successfully.")
|
|
198
|
+
print(f"Posterior samples saved to: {f_post_h5}")
|
|
199
|
+
|
|
200
|
+
return 0
|
|
201
|
+
|
|
202
|
+
except Exception as e:
|
|
203
|
+
print(f"Error during rejection sampling: {str(e)}")
|
|
204
|
+
if args.verbose:
|
|
205
|
+
import traceback
|
|
206
|
+
traceback.print_exc()
|
|
207
|
+
return 1
|
|
208
|
+
|
|
209
|
+
if __name__ == "__main__":
|
|
210
|
+
sys.exit(main())
|
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX backend for integrate_rejection likelihood calculations.
|
|
3
|
+
|
|
4
|
+
All computation — likelihood evaluation AND post-processing (temperature
|
|
5
|
+
estimation, weighted sampling, evidence, CHI2) — is performed on-device.
|
|
6
|
+
Only the tiny final arrays are transferred back to the host:
|
|
7
|
+
|
|
8
|
+
Old: (bsz, N) likelihood matrix → ~256 MB per batch (N=1 M, bsz=64)
|
|
9
|
+
New: i_use + scalars → ~52 KB per batch (~5000× less)
|
|
10
|
+
|
|
11
|
+
This eliminates the PCIe bottleneck that made the old GPU path ~49× slower
|
|
12
|
+
than CPU despite the faster kernel.
|
|
13
|
+
|
|
14
|
+
Usage
|
|
15
|
+
-----
|
|
16
|
+
from integrate.integrate_rejection_jax import integrate_rejection_range_jax
|
|
17
|
+
# or via integrate_rejection(backend='jax', ...)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
import functools
|
|
22
|
+
import numpy as np
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
|
|
25
|
+
# Disable JAX's default behaviour of pre-allocating ~75 % of GPU VRAM upfront.
|
|
26
|
+
# Without this, JAX tries to grab ~18 GB on a 24 GB card at import time, which
|
|
27
|
+
# fails when the display driver or other processes already occupy some VRAM.
|
|
28
|
+
# Must be set before `import jax`.
|
|
29
|
+
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
import jax
|
|
33
|
+
import jax.numpy as jnp
|
|
34
|
+
# Cache compiled GPU kernels to disk. GPU kernel compilation for large
|
|
35
|
+
# static shapes (N=1M sort, cumsum, searchsorted) takes ~40s on first run.
|
|
36
|
+
# With the cache, every subsequent run reloads compiled kernels and warmup
|
|
37
|
+
# drops to ~1s. Must be set via jax.config.update (not an env var) in
|
|
38
|
+
# JAX 0.10+; the cache is keyed on kernel + GPU arch so it is safe to share.
|
|
39
|
+
_cache_dir = os.path.expanduser("~/.cache/jax_xla_gpu")
|
|
40
|
+
os.makedirs(_cache_dir, exist_ok=True)
|
|
41
|
+
jax.config.update("jax_compilation_cache_dir", _cache_dir)
|
|
42
|
+
_JAX_AVAILABLE = True
|
|
43
|
+
except ImportError:
|
|
44
|
+
_JAX_AVAILABLE = False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _check_jax():
|
|
48
|
+
if not _JAX_AVAILABLE:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"JAX is required for backend='jax'.\n"
|
|
51
|
+
"Install with: pip install jax (CPU)\n"
|
|
52
|
+
" or: pip install jax[cuda12] (GPU)"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
# JAX likelihood kernels (built lazily on first use)
|
|
58
|
+
# ---------------------------------------------------------------------------
|
|
59
|
+
|
|
60
|
+
_single_kernel = None
|
|
61
|
+
_batch_kernel = None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_jax_kernels():
|
|
65
|
+
"""Return (single, batch) JIT-compiled Gaussian-diagonal likelihood fns."""
|
|
66
|
+
global _single_kernel, _batch_kernel
|
|
67
|
+
if _single_kernel is not None:
|
|
68
|
+
return _single_kernel, _batch_kernel
|
|
69
|
+
|
|
70
|
+
@jax.jit
|
|
71
|
+
def _likelihood_gaussian_diagonal_jax(D, d_obs, d_std):
|
|
72
|
+
"""
|
|
73
|
+
JIT-compiled Gaussian diagonal log-likelihood for one data point.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
D : jax array (N, Nf) — prior forward-model predictions
|
|
78
|
+
d_obs : jax array (Nf,) — observed data (may contain NaN)
|
|
79
|
+
d_std : jax array (Nf,) — per-feature standard deviation
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
jax array (N,) — log-likelihood for each prior sample
|
|
84
|
+
"""
|
|
85
|
+
valid = ~(jnp.isnan(d_obs) | jnp.isnan(d_std))
|
|
86
|
+
d_obs_s = jnp.where(valid, d_obs, 0.0)
|
|
87
|
+
d_std_s = jnp.where(valid, d_std, 1.0)
|
|
88
|
+
dd = D - d_obs_s
|
|
89
|
+
return -0.5 * jnp.sum(valid * (dd / d_std_s) ** 2, axis=1)
|
|
90
|
+
|
|
91
|
+
# Vectorise over a batch of data points; D is shared (in_axes=(None, 0, 0))
|
|
92
|
+
_likelihood_gaussian_diagonal_batch_jax = jax.jit(
|
|
93
|
+
jax.vmap(_likelihood_gaussian_diagonal_jax, in_axes=(None, 0, 0))
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
_single_kernel = _likelihood_gaussian_diagonal_jax
|
|
97
|
+
_batch_kernel = _likelihood_gaussian_diagonal_batch_jax
|
|
98
|
+
return _single_kernel, _batch_kernel
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# ---------------------------------------------------------------------------
|
|
102
|
+
# JAX post-processing kernels — temperature, sampling, EV, CHI2 on-device
|
|
103
|
+
# ---------------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
def _logl_T_est_jax(L, N_above, P_acc_lev):
|
|
106
|
+
"""
|
|
107
|
+
JAX port of integrate.logl_T_est.
|
|
108
|
+
|
|
109
|
+
Estimates an annealing temperature from the log-likelihood vector L.
|
|
110
|
+
Uses jax.lax.dynamic_index_in_dim for a data-dependent index that is
|
|
111
|
+
still JIT-safe. Returns jnp.inf when all L values are NaN, enforces T>=1
|
|
112
|
+
otherwise.
|
|
113
|
+
"""
|
|
114
|
+
L_norm = L - jnp.nanmax(L) # shift so max = 0
|
|
115
|
+
sorted_L = jnp.sort(L_norm) # NaN sorts to end in XLA
|
|
116
|
+
n_valid = jnp.sum(~jnp.isnan(L)).astype(jnp.int32)
|
|
117
|
+
idx = jnp.maximum(jnp.array(0, jnp.int32), n_valid - N_above - 1)
|
|
118
|
+
logL_lev = jax.lax.dynamic_index_in_dim(sorted_L, idx, axis=0, keepdims=False)
|
|
119
|
+
T_est = logL_lev / jnp.log(P_acc_lev)
|
|
120
|
+
T_est = jnp.maximum(jnp.array(1.0), T_est)
|
|
121
|
+
return jnp.where(n_valid > 0, T_est, jnp.inf)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@functools.lru_cache(maxsize=8)
|
|
125
|
+
def _get_postprocess_kernel(nr):
|
|
126
|
+
"""
|
|
127
|
+
Build and cache a JIT-compiled post-processing kernel for nr samples.
|
|
128
|
+
|
|
129
|
+
Keyed on `nr` because it determines the output shape of the uniform draws.
|
|
130
|
+
Called once per data point in a Python loop — compiles for (N,) shaped
|
|
131
|
+
tensors rather than (bsz, N), which avoids the minutes-long XLA fusion that
|
|
132
|
+
vmap over large N would trigger.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def _postprocess_single(key, L, L_per_type_b, n_data_b, idx_jax,
|
|
136
|
+
N_above, P_acc_lev, autoT, T_base):
|
|
137
|
+
"""
|
|
138
|
+
Full post-processing for one data point.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
key : (2,) — PRNG key
|
|
143
|
+
L : (N,) — combined log-likelihood
|
|
144
|
+
L_per_type_b : (Ndt, N) — per-type log-likelihoods
|
|
145
|
+
n_data_b : (Ndt,) — non-NaN observation count per type
|
|
146
|
+
idx_jax : (N,) — maps sample position → original prior idx
|
|
147
|
+
"""
|
|
148
|
+
N = L.shape[0]
|
|
149
|
+
|
|
150
|
+
# 1. Temperature estimation
|
|
151
|
+
T_auto = _logl_T_est_jax(L, N_above, P_acc_lev)
|
|
152
|
+
T = jnp.where(autoT == 1, T_auto, T_base)
|
|
153
|
+
|
|
154
|
+
# 2. Acceptance probabilities (numerically stable)
|
|
155
|
+
max_L = jnp.nanmax(L)
|
|
156
|
+
P_acc = jnp.exp((1.0 / T) * (L - max_L))
|
|
157
|
+
P_acc = jnp.where(jnp.isnan(P_acc), 0.0, P_acc)
|
|
158
|
+
p_sum = jnp.sum(P_acc)
|
|
159
|
+
# Fall back to uniform when all weights collapse to zero
|
|
160
|
+
p = jnp.where(p_sum > 0.0,
|
|
161
|
+
P_acc / jnp.maximum(p_sum, 1e-300),
|
|
162
|
+
jnp.ones(N) / N)
|
|
163
|
+
|
|
164
|
+
# 3. Weighted sampling with replacement — inverse-CDF method.
|
|
165
|
+
# jax.random.choice(p=p) uses gumbel-max + top_k, which generates a
|
|
166
|
+
# huge XLA graph for N=1M and is very slow to compile. The inverse-CDF
|
|
167
|
+
# approach (cumsum + searchsorted) uses only two efficient GPU primitives
|
|
168
|
+
# and compiles in milliseconds regardless of N.
|
|
169
|
+
cdf = jnp.cumsum(p) # (N,) prefix sum
|
|
170
|
+
u = jax.random.uniform(key, shape=(nr,)) # (nr,) uniform draws
|
|
171
|
+
i_use_raw = jnp.searchsorted(cdf, u, side='right') # (nr,) via binary search
|
|
172
|
+
i_use_raw = jnp.clip(i_use_raw, 0, N - 1)
|
|
173
|
+
|
|
174
|
+
# 4. Evidence (log-mean-exp trick for numerical stability)
|
|
175
|
+
EV = max_L + jnp.log(jnp.nanmean(jnp.exp(L - max_L)))
|
|
176
|
+
|
|
177
|
+
# 5. Reduced chi-squared per data type
|
|
178
|
+
L_accepted = L_per_type_b[:, i_use_raw] # (Ndt, nr)
|
|
179
|
+
n_data_safe = jnp.where(n_data_b > 0, n_data_b, 1.0)
|
|
180
|
+
chi2_vals = jnp.nanmean(-2.0 * L_accepted, axis=1) / n_data_safe
|
|
181
|
+
CHI2 = jnp.where(n_data_b > 0, chi2_vals, jnp.nan)
|
|
182
|
+
|
|
183
|
+
# 6. Unique sample count (sort-diff avoids np.unique's dynamic shape)
|
|
184
|
+
sorted_use = jnp.sort(i_use_raw)
|
|
185
|
+
N_UNIQUE = (jnp.array(1, jnp.int32)
|
|
186
|
+
+ jnp.sum(sorted_use[1:] != sorted_use[:-1]))
|
|
187
|
+
|
|
188
|
+
# 7. Remap sample positions to original prior indices
|
|
189
|
+
i_use = idx_jax[i_use_raw]
|
|
190
|
+
|
|
191
|
+
return i_use, T, EV, CHI2, N_UNIQUE.astype(jnp.float32)
|
|
192
|
+
|
|
193
|
+
return jax.jit(_postprocess_single)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
# ---------------------------------------------------------------------------
|
|
197
|
+
# Public API
|
|
198
|
+
# ---------------------------------------------------------------------------
|
|
199
|
+
|
|
200
|
+
def integrate_rejection_range_jax(
|
|
201
|
+
D,
|
|
202
|
+
DATA,
|
|
203
|
+
idx=[],
|
|
204
|
+
N_use=None,
|
|
205
|
+
id_use=[],
|
|
206
|
+
ip_range=[],
|
|
207
|
+
nr=1000,
|
|
208
|
+
autoT=1,
|
|
209
|
+
T_base=1,
|
|
210
|
+
T_N_above=10,
|
|
211
|
+
T_P_acc_level=0.2,
|
|
212
|
+
progress_callback=None,
|
|
213
|
+
Nbatch=64,
|
|
214
|
+
**kwargs,
|
|
215
|
+
):
|
|
216
|
+
"""
|
|
217
|
+
GPU-efficient JAX replacement for integrate_rejection_range.
|
|
218
|
+
|
|
219
|
+
Likelihood computation and all post-processing (temperature estimation,
|
|
220
|
+
weighted sampling, evidence, CHI2) run on-device. Only the tiny final
|
|
221
|
+
arrays are transferred back to the host per batch:
|
|
222
|
+
|
|
223
|
+
i_use (bsz × nr) ≈ 51 KB [was 256 MB for the likelihood matrix]
|
|
224
|
+
T, EV, N_UNIQUE ≈ 1 KB
|
|
225
|
+
CHI2 (bsz × Ndt) ≈ 1 KB
|
|
226
|
+
|
|
227
|
+
Full-covariance Gaussian and multinomial noise models fall back to the
|
|
228
|
+
original NumPy implementations (their likelihoods are converted to JAX
|
|
229
|
+
arrays once per batch before the on-device post-processing step).
|
|
230
|
+
|
|
231
|
+
Parameters
|
|
232
|
+
----------
|
|
233
|
+
D : list of ndarray — forward-modeled data per data type
|
|
234
|
+
DATA : dict — observed data (same format as load_data)
|
|
235
|
+
idx : list — prior sample indices (empty = sequential)
|
|
236
|
+
N_use : int or None — max prior samples to evaluate
|
|
237
|
+
id_use : list — data-type identifiers to include
|
|
238
|
+
ip_range : list — data-point indices to invert
|
|
239
|
+
nr : int — posterior samples per data point
|
|
240
|
+
autoT : int — 1 = auto temperature, 0 = use T_base
|
|
241
|
+
T_base : float — base temperature when autoT=0
|
|
242
|
+
T_N_above : int — passed to logl_T_est (top-k for T est.)
|
|
243
|
+
T_P_acc_level : float — passed to logl_T_est (target P_acc)
|
|
244
|
+
progress_callback : callable — optional (current, total) callback
|
|
245
|
+
Nbatch : int — data points per JAX batch (default 64)
|
|
246
|
+
**kwargs : use_N_best, showInfo, console_progress, useRandomData, …
|
|
247
|
+
|
|
248
|
+
Returns
|
|
249
|
+
-------
|
|
250
|
+
Same 8-tuple as integrate_rejection_range:
|
|
251
|
+
(i_use_all, T_all, EV_all, EV_post_all, EV_post_all_mean,
|
|
252
|
+
CHI2_all, N_UNIQUE_all, ip_range)
|
|
253
|
+
"""
|
|
254
|
+
_check_jax()
|
|
255
|
+
|
|
256
|
+
import integrate as ig
|
|
257
|
+
from integrate.integrate_rejection import (
|
|
258
|
+
likelihood_gaussian_full,
|
|
259
|
+
likelihood_multinomial,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
_, likelihood_gauss_diag_batch = _get_jax_kernels()
|
|
263
|
+
postprocess_single = _get_postprocess_kernel(nr)
|
|
264
|
+
|
|
265
|
+
# --- Setup (mirrors integrate_rejection_range) --------------------------
|
|
266
|
+
|
|
267
|
+
use_N_best = kwargs.get('use_N_best', 0)
|
|
268
|
+
showInfo = kwargs.get('showInfo', 0)
|
|
269
|
+
console_progress = kwargs.get('console_progress', True)
|
|
270
|
+
disableTqdm = not console_progress if showInfo >= 0 else True
|
|
271
|
+
useRandomData = kwargs.get('useRandomData', True)
|
|
272
|
+
|
|
273
|
+
Ndp = DATA['d_obs'][0].shape[0]
|
|
274
|
+
if len(ip_range) == 0:
|
|
275
|
+
ip_range = np.arange(Ndp)
|
|
276
|
+
nump = len(ip_range)
|
|
277
|
+
|
|
278
|
+
if len(id_use) == 0:
|
|
279
|
+
Ndt = len(DATA['d_obs'])
|
|
280
|
+
id_use = np.arange(Ndt)
|
|
281
|
+
Ndt = len(id_use)
|
|
282
|
+
|
|
283
|
+
noise_model = DATA['noise_model']
|
|
284
|
+
i_use_data = DATA['i_use']
|
|
285
|
+
|
|
286
|
+
# Convert multinomial class IDs to indices (same as original)
|
|
287
|
+
class_is_idx = True
|
|
288
|
+
class_id_list = []
|
|
289
|
+
updated_data_ids = []
|
|
290
|
+
for i in range(Ndt):
|
|
291
|
+
if noise_model[i] == 'multinomial':
|
|
292
|
+
Di, class_id, class_id_out = ig.class_id_to_idx(D[i])
|
|
293
|
+
if class_is_idx and i not in updated_data_ids:
|
|
294
|
+
updated_data_ids.append(i)
|
|
295
|
+
D[i] = Di
|
|
296
|
+
class_id_list.append(class_id_out if class_is_idx else class_id)
|
|
297
|
+
else:
|
|
298
|
+
class_id_list.append([])
|
|
299
|
+
|
|
300
|
+
N = D[0].shape[0]
|
|
301
|
+
if N_use is None:
|
|
302
|
+
N_use = N
|
|
303
|
+
N_use = min(N_use, N)
|
|
304
|
+
if len(idx) == 0:
|
|
305
|
+
idx = np.arange(N_use)
|
|
306
|
+
|
|
307
|
+
# Pre-allocate output arrays
|
|
308
|
+
i_use_all = np.zeros((nump, nr), dtype=np.int32)
|
|
309
|
+
T_all = np.zeros(nump) * np.nan
|
|
310
|
+
EV_all = np.zeros(nump) * np.nan
|
|
311
|
+
EV_post_all = np.zeros(nump) * np.nan # not computed (kept for API compat)
|
|
312
|
+
EV_post_all_mean = np.zeros(nump) * np.nan # not computed (kept for API compat)
|
|
313
|
+
CHI2_all = np.zeros((nump, Ndt)) * np.nan
|
|
314
|
+
N_UNIQUE_all = np.zeros(nump) * np.nan
|
|
315
|
+
|
|
316
|
+
# Transfer D to device once per data type (diagonal-Gaussian only)
|
|
317
|
+
use_jax_diag = [
|
|
318
|
+
noise_model[i] == 'gaussian'
|
|
319
|
+
and DATA['Cd'][0] is None
|
|
320
|
+
and DATA['d_std'][0] is not None
|
|
321
|
+
for i in range(Ndt)
|
|
322
|
+
]
|
|
323
|
+
D_jax = [jnp.asarray(D[i]) if use_jax_diag[i] else None for i in range(Ndt)]
|
|
324
|
+
|
|
325
|
+
# On CPU, jnp.sort(N=1M) is ~54× slower than np.sort — the JAX on-device
|
|
326
|
+
# post-processing path that is fast on GPU becomes a bottleneck on CPU.
|
|
327
|
+
# Detect platform once and fall back to NumPy post-processing on CPU.
|
|
328
|
+
_on_gpu = jax.local_devices()[0].platform != 'cpu'
|
|
329
|
+
|
|
330
|
+
if _on_gpu:
|
|
331
|
+
# GPU path: transfer shared scalars and idx to device once.
|
|
332
|
+
idx_jax = jnp.asarray(idx) if useRandomData else jnp.arange(N, dtype=jnp.int32)
|
|
333
|
+
N_above_jax = jnp.array(T_N_above, dtype=jnp.int32)
|
|
334
|
+
P_acc_lev_jax = jnp.array(T_P_acc_level, dtype=jnp.float32)
|
|
335
|
+
autoT_jax = jnp.array(autoT, dtype=jnp.int32)
|
|
336
|
+
T_base_jax = jnp.array(float(T_base), dtype=jnp.float32)
|
|
337
|
+
rng_key = jax.random.PRNGKey(np.random.randint(0, 2**31))
|
|
338
|
+
|
|
339
|
+
# --- Batch loop ---------------------------------------------------------
|
|
340
|
+
|
|
341
|
+
for batch_start in tqdm(
|
|
342
|
+
range(0, nump, Nbatch),
|
|
343
|
+
disable=disableTqdm,
|
|
344
|
+
desc='Rejection Sampling (JAX)',
|
|
345
|
+
leave=False,
|
|
346
|
+
):
|
|
347
|
+
batch_end = min(batch_start + Nbatch, nump)
|
|
348
|
+
batch_js = range(batch_start, batch_end)
|
|
349
|
+
ip_batch = [ip_range[j] for j in batch_js]
|
|
350
|
+
bsz = len(ip_batch)
|
|
351
|
+
|
|
352
|
+
# Build per-type log-likelihoods.
|
|
353
|
+
# Diagonal-Gaussian: computed in JAX (no GPU→CPU round-trip on GPU).
|
|
354
|
+
# Fallbacks (full-Cd, multinomial): computed in NumPy, then converted.
|
|
355
|
+
L_per_type_list = [] # list of (bsz, N) JAX arrays
|
|
356
|
+
n_data_per_type = np.zeros((bsz, Ndt), dtype=np.float32)
|
|
357
|
+
|
|
358
|
+
for i in range(Ndt):
|
|
359
|
+
# active[b]=1 means data point ip_batch[b] has valid data for type i
|
|
360
|
+
active = np.array([i_use_data[i][ip] for ip in ip_batch]).ravel() # (bsz,)
|
|
361
|
+
|
|
362
|
+
if noise_model[i] == 'gaussian':
|
|
363
|
+
for b, ip in enumerate(ip_batch):
|
|
364
|
+
if active[b]:
|
|
365
|
+
n_data_per_type[b, i] = int(
|
|
366
|
+
np.sum(~np.isnan(DATA['d_obs'][i][ip]))
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
if DATA['Cd'][0] is not None:
|
|
370
|
+
# Full-covariance fallback — NumPy, converted once per batch
|
|
371
|
+
L_np = np.zeros((bsz, N), dtype=np.float32)
|
|
372
|
+
for b, ip in enumerate(ip_batch):
|
|
373
|
+
if active[b]:
|
|
374
|
+
Cd = (DATA['Cd'][0][ip]
|
|
375
|
+
if len(DATA['Cd'][0].shape) == 3
|
|
376
|
+
else DATA['Cd'][0][:])
|
|
377
|
+
L_np[b] = likelihood_gaussian_full(
|
|
378
|
+
D[i], DATA['d_obs'][i][ip], Cd, N_app=use_N_best
|
|
379
|
+
)
|
|
380
|
+
L_per_type_list.append(jnp.asarray(L_np))
|
|
381
|
+
|
|
382
|
+
elif DATA['d_std'][0] is not None:
|
|
383
|
+
# Diagonal case: batched JAX kernel (fast on both CPU and GPU)
|
|
384
|
+
d_obs_batch = np.array([DATA['d_obs'][i][ip] for ip in ip_batch])
|
|
385
|
+
d_std_batch = np.array([DATA['d_std'][i][ip] for ip in ip_batch])
|
|
386
|
+
L_jax = likelihood_gauss_diag_batch(
|
|
387
|
+
D_jax[i],
|
|
388
|
+
jnp.asarray(d_obs_batch),
|
|
389
|
+
jnp.asarray(d_std_batch),
|
|
390
|
+
) # (bsz, N)
|
|
391
|
+
L_per_type_list.append(L_jax * jnp.asarray(active[:, None]))
|
|
392
|
+
|
|
393
|
+
else:
|
|
394
|
+
L_per_type_list.append(jnp.zeros((bsz, N), dtype=jnp.float32))
|
|
395
|
+
|
|
396
|
+
elif noise_model[i] == 'multinomial':
|
|
397
|
+
# Multinomial fallback — NumPy, converted once per batch
|
|
398
|
+
L_np = np.zeros((bsz, N), dtype=np.float32)
|
|
399
|
+
for b, ip in enumerate(ip_batch):
|
|
400
|
+
if active[b]:
|
|
401
|
+
d_obs_ip = DATA['d_obs'][i][ip]
|
|
402
|
+
n_data_per_type[b, i] = int(np.sum(~np.isnan(d_obs_ip)))
|
|
403
|
+
L_np[b] = likelihood_multinomial(
|
|
404
|
+
D[i], d_obs_ip,
|
|
405
|
+
np.array(class_id_list[i]),
|
|
406
|
+
class_is_idx=class_is_idx,
|
|
407
|
+
)
|
|
408
|
+
L_per_type_list.append(jnp.asarray(L_np))
|
|
409
|
+
|
|
410
|
+
else:
|
|
411
|
+
L_per_type_list.append(jnp.zeros((bsz, N), dtype=jnp.float32))
|
|
412
|
+
|
|
413
|
+
# Stack to (Ndt, bsz, N) and combine
|
|
414
|
+
L_per_type_stacked = jnp.stack(L_per_type_list, axis=0) # (Ndt, bsz, N)
|
|
415
|
+
L_combined = jnp.sum(L_per_type_stacked, axis=0) # (bsz, N)
|
|
416
|
+
|
|
417
|
+
if _on_gpu:
|
|
418
|
+
# GPU path: all post-processing on-device, one point at a time.
|
|
419
|
+
# Avoids the ~256 MB PCIe transfer per batch that the CPU path pays
|
|
420
|
+
# for free (same-memory transfer). A Python loop rather than vmap
|
|
421
|
+
# avoids minutes-long XLA fusion for (bsz, N) shaped kernels.
|
|
422
|
+
rng_key, batch_key = jax.random.split(rng_key)
|
|
423
|
+
keys = jax.random.split(batch_key, bsz) # (bsz, 2)
|
|
424
|
+
n_data_jax = jnp.asarray(n_data_per_type) # (bsz, Ndt)
|
|
425
|
+
|
|
426
|
+
for b in range(bsz):
|
|
427
|
+
i_use_b, T_b, EV_b, CHI2_b, N_UNIQUE_b = postprocess_single(
|
|
428
|
+
keys[b],
|
|
429
|
+
L_combined[b],
|
|
430
|
+
L_per_type_stacked[:, b, :],
|
|
431
|
+
n_data_jax[b],
|
|
432
|
+
idx_jax,
|
|
433
|
+
N_above_jax,
|
|
434
|
+
P_acc_lev_jax,
|
|
435
|
+
autoT_jax,
|
|
436
|
+
T_base_jax,
|
|
437
|
+
)
|
|
438
|
+
j = batch_start + b
|
|
439
|
+
i_use_all[j] = np.asarray(i_use_b)
|
|
440
|
+
T_all[j] = float(T_b)
|
|
441
|
+
EV_all[j] = float(EV_b)
|
|
442
|
+
CHI2_all[j] = np.asarray(CHI2_b)
|
|
443
|
+
N_UNIQUE_all[j] = float(N_UNIQUE_b)
|
|
444
|
+
else:
|
|
445
|
+
# CPU path: transfer combined likelihoods to NumPy once per batch,
|
|
446
|
+
# then post-process with NumPy. Avoids jnp.sort(N=1M) which is
|
|
447
|
+
# ~54× slower than np.sort on CPU and dominates wall-clock time.
|
|
448
|
+
L_combined_np = np.asarray(L_combined) # (bsz, N)
|
|
449
|
+
L_per_type_np = np.asarray(L_per_type_stacked) # (Ndt, bsz, N)
|
|
450
|
+
|
|
451
|
+
for b, j in enumerate(batch_js):
|
|
452
|
+
L = L_combined_np[b] # (N,)
|
|
453
|
+
|
|
454
|
+
if autoT == 1:
|
|
455
|
+
T = ig.logl_T_est(L, N_above=T_N_above, P_acc_lev=T_P_acc_level)
|
|
456
|
+
else:
|
|
457
|
+
T = float(T_base)
|
|
458
|
+
|
|
459
|
+
P_acc = np.exp((1.0 / T) * (L - np.nanmax(L)))
|
|
460
|
+
P_acc[np.isnan(P_acc)] = 0.0
|
|
461
|
+
p_sum = P_acc.sum()
|
|
462
|
+
if p_sum > 0:
|
|
463
|
+
p = P_acc / p_sum
|
|
464
|
+
i_use = np.random.choice(N, nr, p=p)
|
|
465
|
+
else:
|
|
466
|
+
i_use = np.random.choice(N, nr)
|
|
467
|
+
|
|
468
|
+
CHI2_current = np.full(Ndt, np.nan)
|
|
469
|
+
for i in range(Ndt):
|
|
470
|
+
if n_data_per_type[b, i] > 0:
|
|
471
|
+
L_acc = L_per_type_np[i, b, i_use]
|
|
472
|
+
CHI2_current[i] = (
|
|
473
|
+
np.nanmean(-2.0 * L_acc) / n_data_per_type[b, i]
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
if useRandomData:
|
|
477
|
+
i_use = idx[i_use]
|
|
478
|
+
|
|
479
|
+
max_L = np.nanmax(L)
|
|
480
|
+
EV = max_L + np.log(np.nanmean(np.exp(L - max_L)))
|
|
481
|
+
|
|
482
|
+
i_use_all[j] = i_use
|
|
483
|
+
T_all[j] = T
|
|
484
|
+
EV_all[j] = EV
|
|
485
|
+
CHI2_all[j] = CHI2_current
|
|
486
|
+
N_UNIQUE_all[j] = len(np.unique(i_use))
|
|
487
|
+
|
|
488
|
+
if progress_callback is not None:
|
|
489
|
+
progress_callback(batch_end, nump)
|
|
490
|
+
|
|
491
|
+
return (
|
|
492
|
+
i_use_all, T_all, EV_all, EV_post_all, EV_post_all_mean,
|
|
493
|
+
CHI2_all, N_UNIQUE_all, ip_range,
|
|
494
|
+
)
|