integrate_module 0.99.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,210 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ INTEGRATE Rejection Sampling CLI
4
+
5
+ Command-line interface for probabilistic inversion using rejection sampling.
6
+ Provides access to the integrate_rejection function with various options for
7
+ Bayesian inversion and posterior sampling.
8
+
9
+ Author: Thomas Mejer Hansen
10
+ Email: tmeha@geo.au.dk
11
+ """
12
+
13
+ import argparse
14
+ import sys
15
+ import os
16
+ import multiprocessing
17
+
18
+ # Import the integrate module
19
+ try:
20
+ import integrate as ig
21
+ except ImportError:
22
+ print("Error: Could not import integrate module. Please ensure it is properly installed.")
23
+ sys.exit(1)
24
+
25
+ def main():
26
+ """Entry point for the integrate_rejection command."""
27
+
28
+ # Set up multiprocessing support
29
+ multiprocessing.freeze_support()
30
+
31
+ # Create argument parser
32
+ parser = argparse.ArgumentParser(
33
+ description='INTEGRATE rejection sampling for Bayesian inversion',
34
+ formatter_class=argparse.RawDescriptionHelpFormatter,
35
+ epilog="""
36
+ Examples:
37
+ integrate_rejection --prior prior.h5 --data data.h5 --output post.h5
38
+ integrate_rejection --prior prior.h5 --data data.h5 --samples 1000000 --parallel
39
+ integrate_rejection --prior prior.h5 --data data.h5 --auto-temp --cpus 4
40
+
41
+ For more information, see the INTEGRATE documentation.
42
+ """
43
+ )
44
+
45
+ # Required arguments
46
+ parser.add_argument('--prior', '-p',
47
+ type=str,
48
+ required=True,
49
+ help='Path to HDF5 file containing prior model and data samples')
50
+
51
+ parser.add_argument('--data', '-d',
52
+ type=str,
53
+ required=True,
54
+ help='Path to HDF5 file containing observed data for inversion')
55
+
56
+ # Optional arguments
57
+ parser.add_argument('--output', '-o',
58
+ type=str,
59
+ default='',
60
+ help='Output path for posterior samples (auto-generated if not specified)')
61
+
62
+ parser.add_argument('--samples', '-n',
63
+ type=int,
64
+ default=100000000,
65
+ help='Maximum number of prior samples to use for inversion (default: 100000000)')
66
+
67
+ parser.add_argument('--auto-temp', '-T',
68
+ action='store_true',
69
+ help='Enable automatic temperature estimation (default: disabled)')
70
+
71
+ parser.add_argument('--temp-base',
72
+ type=float,
73
+ default=1.0,
74
+ help='Base temperature for sampling (default: 1.0)')
75
+
76
+ parser.add_argument('--nr',
77
+ type=int,
78
+ default=400,
79
+ help='Number of resamples for temperature estimation (default: 400)')
80
+
81
+ parser.add_argument('--cpus', '-c',
82
+ type=int,
83
+ default=0,
84
+ help='Number of CPU cores to use (0 = auto-detect, default: 0)')
85
+
86
+ parser.add_argument('--no-parallel',
87
+ action='store_true',
88
+ help='Disable parallel processing')
89
+
90
+ parser.add_argument('--chunks',
91
+ type=int,
92
+ default=0,
93
+ help='Number of chunks for processing (0 = auto, default: 0)')
94
+
95
+ parser.add_argument('--id-use',
96
+ type=str,
97
+ default='',
98
+ help='Comma-separated list of data IDs to use for inversion')
99
+
100
+ parser.add_argument('--ip-range',
101
+ type=str,
102
+ default='',
103
+ help='Comma-separated IP range for distributed processing')
104
+
105
+ parser.add_argument('--use-n-best',
106
+ type=int,
107
+ default=0,
108
+ help='Use N best samples for analysis (default: 0)')
109
+
110
+ parser.add_argument('--backend',
111
+ choices=['numpy', 'jax'],
112
+ default='numpy',
113
+ help='Rejection sampling backend: numpy (default) or jax')
114
+
115
+ parser.add_argument('--verbose', '-v',
116
+ action='store_true',
117
+ help='Enable verbose output')
118
+
119
+ parser.add_argument('--version',
120
+ action='store_true',
121
+ help='Show version information')
122
+
123
+ # Parse arguments
124
+ args = parser.parse_args()
125
+
126
+ # Handle version request
127
+ if args.version:
128
+ try:
129
+ from integrate import __version__
130
+ print(f"INTEGRATE version: {__version__}")
131
+ except ImportError:
132
+ print("INTEGRATE version: unknown")
133
+ return 0
134
+
135
+ # Validate input files
136
+ if not os.path.exists(args.prior):
137
+ print(f"Error: Prior file not found: {args.prior}")
138
+ return 1
139
+
140
+ if not os.path.exists(args.data):
141
+ print(f"Error: Data file not found: {args.data}")
142
+ return 1
143
+
144
+ # Parse comma-separated arguments
145
+ id_use = []
146
+ if args.id_use:
147
+ try:
148
+ id_use = [int(x.strip()) for x in args.id_use.split(',')]
149
+ except ValueError:
150
+ print(f"Error: Invalid ID list format: {args.id_use}")
151
+ return 1
152
+
153
+ ip_range = []
154
+ if args.ip_range:
155
+ ip_range = [x.strip() for x in args.ip_range.split(',')]
156
+
157
+ # Set up parallel processing
158
+ parallel = not args.no_parallel
159
+ if parallel:
160
+ # Check if parallel processing is supported
161
+ parallel = ig.use_parallel(showInfo=1 if args.verbose else 0)
162
+
163
+ # Print configuration if verbose
164
+ if args.verbose:
165
+ print("Configuration:")
166
+ print(f" Prior file: {args.prior}")
167
+ print(f" Data file: {args.data}")
168
+ print(f" Output file: {args.output if args.output else 'auto-generated'}")
169
+ print(f" Max samples: {args.samples}")
170
+ print(f" Auto temperature: {args.auto_temp}")
171
+ print(f" Base temperature: {args.temp_base}")
172
+ print(f" Parallel processing: {parallel}")
173
+ print(f" CPU cores: {args.cpus if args.cpus > 0 else 'auto-detect'}")
174
+ print(f" Backend: {args.backend}")
175
+ print("")
176
+
177
+ try:
178
+ # Call the integrate_rejection function
179
+ f_post_h5 = ig.integrate_rejection(
180
+ f_prior_h5=args.prior,
181
+ f_data_h5=args.data,
182
+ f_post_h5=args.output,
183
+ N_use=args.samples,
184
+ id_use=id_use,
185
+ ip_range=ip_range,
186
+ nr=args.nr,
187
+ autoT=1 if args.auto_temp else 0,
188
+ T_base=args.temp_base,
189
+ Nchunks=args.chunks,
190
+ Ncpu=args.cpus,
191
+ parallel=parallel,
192
+ use_N_best=args.use_n_best,
193
+ backend=args.backend,
194
+ showInfo=1 if args.verbose else 0
195
+ )
196
+
197
+ print(f"Rejection sampling completed successfully.")
198
+ print(f"Posterior samples saved to: {f_post_h5}")
199
+
200
+ return 0
201
+
202
+ except Exception as e:
203
+ print(f"Error during rejection sampling: {str(e)}")
204
+ if args.verbose:
205
+ import traceback
206
+ traceback.print_exc()
207
+ return 1
208
+
209
+ if __name__ == "__main__":
210
+ sys.exit(main())
@@ -0,0 +1,494 @@
1
+ """
2
+ JAX backend for integrate_rejection likelihood calculations.
3
+
4
+ All computation — likelihood evaluation AND post-processing (temperature
5
+ estimation, weighted sampling, evidence, CHI2) — is performed on-device.
6
+ Only the tiny final arrays are transferred back to the host:
7
+
8
+ Old: (bsz, N) likelihood matrix → ~256 MB per batch (N=1 M, bsz=64)
9
+ New: i_use + scalars → ~52 KB per batch (~5000× less)
10
+
11
+ This eliminates the PCIe bottleneck that made the old GPU path ~49× slower
12
+ than CPU despite the faster kernel.
13
+
14
+ Usage
15
+ -----
16
+ from integrate.integrate_rejection_jax import integrate_rejection_range_jax
17
+ # or via integrate_rejection(backend='jax', ...)
18
+ """
19
+
20
+ import os
21
+ import functools
22
+ import numpy as np
23
+ from tqdm import tqdm
24
+
25
+ # Disable JAX's default behaviour of pre-allocating ~75 % of GPU VRAM upfront.
26
+ # Without this, JAX tries to grab ~18 GB on a 24 GB card at import time, which
27
+ # fails when the display driver or other processes already occupy some VRAM.
28
+ # Must be set before `import jax`.
29
+ os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
30
+
31
+ try:
32
+ import jax
33
+ import jax.numpy as jnp
34
+ # Cache compiled GPU kernels to disk. GPU kernel compilation for large
35
+ # static shapes (N=1M sort, cumsum, searchsorted) takes ~40s on first run.
36
+ # With the cache, every subsequent run reloads compiled kernels and warmup
37
+ # drops to ~1s. Must be set via jax.config.update (not an env var) in
38
+ # JAX 0.10+; the cache is keyed on kernel + GPU arch so it is safe to share.
39
+ _cache_dir = os.path.expanduser("~/.cache/jax_xla_gpu")
40
+ os.makedirs(_cache_dir, exist_ok=True)
41
+ jax.config.update("jax_compilation_cache_dir", _cache_dir)
42
+ _JAX_AVAILABLE = True
43
+ except ImportError:
44
+ _JAX_AVAILABLE = False
45
+
46
+
47
+ def _check_jax():
48
+ if not _JAX_AVAILABLE:
49
+ raise ImportError(
50
+ "JAX is required for backend='jax'.\n"
51
+ "Install with: pip install jax (CPU)\n"
52
+ " or: pip install jax[cuda12] (GPU)"
53
+ )
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # JAX likelihood kernels (built lazily on first use)
58
+ # ---------------------------------------------------------------------------
59
+
60
+ _single_kernel = None
61
+ _batch_kernel = None
62
+
63
+
64
+ def _get_jax_kernels():
65
+ """Return (single, batch) JIT-compiled Gaussian-diagonal likelihood fns."""
66
+ global _single_kernel, _batch_kernel
67
+ if _single_kernel is not None:
68
+ return _single_kernel, _batch_kernel
69
+
70
+ @jax.jit
71
+ def _likelihood_gaussian_diagonal_jax(D, d_obs, d_std):
72
+ """
73
+ JIT-compiled Gaussian diagonal log-likelihood for one data point.
74
+
75
+ Parameters
76
+ ----------
77
+ D : jax array (N, Nf) — prior forward-model predictions
78
+ d_obs : jax array (Nf,) — observed data (may contain NaN)
79
+ d_std : jax array (Nf,) — per-feature standard deviation
80
+
81
+ Returns
82
+ -------
83
+ jax array (N,) — log-likelihood for each prior sample
84
+ """
85
+ valid = ~(jnp.isnan(d_obs) | jnp.isnan(d_std))
86
+ d_obs_s = jnp.where(valid, d_obs, 0.0)
87
+ d_std_s = jnp.where(valid, d_std, 1.0)
88
+ dd = D - d_obs_s
89
+ return -0.5 * jnp.sum(valid * (dd / d_std_s) ** 2, axis=1)
90
+
91
+ # Vectorise over a batch of data points; D is shared (in_axes=(None, 0, 0))
92
+ _likelihood_gaussian_diagonal_batch_jax = jax.jit(
93
+ jax.vmap(_likelihood_gaussian_diagonal_jax, in_axes=(None, 0, 0))
94
+ )
95
+
96
+ _single_kernel = _likelihood_gaussian_diagonal_jax
97
+ _batch_kernel = _likelihood_gaussian_diagonal_batch_jax
98
+ return _single_kernel, _batch_kernel
99
+
100
+
101
+ # ---------------------------------------------------------------------------
102
+ # JAX post-processing kernels — temperature, sampling, EV, CHI2 on-device
103
+ # ---------------------------------------------------------------------------
104
+
105
+ def _logl_T_est_jax(L, N_above, P_acc_lev):
106
+ """
107
+ JAX port of integrate.logl_T_est.
108
+
109
+ Estimates an annealing temperature from the log-likelihood vector L.
110
+ Uses jax.lax.dynamic_index_in_dim for a data-dependent index that is
111
+ still JIT-safe. Returns jnp.inf when all L values are NaN, enforces T>=1
112
+ otherwise.
113
+ """
114
+ L_norm = L - jnp.nanmax(L) # shift so max = 0
115
+ sorted_L = jnp.sort(L_norm) # NaN sorts to end in XLA
116
+ n_valid = jnp.sum(~jnp.isnan(L)).astype(jnp.int32)
117
+ idx = jnp.maximum(jnp.array(0, jnp.int32), n_valid - N_above - 1)
118
+ logL_lev = jax.lax.dynamic_index_in_dim(sorted_L, idx, axis=0, keepdims=False)
119
+ T_est = logL_lev / jnp.log(P_acc_lev)
120
+ T_est = jnp.maximum(jnp.array(1.0), T_est)
121
+ return jnp.where(n_valid > 0, T_est, jnp.inf)
122
+
123
+
124
+ @functools.lru_cache(maxsize=8)
125
+ def _get_postprocess_kernel(nr):
126
+ """
127
+ Build and cache a JIT-compiled post-processing kernel for nr samples.
128
+
129
+ Keyed on `nr` because it determines the output shape of the uniform draws.
130
+ Called once per data point in a Python loop — compiles for (N,) shaped
131
+ tensors rather than (bsz, N), which avoids the minutes-long XLA fusion that
132
+ vmap over large N would trigger.
133
+ """
134
+
135
+ def _postprocess_single(key, L, L_per_type_b, n_data_b, idx_jax,
136
+ N_above, P_acc_lev, autoT, T_base):
137
+ """
138
+ Full post-processing for one data point.
139
+
140
+ Parameters
141
+ ----------
142
+ key : (2,) — PRNG key
143
+ L : (N,) — combined log-likelihood
144
+ L_per_type_b : (Ndt, N) — per-type log-likelihoods
145
+ n_data_b : (Ndt,) — non-NaN observation count per type
146
+ idx_jax : (N,) — maps sample position → original prior idx
147
+ """
148
+ N = L.shape[0]
149
+
150
+ # 1. Temperature estimation
151
+ T_auto = _logl_T_est_jax(L, N_above, P_acc_lev)
152
+ T = jnp.where(autoT == 1, T_auto, T_base)
153
+
154
+ # 2. Acceptance probabilities (numerically stable)
155
+ max_L = jnp.nanmax(L)
156
+ P_acc = jnp.exp((1.0 / T) * (L - max_L))
157
+ P_acc = jnp.where(jnp.isnan(P_acc), 0.0, P_acc)
158
+ p_sum = jnp.sum(P_acc)
159
+ # Fall back to uniform when all weights collapse to zero
160
+ p = jnp.where(p_sum > 0.0,
161
+ P_acc / jnp.maximum(p_sum, 1e-300),
162
+ jnp.ones(N) / N)
163
+
164
+ # 3. Weighted sampling with replacement — inverse-CDF method.
165
+ # jax.random.choice(p=p) uses gumbel-max + top_k, which generates a
166
+ # huge XLA graph for N=1M and is very slow to compile. The inverse-CDF
167
+ # approach (cumsum + searchsorted) uses only two efficient GPU primitives
168
+ # and compiles in milliseconds regardless of N.
169
+ cdf = jnp.cumsum(p) # (N,) prefix sum
170
+ u = jax.random.uniform(key, shape=(nr,)) # (nr,) uniform draws
171
+ i_use_raw = jnp.searchsorted(cdf, u, side='right') # (nr,) via binary search
172
+ i_use_raw = jnp.clip(i_use_raw, 0, N - 1)
173
+
174
+ # 4. Evidence (log-mean-exp trick for numerical stability)
175
+ EV = max_L + jnp.log(jnp.nanmean(jnp.exp(L - max_L)))
176
+
177
+ # 5. Reduced chi-squared per data type
178
+ L_accepted = L_per_type_b[:, i_use_raw] # (Ndt, nr)
179
+ n_data_safe = jnp.where(n_data_b > 0, n_data_b, 1.0)
180
+ chi2_vals = jnp.nanmean(-2.0 * L_accepted, axis=1) / n_data_safe
181
+ CHI2 = jnp.where(n_data_b > 0, chi2_vals, jnp.nan)
182
+
183
+ # 6. Unique sample count (sort-diff avoids np.unique's dynamic shape)
184
+ sorted_use = jnp.sort(i_use_raw)
185
+ N_UNIQUE = (jnp.array(1, jnp.int32)
186
+ + jnp.sum(sorted_use[1:] != sorted_use[:-1]))
187
+
188
+ # 7. Remap sample positions to original prior indices
189
+ i_use = idx_jax[i_use_raw]
190
+
191
+ return i_use, T, EV, CHI2, N_UNIQUE.astype(jnp.float32)
192
+
193
+ return jax.jit(_postprocess_single)
194
+
195
+
196
+ # ---------------------------------------------------------------------------
197
+ # Public API
198
+ # ---------------------------------------------------------------------------
199
+
200
+ def integrate_rejection_range_jax(
201
+ D,
202
+ DATA,
203
+ idx=[],
204
+ N_use=None,
205
+ id_use=[],
206
+ ip_range=[],
207
+ nr=1000,
208
+ autoT=1,
209
+ T_base=1,
210
+ T_N_above=10,
211
+ T_P_acc_level=0.2,
212
+ progress_callback=None,
213
+ Nbatch=64,
214
+ **kwargs,
215
+ ):
216
+ """
217
+ GPU-efficient JAX replacement for integrate_rejection_range.
218
+
219
+ Likelihood computation and all post-processing (temperature estimation,
220
+ weighted sampling, evidence, CHI2) run on-device. Only the tiny final
221
+ arrays are transferred back to the host per batch:
222
+
223
+ i_use (bsz × nr) ≈ 51 KB [was 256 MB for the likelihood matrix]
224
+ T, EV, N_UNIQUE ≈ 1 KB
225
+ CHI2 (bsz × Ndt) ≈ 1 KB
226
+
227
+ Full-covariance Gaussian and multinomial noise models fall back to the
228
+ original NumPy implementations (their likelihoods are converted to JAX
229
+ arrays once per batch before the on-device post-processing step).
230
+
231
+ Parameters
232
+ ----------
233
+ D : list of ndarray — forward-modeled data per data type
234
+ DATA : dict — observed data (same format as load_data)
235
+ idx : list — prior sample indices (empty = sequential)
236
+ N_use : int or None — max prior samples to evaluate
237
+ id_use : list — data-type identifiers to include
238
+ ip_range : list — data-point indices to invert
239
+ nr : int — posterior samples per data point
240
+ autoT : int — 1 = auto temperature, 0 = use T_base
241
+ T_base : float — base temperature when autoT=0
242
+ T_N_above : int — passed to logl_T_est (top-k for T est.)
243
+ T_P_acc_level : float — passed to logl_T_est (target P_acc)
244
+ progress_callback : callable — optional (current, total) callback
245
+ Nbatch : int — data points per JAX batch (default 64)
246
+ **kwargs : use_N_best, showInfo, console_progress, useRandomData, …
247
+
248
+ Returns
249
+ -------
250
+ Same 8-tuple as integrate_rejection_range:
251
+ (i_use_all, T_all, EV_all, EV_post_all, EV_post_all_mean,
252
+ CHI2_all, N_UNIQUE_all, ip_range)
253
+ """
254
+ _check_jax()
255
+
256
+ import integrate as ig
257
+ from integrate.integrate_rejection import (
258
+ likelihood_gaussian_full,
259
+ likelihood_multinomial,
260
+ )
261
+
262
+ _, likelihood_gauss_diag_batch = _get_jax_kernels()
263
+ postprocess_single = _get_postprocess_kernel(nr)
264
+
265
+ # --- Setup (mirrors integrate_rejection_range) --------------------------
266
+
267
+ use_N_best = kwargs.get('use_N_best', 0)
268
+ showInfo = kwargs.get('showInfo', 0)
269
+ console_progress = kwargs.get('console_progress', True)
270
+ disableTqdm = not console_progress if showInfo >= 0 else True
271
+ useRandomData = kwargs.get('useRandomData', True)
272
+
273
+ Ndp = DATA['d_obs'][0].shape[0]
274
+ if len(ip_range) == 0:
275
+ ip_range = np.arange(Ndp)
276
+ nump = len(ip_range)
277
+
278
+ if len(id_use) == 0:
279
+ Ndt = len(DATA['d_obs'])
280
+ id_use = np.arange(Ndt)
281
+ Ndt = len(id_use)
282
+
283
+ noise_model = DATA['noise_model']
284
+ i_use_data = DATA['i_use']
285
+
286
+ # Convert multinomial class IDs to indices (same as original)
287
+ class_is_idx = True
288
+ class_id_list = []
289
+ updated_data_ids = []
290
+ for i in range(Ndt):
291
+ if noise_model[i] == 'multinomial':
292
+ Di, class_id, class_id_out = ig.class_id_to_idx(D[i])
293
+ if class_is_idx and i not in updated_data_ids:
294
+ updated_data_ids.append(i)
295
+ D[i] = Di
296
+ class_id_list.append(class_id_out if class_is_idx else class_id)
297
+ else:
298
+ class_id_list.append([])
299
+
300
+ N = D[0].shape[0]
301
+ if N_use is None:
302
+ N_use = N
303
+ N_use = min(N_use, N)
304
+ if len(idx) == 0:
305
+ idx = np.arange(N_use)
306
+
307
+ # Pre-allocate output arrays
308
+ i_use_all = np.zeros((nump, nr), dtype=np.int32)
309
+ T_all = np.zeros(nump) * np.nan
310
+ EV_all = np.zeros(nump) * np.nan
311
+ EV_post_all = np.zeros(nump) * np.nan # not computed (kept for API compat)
312
+ EV_post_all_mean = np.zeros(nump) * np.nan # not computed (kept for API compat)
313
+ CHI2_all = np.zeros((nump, Ndt)) * np.nan
314
+ N_UNIQUE_all = np.zeros(nump) * np.nan
315
+
316
+ # Transfer D to device once per data type (diagonal-Gaussian only)
317
+ use_jax_diag = [
318
+ noise_model[i] == 'gaussian'
319
+ and DATA['Cd'][0] is None
320
+ and DATA['d_std'][0] is not None
321
+ for i in range(Ndt)
322
+ ]
323
+ D_jax = [jnp.asarray(D[i]) if use_jax_diag[i] else None for i in range(Ndt)]
324
+
325
+ # On CPU, jnp.sort(N=1M) is ~54× slower than np.sort — the JAX on-device
326
+ # post-processing path that is fast on GPU becomes a bottleneck on CPU.
327
+ # Detect platform once and fall back to NumPy post-processing on CPU.
328
+ _on_gpu = jax.local_devices()[0].platform != 'cpu'
329
+
330
+ if _on_gpu:
331
+ # GPU path: transfer shared scalars and idx to device once.
332
+ idx_jax = jnp.asarray(idx) if useRandomData else jnp.arange(N, dtype=jnp.int32)
333
+ N_above_jax = jnp.array(T_N_above, dtype=jnp.int32)
334
+ P_acc_lev_jax = jnp.array(T_P_acc_level, dtype=jnp.float32)
335
+ autoT_jax = jnp.array(autoT, dtype=jnp.int32)
336
+ T_base_jax = jnp.array(float(T_base), dtype=jnp.float32)
337
+ rng_key = jax.random.PRNGKey(np.random.randint(0, 2**31))
338
+
339
+ # --- Batch loop ---------------------------------------------------------
340
+
341
+ for batch_start in tqdm(
342
+ range(0, nump, Nbatch),
343
+ disable=disableTqdm,
344
+ desc='Rejection Sampling (JAX)',
345
+ leave=False,
346
+ ):
347
+ batch_end = min(batch_start + Nbatch, nump)
348
+ batch_js = range(batch_start, batch_end)
349
+ ip_batch = [ip_range[j] for j in batch_js]
350
+ bsz = len(ip_batch)
351
+
352
+ # Build per-type log-likelihoods.
353
+ # Diagonal-Gaussian: computed in JAX (no GPU→CPU round-trip on GPU).
354
+ # Fallbacks (full-Cd, multinomial): computed in NumPy, then converted.
355
+ L_per_type_list = [] # list of (bsz, N) JAX arrays
356
+ n_data_per_type = np.zeros((bsz, Ndt), dtype=np.float32)
357
+
358
+ for i in range(Ndt):
359
+ # active[b]=1 means data point ip_batch[b] has valid data for type i
360
+ active = np.array([i_use_data[i][ip] for ip in ip_batch]).ravel() # (bsz,)
361
+
362
+ if noise_model[i] == 'gaussian':
363
+ for b, ip in enumerate(ip_batch):
364
+ if active[b]:
365
+ n_data_per_type[b, i] = int(
366
+ np.sum(~np.isnan(DATA['d_obs'][i][ip]))
367
+ )
368
+
369
+ if DATA['Cd'][0] is not None:
370
+ # Full-covariance fallback — NumPy, converted once per batch
371
+ L_np = np.zeros((bsz, N), dtype=np.float32)
372
+ for b, ip in enumerate(ip_batch):
373
+ if active[b]:
374
+ Cd = (DATA['Cd'][0][ip]
375
+ if len(DATA['Cd'][0].shape) == 3
376
+ else DATA['Cd'][0][:])
377
+ L_np[b] = likelihood_gaussian_full(
378
+ D[i], DATA['d_obs'][i][ip], Cd, N_app=use_N_best
379
+ )
380
+ L_per_type_list.append(jnp.asarray(L_np))
381
+
382
+ elif DATA['d_std'][0] is not None:
383
+ # Diagonal case: batched JAX kernel (fast on both CPU and GPU)
384
+ d_obs_batch = np.array([DATA['d_obs'][i][ip] for ip in ip_batch])
385
+ d_std_batch = np.array([DATA['d_std'][i][ip] for ip in ip_batch])
386
+ L_jax = likelihood_gauss_diag_batch(
387
+ D_jax[i],
388
+ jnp.asarray(d_obs_batch),
389
+ jnp.asarray(d_std_batch),
390
+ ) # (bsz, N)
391
+ L_per_type_list.append(L_jax * jnp.asarray(active[:, None]))
392
+
393
+ else:
394
+ L_per_type_list.append(jnp.zeros((bsz, N), dtype=jnp.float32))
395
+
396
+ elif noise_model[i] == 'multinomial':
397
+ # Multinomial fallback — NumPy, converted once per batch
398
+ L_np = np.zeros((bsz, N), dtype=np.float32)
399
+ for b, ip in enumerate(ip_batch):
400
+ if active[b]:
401
+ d_obs_ip = DATA['d_obs'][i][ip]
402
+ n_data_per_type[b, i] = int(np.sum(~np.isnan(d_obs_ip)))
403
+ L_np[b] = likelihood_multinomial(
404
+ D[i], d_obs_ip,
405
+ np.array(class_id_list[i]),
406
+ class_is_idx=class_is_idx,
407
+ )
408
+ L_per_type_list.append(jnp.asarray(L_np))
409
+
410
+ else:
411
+ L_per_type_list.append(jnp.zeros((bsz, N), dtype=jnp.float32))
412
+
413
+ # Stack to (Ndt, bsz, N) and combine
414
+ L_per_type_stacked = jnp.stack(L_per_type_list, axis=0) # (Ndt, bsz, N)
415
+ L_combined = jnp.sum(L_per_type_stacked, axis=0) # (bsz, N)
416
+
417
+ if _on_gpu:
418
+ # GPU path: all post-processing on-device, one point at a time.
419
+ # Avoids the ~256 MB PCIe transfer per batch that the CPU path pays
420
+ # for free (same-memory transfer). A Python loop rather than vmap
421
+ # avoids minutes-long XLA fusion for (bsz, N) shaped kernels.
422
+ rng_key, batch_key = jax.random.split(rng_key)
423
+ keys = jax.random.split(batch_key, bsz) # (bsz, 2)
424
+ n_data_jax = jnp.asarray(n_data_per_type) # (bsz, Ndt)
425
+
426
+ for b in range(bsz):
427
+ i_use_b, T_b, EV_b, CHI2_b, N_UNIQUE_b = postprocess_single(
428
+ keys[b],
429
+ L_combined[b],
430
+ L_per_type_stacked[:, b, :],
431
+ n_data_jax[b],
432
+ idx_jax,
433
+ N_above_jax,
434
+ P_acc_lev_jax,
435
+ autoT_jax,
436
+ T_base_jax,
437
+ )
438
+ j = batch_start + b
439
+ i_use_all[j] = np.asarray(i_use_b)
440
+ T_all[j] = float(T_b)
441
+ EV_all[j] = float(EV_b)
442
+ CHI2_all[j] = np.asarray(CHI2_b)
443
+ N_UNIQUE_all[j] = float(N_UNIQUE_b)
444
+ else:
445
+ # CPU path: transfer combined likelihoods to NumPy once per batch,
446
+ # then post-process with NumPy. Avoids jnp.sort(N=1M) which is
447
+ # ~54× slower than np.sort on CPU and dominates wall-clock time.
448
+ L_combined_np = np.asarray(L_combined) # (bsz, N)
449
+ L_per_type_np = np.asarray(L_per_type_stacked) # (Ndt, bsz, N)
450
+
451
+ for b, j in enumerate(batch_js):
452
+ L = L_combined_np[b] # (N,)
453
+
454
+ if autoT == 1:
455
+ T = ig.logl_T_est(L, N_above=T_N_above, P_acc_lev=T_P_acc_level)
456
+ else:
457
+ T = float(T_base)
458
+
459
+ P_acc = np.exp((1.0 / T) * (L - np.nanmax(L)))
460
+ P_acc[np.isnan(P_acc)] = 0.0
461
+ p_sum = P_acc.sum()
462
+ if p_sum > 0:
463
+ p = P_acc / p_sum
464
+ i_use = np.random.choice(N, nr, p=p)
465
+ else:
466
+ i_use = np.random.choice(N, nr)
467
+
468
+ CHI2_current = np.full(Ndt, np.nan)
469
+ for i in range(Ndt):
470
+ if n_data_per_type[b, i] > 0:
471
+ L_acc = L_per_type_np[i, b, i_use]
472
+ CHI2_current[i] = (
473
+ np.nanmean(-2.0 * L_acc) / n_data_per_type[b, i]
474
+ )
475
+
476
+ if useRandomData:
477
+ i_use = idx[i_use]
478
+
479
+ max_L = np.nanmax(L)
480
+ EV = max_L + np.log(np.nanmean(np.exp(L - max_L)))
481
+
482
+ i_use_all[j] = i_use
483
+ T_all[j] = T
484
+ EV_all[j] = EV
485
+ CHI2_all[j] = CHI2_current
486
+ N_UNIQUE_all[j] = len(np.unique(i_use))
487
+
488
+ if progress_callback is not None:
489
+ progress_callback(batch_end, nump)
490
+
491
+ return (
492
+ i_use_all, T_all, EV_all, EV_post_all, EV_post_all_mean,
493
+ CHI2_all, N_UNIQUE_all, ip_range,
494
+ )