integrate_module 0.99.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1609 @@
1
+ """
2
+ Query posterior realizations based on geophysical constraints.
3
+
4
+ This module provides tools to compute probabilities that posterior realizations
5
+ from Bayesian inversion satisfy user-defined constraints (e.g., thickness of
6
+ lithology classes, resistivity thresholds).
7
+ """
8
+
9
+ import json
10
+ import os
11
+ import numpy as np
12
+ import h5py
13
+ from tqdm import tqdm
14
+
15
+
16
+ def _get_clipped_thickness(z, depth_min, depth_max):
17
+ """
18
+ Compute per-layer thickness clipped to a depth interval.
19
+
20
+ Parameters
21
+ ----------
22
+ z : ndarray (N_depth,)
23
+ Depth values [m] at layer interfaces.
24
+ depth_min : float, None, or ndarray (N_samples,)
25
+ Lower depth bound [m]. Array triggers per-sample mode.
26
+ depth_max : float, None, or ndarray (N_samples,)
27
+ Upper depth bound [m]. Array triggers per-sample mode.
28
+
29
+ Returns
30
+ -------
31
+ t : ndarray (N_depth - 1,) or (N_samples, N_depth - 1)
32
+ Clipped thickness per layer. 2-D when bounds are per-sample.
33
+ mask : ndarray (N_depth - 1,) bool or None
34
+ True where clipped thickness > 0 (scalar bounds). None when 2-D.
35
+ """
36
+ n_layers = len(z) - 1
37
+ top = z[:n_layers]
38
+ bot = z[1:]
39
+ lo = depth_min if depth_min is not None else -np.inf
40
+ hi = depth_max if depth_max is not None else np.inf
41
+
42
+ if np.ndim(lo) > 0 or np.ndim(hi) > 0:
43
+ # Per-sample bounds: broadcast (N_samples,1) against (N_layers,)
44
+ lo2 = np.asarray(lo)[:, np.newaxis] if np.ndim(lo) > 0 else lo
45
+ hi2 = np.asarray(hi)[:, np.newaxis] if np.ndim(hi) > 0 else hi
46
+ t = np.maximum(np.minimum(bot, hi2) - np.maximum(top, lo2), 0.0)
47
+ return t, None # 2-D t; mask handled downstream
48
+ else:
49
+ t = np.maximum(np.minimum(bot, hi) - np.maximum(top, lo), 0.0)
50
+ return t, t > 0
51
+
52
+
53
+ def _first_occurrence_thickness(condition, t):
54
+ """
55
+ Thickness of the first contiguous True run per sample row.
56
+
57
+ Parameters
58
+ ----------
59
+ condition : ndarray (N_samples, N_layers) bool
60
+ t : ndarray (N_layers,) float
61
+
62
+ Returns
63
+ -------
64
+ metric : ndarray (N_samples,) float
65
+ """
66
+ metric = np.zeros(condition.shape[0])
67
+ has_true = condition.any(axis=1)
68
+ first_idx = np.argmax(condition, axis=1)
69
+ t_2d = t.ndim == 2
70
+ for i in np.where(has_true)[0]:
71
+ j = first_idx[i]
72
+ while j < condition.shape[1] and condition[i, j]:
73
+ metric[i] += t[i, j] if t_2d else t[j]
74
+ j += 1
75
+ return metric
76
+
77
+
78
+ def _compute_metric(M_samples, z, metric_def, scalar_model_values=None):
79
+ """
80
+ Compute the raw per-realization metric value for a model.
81
+
82
+ Parameters
83
+ ----------
84
+ M_samples : ndarray (N_samples, N_depth)
85
+ Model values for the sampled realizations.
86
+ z : ndarray (N_depth,)
87
+ Depth values [m].
88
+ metric_def : dict
89
+ Metric definition — same fields as a constraint minus comparison fields.
90
+ Supports: im, classes, value_comparison, value_threshold, thickness_mode,
91
+ depth_min, depth_max, depth_max_im, depth_min_im.
92
+ scalar_model_values : dict {im: ndarray (N_samples,)}, optional
93
+ Per-sample scalar model values for dynamic depth bounds.
94
+
95
+ Returns
96
+ -------
97
+ metric : ndarray (N_samples,)
98
+ For scalar models: raw model value per realization.
99
+ For depth models: cumulative or first-occurrence thickness [m] of
100
+ layers satisfying the class/value condition.
101
+ """
102
+ depth_min = metric_def.get('depth_min', None)
103
+ depth_max = metric_def.get('depth_max', None)
104
+
105
+ if scalar_model_values:
106
+ if 'depth_max_im' in metric_def:
107
+ depth_max = scalar_model_values.get(metric_def['depth_max_im'], depth_max)
108
+ if 'depth_min_im' in metric_def:
109
+ depth_min = scalar_model_values.get(metric_def['depth_min_im'], depth_min)
110
+
111
+ t, layer_mask = _get_clipped_thickness(z, depth_min, depth_max)
112
+ n_layers = len(z) - 1
113
+
114
+ # Scalar model: return the single value directly (no thickness concept)
115
+ if (z[-1] - z[0]) == 0 or n_layers == 0:
116
+ return M_samples[:, 0].copy()
117
+
118
+ if layer_mask is None:
119
+ M_sel = M_samples[:, :n_layers]
120
+ t_sel = t
121
+ else:
122
+ M_sel = M_samples[:, :n_layers][:, layer_mask]
123
+ t_sel = t[layer_mask]
124
+
125
+ if 'classes' in metric_def:
126
+ condition = np.isin(np.round(M_sel).astype(int), metric_def['classes'])
127
+ else:
128
+ v_cmp = metric_def.get('value_comparison', '<')
129
+ v_thr = metric_def.get('value_threshold', 0.0)
130
+ condition = M_sel < v_thr if v_cmp == '<' else M_sel > v_thr
131
+
132
+ mode = metric_def.get('thickness_mode', 'cumulative')
133
+ if mode == 'cumulative':
134
+ return (condition * t_sel).sum(axis=1)
135
+ else:
136
+ return _first_occurrence_thickness(condition, t_sel)
137
+
138
+
139
+ def _evaluate_constraint(M_samples, z, constraint, scalar_model_values=None):
140
+ """
141
+ Evaluate one constraint for a batch of realizations.
142
+
143
+ Parameters
144
+ ----------
145
+ M_samples : ndarray (N_samples, N_depth)
146
+ Model values for the sampled realizations.
147
+ z : ndarray (N_depth,)
148
+ Depth values [m].
149
+ constraint : dict
150
+ Constraint definition.
151
+ scalar_model_values : dict {im: ndarray (N_samples,)}, optional
152
+ Per-sample values of scalar models. Used to resolve `depth_max_im`
153
+ and `depth_min_im` constraint fields into per-sample depth bounds.
154
+
155
+ Returns
156
+ -------
157
+ valid : ndarray (N_samples,) bool
158
+ True for each realization that satisfies the constraint.
159
+ """
160
+ n_layers = len(z) - 1
161
+ is_scalar = (z[-1] - z[0]) == 0 or n_layers == 0
162
+
163
+ raw = _compute_metric(M_samples, z, constraint, scalar_model_values)
164
+
165
+ if is_scalar:
166
+ if 'classes' in constraint:
167
+ valid = np.isin(np.round(raw).astype(int), constraint['classes'])
168
+ else:
169
+ v_cmp = constraint.get('value_comparison', '<')
170
+ v_thr = constraint.get('value_threshold', 0.0)
171
+ valid = raw < v_thr if v_cmp == '<' else raw > v_thr
172
+ else:
173
+ t_cmp = constraint.get('thickness_comparison', '>')
174
+ t_thr = constraint.get('thickness_threshold', 0.0)
175
+ _ops = {
176
+ '>': lambda m: m > t_thr,
177
+ '<': lambda m: m < t_thr,
178
+ '>=': lambda m: m >= t_thr,
179
+ '<=': lambda m: m <= t_thr,
180
+ }
181
+ valid = _ops.get(t_cmp, _ops['>'])(raw)
182
+
183
+ return ~valid if constraint.get('negate', False) else valid
184
+
185
+
186
+ def _collect_needed_ims(items):
187
+ """Return the set of model indices (im) referenced in a constraint list or metric dict."""
188
+ needed = set()
189
+ if isinstance(items, dict):
190
+ items = [items]
191
+ for c in items:
192
+ needed.add(c['im'])
193
+ for field in ('depth_max_im', 'depth_min_im'):
194
+ if field in c:
195
+ needed.add(c[field])
196
+ return needed
197
+
198
+
199
+ def _load_query_inputs(f_post_h5, needed_ims):
200
+ """
201
+ Load posterior indices, coordinates, and prior model arrays.
202
+
203
+ Parameters
204
+ ----------
205
+ f_post_h5 : str
206
+ Path to the posterior HDF5 file.
207
+ needed_ims : set of int
208
+ Model indices to pre-load from the prior file.
209
+
210
+ Returns
211
+ -------
212
+ i_use : ndarray (N_data, N_post)
213
+ X, Y : ndarray or None
214
+ prior_models : dict {im: (M, z, is_discrete)}
215
+ N_data, N_post : int
216
+ """
217
+ with h5py.File(f_post_h5, 'r') as f:
218
+ i_use = f['i_use'][:]
219
+ f_prior_h5 = str(f.attrs.get('f5_prior', ''))
220
+ f_data_h5 = str(f.attrs.get('f5_data', ''))
221
+ X = f['UTMX'][:] if 'UTMX' in f else None
222
+ Y = f['UTMY'][:] if 'UTMY' in f else None
223
+
224
+ if not f_prior_h5:
225
+ raise ValueError("Posterior file missing 'f5_prior' attribute.")
226
+ if not os.path.isfile(f_prior_h5):
227
+ raise FileNotFoundError(f"Prior file not found: {f_prior_h5}")
228
+
229
+ if (X is None or Y is None) and f_data_h5 and os.path.isfile(f_data_h5):
230
+ with h5py.File(f_data_h5, 'r') as f:
231
+ if X is None and 'UTMX' in f:
232
+ X = f['UTMX'][:]
233
+ if Y is None and 'UTMY' in f:
234
+ Y = f['UTMY'][:]
235
+
236
+ N_data, N_post = i_use.shape
237
+
238
+ prior_models = {}
239
+ for im in needed_ims:
240
+ key = f'M{im}'
241
+ with h5py.File(f_prior_h5, 'r') as f:
242
+ M = f[key][:]
243
+ z = f[key].attrs['x'].astype(float)
244
+ is_discrete = bool(f[key].attrs.get('is_discrete', 0))
245
+ prior_models[im] = (M, z, is_discrete)
246
+
247
+ return i_use, X, Y, prior_models, N_data, N_post
248
+
249
+
250
+ def _build_scalar_vals(prior_models, idx):
251
+ """Return per-sample values for all scalar models in prior_models."""
252
+ scalar_vals = {}
253
+ for im_s, (M_s, z_s, _) in prior_models.items():
254
+ if (z_s[-1] - z_s[0]) == 0 or len(z_s) - 1 == 0:
255
+ scalar_vals[im_s] = M_s[idx, 0]
256
+ return scalar_vals
257
+
258
+
259
+ def query_probability(f_post_h5, query_dict):
260
+ """
261
+ Compute per-data-point probability that posterior realizations satisfy a query.
262
+
263
+ Parameters
264
+ ----------
265
+ f_post_h5 : str
266
+ Path to the posterior HDF5 file.
267
+ query_dict : str or dict
268
+ Path to a JSON file, or a dict with a ``"constraints"`` key.
269
+
270
+ Returns
271
+ -------
272
+ P : ndarray (N_data,)
273
+ Probability [0, 1] for each data location.
274
+ meta : dict
275
+ Keys: 'X', 'Y', 'N_data', 'N_post', 'i_use', 'i_use_query'.
276
+
277
+ Examples
278
+ --------
279
+ >>> query_def = {
280
+ ... "constraints": [{
281
+ ... "im": 2, "classes": [2],
282
+ ... "thickness_mode": "cumulative",
283
+ ... "thickness_comparison": ">",
284
+ ... "thickness_threshold": 10.0,
285
+ ... "depth_min": 0.0, "depth_max": 30.0
286
+ ... }]
287
+ ... }
288
+ >>> P, meta = query_probability('f_post.h5', query_def)
289
+ """
290
+ if isinstance(query_dict, str):
291
+ with open(query_dict, 'r') as fh:
292
+ query_dict = json.load(fh)
293
+
294
+ constraints = query_dict['constraints']
295
+ needed_ims = _collect_needed_ims(constraints)
296
+ i_use, X, Y, prior_models, N_data, N_post = _load_query_inputs(f_post_h5, needed_ims)
297
+
298
+ P = np.zeros(N_data)
299
+ i_use_query = []
300
+ for i in tqdm(range(N_data), desc='Evaluating probability query', unit='location'):
301
+ idx = i_use[i]
302
+ valid = np.ones(N_post, dtype=bool)
303
+ scalar_vals = _build_scalar_vals(prior_models, idx)
304
+ for c in constraints:
305
+ M, z, _ = prior_models[c['im']]
306
+ valid &= _evaluate_constraint(M[idx, :], z, c, scalar_model_values=scalar_vals)
307
+ P[i] = valid.mean()
308
+ i_use_query.append(idx[valid])
309
+
310
+ meta = {
311
+ 'X': X, 'Y': Y,
312
+ 'N_data': N_data, 'N_post': N_post,
313
+ 'i_use': i_use, 'i_use_query': i_use_query,
314
+ }
315
+ return P, meta
316
+
317
+
318
+ def query_percentile(f_post_h5, query_dict):
319
+ """
320
+ Compute per-data-point percentiles of a metric over posterior realizations.
321
+
322
+ Rather than asking "what fraction of realizations satisfy condition X?", this
323
+ asks "what is the p5/p50/p95 of metric X across realizations?". The metric
324
+ is defined by the same fields as a probability constraint, minus the comparison
325
+ fields (``thickness_comparison``, ``thickness_threshold``, ``negate``).
326
+
327
+ Parameters
328
+ ----------
329
+ f_post_h5 : str
330
+ Path to the posterior HDF5 file.
331
+ query_dict : str or dict
332
+ Path to a JSON file, or a dict with a ``"metric"`` key and an optional
333
+ ``"percentiles"`` key (default ``[5, 50, 95]``).
334
+
335
+ Returns
336
+ -------
337
+ percentile_values : ndarray (N_data, n_percentiles)
338
+ Requested percentile values for each data location.
339
+ meta : dict
340
+ Keys: 'X', 'Y', 'N_data', 'N_post', 'i_use', 'percentiles'.
341
+
342
+ Examples
343
+ --------
344
+ >>> query_def = {
345
+ ... "metric": {
346
+ ... "im": 2, "classes": [1, 2],
347
+ ... "thickness_mode": "cumulative",
348
+ ... "depth_max": 30.0
349
+ ... },
350
+ ... "percentiles": [5, 50, 95]
351
+ ... }
352
+ >>> pct_values, meta = query_percentile('f_post.h5', query_def)
353
+ >>> # pct_values shape: (N_data, 3) — p5, p50, p95 per location
354
+ """
355
+ if isinstance(query_dict, str):
356
+ with open(query_dict, 'r') as fh:
357
+ query_dict = json.load(fh)
358
+
359
+ metric_def = query_dict['metric']
360
+ percentiles = query_dict.get('percentiles', [5, 50, 95])
361
+ needed_ims = _collect_needed_ims(metric_def)
362
+ i_use, X, Y, prior_models, N_data, N_post = _load_query_inputs(f_post_h5, needed_ims)
363
+
364
+ M_main, z_main, _ = prior_models[metric_def['im']]
365
+ n_pct = len(percentiles)
366
+ result = np.zeros((N_data, n_pct))
367
+
368
+ for i in tqdm(range(N_data), desc='Evaluating percentile query', unit='location'):
369
+ idx = i_use[i]
370
+ scalar_vals = _build_scalar_vals(prior_models, idx)
371
+ values = _compute_metric(M_main[idx, :], z_main, metric_def,
372
+ scalar_model_values=scalar_vals)
373
+ result[i, :] = np.percentile(values, percentiles)
374
+
375
+ meta = {
376
+ 'X': X, 'Y': Y,
377
+ 'N_data': N_data, 'N_post': N_post,
378
+ 'i_use': i_use,
379
+ 'percentiles': percentiles,
380
+ }
381
+ return result, meta
382
+
383
+
384
+ def query(f_post_h5, query_dict):
385
+ """
386
+ Dispatcher: route to query_probability() or query_percentile() based on query_dict.
387
+
388
+ If ``query_dict`` contains a ``"metric"`` key, calls :func:`query_percentile`.
389
+ Otherwise calls :func:`query_probability` (backward compatible with all existing
390
+ ``"constraints"``-based dicts).
391
+
392
+ Parameters
393
+ ----------
394
+ f_post_h5 : str
395
+ Path to the posterior HDF5 file.
396
+ query_dict : str or dict
397
+ Query definition. See :func:`query_probability` and
398
+ :func:`query_percentile` for the respective schemas.
399
+
400
+ Returns
401
+ -------
402
+ result, meta
403
+ See the delegated function for details.
404
+ """
405
+ if not os.path.isfile(str(f_post_h5)):
406
+ print(f"[ig.query] Posterior file not found: {f_post_h5}")
407
+ return None, {}
408
+ if isinstance(query_dict, str):
409
+ with open(query_dict, 'r') as fh:
410
+ query_dict = json.load(fh)
411
+ if 'metric' in query_dict:
412
+ return query_percentile(f_post_h5, query_dict)
413
+ return query_probability(f_post_h5, query_dict)
414
+
415
+
416
+ def query_plot(P, meta, ip=None, query_dict=None, f_prior_h5=None, f_post_h5=None, title=None,
417
+ query_text=None, interpretation=None, text_panel=False, hardcopy=False, **kwargs):
418
+ """
419
+ Plot query results and optionally detailed model visualization for a data point.
420
+
421
+ If ip is None, displays the XY probability map showing P(x, y).
422
+ If ip is provided (together with query_dict and f_prior_h5/f_post_h5), skips the
423
+ probability map and shows only the detailed single-point visualization of all
424
+ posterior realizations and the query-matching subset.
425
+
426
+ Parameters
427
+ ----------
428
+ P : ndarray (N_data,)
429
+ Probability array from query().
430
+ meta : dict
431
+ Metadata dict from query() containing 'X', 'Y', 'i_use', 'i_use_query'.
432
+ ip : int, optional
433
+ Data point index to visualize in detail. If None, only shows probability map.
434
+ query_dict : dict, optional
435
+ Query dict used in query(). Required for detailed visualization.
436
+ f_prior_h5 : str, optional
437
+ Path to prior HDF5 file. If not provided, will be extracted from f_post_h5.
438
+ f_post_h5 : str, optional
439
+ Path to posterior HDF5 file. Used to automatically extract prior file path
440
+ if f_prior_h5 is not provided.
441
+ title : str, optional
442
+ Custom title for the probability map. If None, a title is built from
443
+ query_text and interpretation (if provided), or 'Query Probability Map'.
444
+ query_text : str, optional
445
+ The original natural-language query string. Shown in the figure title,
446
+ or in the text panel if text_panel=True.
447
+ interpretation : str, optional
448
+ The LLM interpretation string returned by query_from_text(). Shown as a
449
+ second line in the figure title, or in the text panel if text_panel=True.
450
+ text_panel : bool, optional
451
+ If True and query_text or interpretation is provided, adds a narrow text
452
+ column to the right of the probability map. The query text appears at the
453
+ top and the interpretation below. Default False.
454
+ hardcopy : bool or str, optional
455
+ Save the probability map figure. If True, saves as 'query_plot.png'.
456
+ If a string, uses that as the filename (a '.png' extension is appended
457
+ if the string has no extension). Default False.
458
+ **kwargs
459
+ All remaining keyword arguments are forwarded to :func:`plot_xy`, giving
460
+ full control over ``cmap``, ``clim``, ``uselog``, ``colorbar``,
461
+ ``colorbar_label``, ``plotPoints``, ``plotPoints_color``,
462
+ ``plotPoints_marker``, ``s``, etc.
463
+ ``cmap`` defaults to ``'hot_r'`` and ``clim`` defaults to ``[0, 1]``.
464
+
465
+ Examples
466
+ --------
467
+ >>> P, meta = query(f_post_h5, query_def)
468
+ >>> query_plot(P, meta) # Just probability map
469
+ >>> query_plot(P, meta, title='Custom Query Title') # Custom title
470
+ >>> query_plot(P, meta, ip=1000, query_dict=query_def, f_post_h5='posterior.h5')
471
+ >>> query_plot(P, meta, ip=1000, query_dict=query_def, f_prior_h5='prior.h5')
472
+ >>> # With LLM query text and interpretation:
473
+ >>> query_dict, interp = ig.query_from_text(text, f_prior_h5)
474
+ >>> P, meta = ig.query(f_post_h5, query_dict)
475
+ >>> ig.query_plot(P, meta, query_text=text, interpretation=interp)
476
+ """
477
+ import matplotlib.pyplot as plt
478
+
479
+ # Auto-extract prior file path from posterior file if needed
480
+ if f_prior_h5 is None and f_post_h5 is not None:
481
+ with h5py.File(f_post_h5, 'r') as f:
482
+ f_prior_h5 = str(f.attrs.get('f5_prior', ''))
483
+ if not f_prior_h5:
484
+ print("Warning: Could not extract f5_prior attribute from posterior file")
485
+
486
+ X = meta['X']
487
+ Y = meta['Y']
488
+
489
+ # Plot XY probability map only when no specific point is requested
490
+ if ip is None:
491
+ has_text = text_panel and (query_text is not None or interpretation is not None)
492
+ if has_text:
493
+ fig = plt.figure(figsize=(11, 6))
494
+ gs = fig.add_gridspec(1, 2, width_ratios=[3, 1], wspace=0.05)
495
+ ax = fig.add_subplot(gs[0])
496
+ ax_text = fig.add_subplot(gs[1])
497
+ else:
498
+ fig, ax = plt.subplots(figsize=(8, 6))
499
+
500
+ # Determine title
501
+ if has_text:
502
+ _title = 'Query Probability Map'
503
+ elif title is not None:
504
+ _title = title
505
+ elif query_text is not None or interpretation is not None:
506
+ parts = []
507
+ if query_text is not None:
508
+ parts.append(f"Query: {query_text}")
509
+ if interpretation is not None:
510
+ parts.append(f"Interpreted as: {interpretation}")
511
+ _title = '\n'.join(parts)
512
+ else:
513
+ _title = 'Query Probability Map'
514
+
515
+ import textwrap
516
+ _title = '\n'.join(
517
+ '\n'.join(textwrap.wrap(line, width=60)) if len(line) > 60 else line
518
+ for line in _title.splitlines()
519
+ )
520
+
521
+ # Background dots so P=0 (white) areas are visible, then probability scatter
522
+ _cmap = kwargs.pop('cmap', 'hot_r')
523
+ _clim = kwargs.pop('clim', [0, 1])
524
+ _s = kwargs.pop('s', 1)
525
+ _colorbar = kwargs.pop('colorbar', True)
526
+ _colorbar_label = kwargs.pop('colorbar_label', 'Probability')
527
+ _plotPoints = kwargs.pop('plotPoints', True)
528
+ from integrate.integrate_plot import plot_xy
529
+ _, ax, sc = plot_xy(P, X=X, Y=Y,
530
+ cmap=_cmap, clim=_clim,
531
+ title=_title, colorbar=_colorbar, colorbar_label=_colorbar_label,
532
+ ax=ax, s=_s, plotPoints=_plotPoints, **kwargs)
533
+ ax.set_xlabel('UTMX [m]')
534
+ ax.set_ylabel('UTMY [m]')
535
+
536
+ if has_text:
537
+ import textwrap
538
+ ax_text.set_axis_off()
539
+ CHARS = 36 # characters per wrapped line
540
+ LH = 0.062 # axes-fraction height per text line (fontsize 8)
541
+ LABEL_GAP = 0.03 # gap between bold label and text box
542
+ SECTION_GAP = 0.07 # gap between sections
543
+
544
+ y = 0.97
545
+ if query_text is not None:
546
+ ax_text.text(0.02, y, "Query:", transform=ax_text.transAxes,
547
+ fontsize=8, fontweight='bold', va='top')
548
+ y -= LH + LABEL_GAP
549
+ wrapped_q = textwrap.fill(query_text, CHARS)
550
+ n_q = wrapped_q.count('\n') + 1
551
+ ax_text.text(0.02, y, wrapped_q, transform=ax_text.transAxes,
552
+ fontsize=7.5, va='top',
553
+ bbox=dict(boxstyle='round,pad=0.4', facecolor='#f0f0f0', edgecolor='none'))
554
+ y -= n_q * LH + SECTION_GAP
555
+ if interpretation is not None:
556
+ ax_text.text(0.02, y, "Interpretation:", transform=ax_text.transAxes,
557
+ fontsize=8, fontweight='bold', va='top')
558
+ y -= LH + LABEL_GAP
559
+ wrapped_i = textwrap.fill(interpretation, CHARS)
560
+ ax_text.text(0.02, y, wrapped_i, transform=ax_text.transAxes,
561
+ fontsize=7.5, va='top',
562
+ bbox=dict(boxstyle='round,pad=0.4', facecolor='#e8f4e8', edgecolor='none'))
563
+
564
+ plt.tight_layout()
565
+
566
+ # If ip provided and we have necessary data, plot detailed model view
567
+ if ip is not None and query_dict is not None and f_prior_h5 is not None:
568
+ # Load prior model for the first constraint
569
+ im = query_dict['constraints'][0]['im']
570
+ with h5py.File(f_prior_h5, 'r') as f:
571
+ M = f[f'M{im}'][:]
572
+ # Read model attributes
573
+ is_discrete = bool(f[f'M{im}'].attrs.get('is_discrete', 0))
574
+ class_id = None
575
+ class_name = None
576
+ prior_cmap = None
577
+
578
+ # Always try to read colormap from prior file if available
579
+ if 'cmap' in f[f'M{im}'].attrs.keys():
580
+ try:
581
+ cmap_array = f[f'M{im}'].attrs['cmap'][:]
582
+ from matplotlib.colors import ListedColormap
583
+ # Format is [3, nlev] or [4, nlev] - transpose to get [nlev, 3] or [nlev, 4]
584
+ prior_cmap = ListedColormap(cmap_array.T)
585
+ except Exception:
586
+ prior_cmap = None
587
+
588
+ # Read class information for discrete models
589
+ if is_discrete:
590
+ if 'class_id' in f[f'M{im}'].attrs.keys():
591
+ class_id = f[f'M{im}'].attrs['class_id'][:].flatten()
592
+ if 'class_name' in f[f'M{im}'].attrs.keys():
593
+ class_name = f[f'M{im}'].attrs['class_name'][:].flatten()
594
+
595
+ # Get posterior and query-matching indices
596
+ i_use = meta['i_use'][ip, :]
597
+ i_use_query = meta['i_use_query'][ip]
598
+
599
+ # Calculate statistics
600
+ n_total = len(i_use)
601
+ n_accepted = len(i_use_query)
602
+ probability = P[ip]
603
+
604
+ # Get all posterior realizations
605
+ M_use_all = M[i_use]
606
+
607
+ # Create filtered version with NaN for non-matching realizations
608
+ # Convert to float to allow NaN values
609
+ M_use_filtered = M_use_all.astype(float)
610
+ # Create mask: True where i_use is in i_use_query
611
+ matching_mask = np.isin(i_use, i_use_query)
612
+ # Set non-matching realizations to NaN
613
+ M_use_filtered[~matching_mask, :] = np.nan
614
+
615
+ # Create detailed model plot
616
+ plt.figure(figsize=(12, 8))
617
+
618
+ # Determine color limits for discrete models
619
+ # Use the full range of class IDs from the prior file
620
+ if is_discrete and class_id is not None:
621
+ vmin_plot = np.min(class_id) - 0.5
622
+ vmax_plot = np.max(class_id) + 0.5
623
+ else:
624
+ vmin_plot = None
625
+ vmax_plot = None
626
+
627
+ # Subplot 1: All posterior realizations
628
+ plt.subplot(2, 1, 1)
629
+ # Use colormap from prior file if available
630
+ if prior_cmap is not None and vmin_plot is not None:
631
+ im1 = plt.imshow(M_use_all.T, aspect='auto', cmap=prior_cmap, interpolation='nearest',
632
+ vmin=vmin_plot, vmax=vmax_plot)
633
+ elif prior_cmap is not None:
634
+ im1 = plt.imshow(M_use_all.T, aspect='auto', cmap=prior_cmap, interpolation='nearest')
635
+ else:
636
+ im1 = plt.imshow(M_use_all.T, aspect='auto', cmap='jet', interpolation='nearest')
637
+
638
+ plt.title(f'All Posterior Realizations (Point {ip})\n'
639
+ f'Total Realizations: {n_total} | Accepted: {n_accepted} | Probability: {probability:.3f}')
640
+ plt.xlabel('Realization index')
641
+ plt.ylabel('Layer index')
642
+
643
+ # Create colorbar with class names if discrete
644
+ if is_discrete and class_id is not None and class_name is not None:
645
+ cbar1 = plt.colorbar(im1)
646
+ cbar1.set_ticks(class_id)
647
+ # Create tick labels with format "ClassName (ID)"
648
+ tick_labels = [f'{name} ({int(cid)})' for name, cid in zip(class_name, class_id)]
649
+ cbar1.set_ticklabels(tick_labels)
650
+ cbar1.ax.invert_yaxis()
651
+ else:
652
+ plt.colorbar(im1, label='Model value')
653
+
654
+ # Subplot 2: Query-matching realizations only (others set to NaN)
655
+ plt.subplot(2, 1, 2)
656
+ # Use colormap from prior file if available
657
+ if prior_cmap is not None and vmin_plot is not None:
658
+ im2 = plt.imshow(M_use_filtered.T, aspect='auto', cmap=prior_cmap, interpolation='nearest',
659
+ vmin=vmin_plot, vmax=vmax_plot)
660
+ elif prior_cmap is not None:
661
+ im2 = plt.imshow(M_use_filtered.T, aspect='auto', cmap=prior_cmap, interpolation='nearest')
662
+ else:
663
+ im2 = plt.imshow(M_use_filtered.T, aspect='auto', cmap='jet', interpolation='nearest')
664
+
665
+ plt.title('Query-Matching Realizations Only (non-matching set to NaN)')
666
+ plt.xlabel('Realization index')
667
+ plt.ylabel('Layer index')
668
+
669
+ # Create colorbar with class names if discrete
670
+ if is_discrete and class_id is not None and class_name is not None:
671
+ cbar2 = plt.colorbar(im2)
672
+ cbar2.set_ticks(class_id)
673
+ tick_labels = [f'{name} ({int(cid)})' for name, cid in zip(class_name, class_id)]
674
+ cbar2.set_ticklabels(tick_labels)
675
+ cbar2.ax.invert_yaxis()
676
+ else:
677
+ plt.colorbar(im2, label='Model value')
678
+
679
+ plt.tight_layout()
680
+
681
+ _VALID_EXTS = {'.png', '.jpg', '.jpeg', '.pdf', '.svg', '.eps', '.tif', '.tiff', '.webp'}
682
+ if hardcopy is not False and hardcopy is not None:
683
+ if isinstance(hardcopy, str):
684
+ safe = hardcopy.replace(':', '_').replace('/', '_')
685
+ f_png = safe if os.path.splitext(safe)[1].lower() in _VALID_EXTS else safe + '.png'
686
+ else:
687
+ f_png = 'query_plot.png'
688
+ plt.savefig(f_png)
689
+ print(f"Figure saved to {f_png}")
690
+
691
+ plt.show()
692
+
693
+
694
+ def query_percentile_plot(percentile_values, meta, query_text=None, interpretation=None,
695
+ text_panel=False, hardcopy=False, **kwargs):
696
+ """
697
+ Plot one probability map per requested percentile as side-by-side subplots.
698
+
699
+ Parameters
700
+ ----------
701
+ percentile_values : ndarray (N_data, n_percentiles)
702
+ Output of query_percentile().
703
+ meta : dict
704
+ Metadata dict from query_percentile() containing 'X', 'Y', 'percentiles'.
705
+ query_text : str, optional
706
+ Original query string — shown as figure suptitle.
707
+ interpretation : str, optional
708
+ LLM interpretation string — shown below query_text if provided.
709
+ text_panel : bool, optional
710
+ If True, add a narrow text column to the right of the maps.
711
+ hardcopy : bool or str, optional
712
+ Save figure to disk. True → 'query_percentile_plot.png'; a string is
713
+ used as the filename (.png appended if no extension).
714
+ **kwargs
715
+ All remaining keyword arguments are forwarded to :func:`plot_xy`, giving
716
+ full control over ``cmap``, ``clim``, ``uselog``, ``colorbar``,
717
+ ``colorbar_label``, ``plotPoints``, ``plotPoints_color``,
718
+ ``plotPoints_marker``, ``s``, etc.
719
+ ``clim`` defaults to ``[percentile_values.min(), percentile_values.max()]``
720
+ so that all subplots share the same colour scale.
721
+ ``cmap`` defaults to ``'viridis'``.
722
+
723
+ Returns
724
+ -------
725
+ fig : matplotlib Figure
726
+ """
727
+ import matplotlib.pyplot as plt
728
+
729
+ percentiles = meta.get('percentiles', [5, 50, 95])
730
+ n_pct = len(percentiles)
731
+ X = meta.get('X')
732
+ Y = meta.get('Y')
733
+
734
+ # Shared colour scale across all subplots unless caller overrides.
735
+ cmap = kwargs.pop('cmap', 'viridis')
736
+ clim = kwargs.pop('clim', [float(percentile_values.min()), float(percentile_values.max())])
737
+
738
+ ncols = n_pct + (1 if text_panel else 0)
739
+ width_ratios = [4] * n_pct + ([1.5] if text_panel else [])
740
+ fig, axes = plt.subplots(1, ncols, figsize=(4.5 * n_pct + (1.5 if text_panel else 0), 5),
741
+ gridspec_kw={'width_ratios': width_ratios} if text_panel else {},
742
+ squeeze=False)
743
+
744
+ map_axes = axes[0, :n_pct]
745
+
746
+ from integrate.integrate_plot import plot_xy
747
+ for k, (pct, ax) in enumerate(zip(percentiles, map_axes)):
748
+ vals = percentile_values[:, k]
749
+ if X is not None and Y is not None:
750
+ _, ax, _ = plot_xy(vals, X=X, Y=Y,
751
+ cmap=cmap, clim=clim,
752
+ ax=ax, **kwargs)
753
+ ax.set_xlabel('UTMX')
754
+ if k == 0:
755
+ ax.set_ylabel('UTMY')
756
+ else:
757
+ ax.plot(vals)
758
+ ax.set_xlabel('Location index')
759
+ ax.set_title(f'P{pct} (median={np.median(vals):.1f})')
760
+
761
+ if text_panel:
762
+ tax = axes[0, n_pct]
763
+ tax.axis('off')
764
+ txt = ''
765
+ if query_text:
766
+ txt += f'Query:\n{query_text}\n\n'
767
+ if interpretation:
768
+ txt += f'Interpretation:\n{interpretation}'
769
+ if txt:
770
+ tax.text(0.05, 0.95, txt, transform=tax.transAxes,
771
+ fontsize=7, va='top', wrap=True)
772
+
773
+ suptitle_parts = [t for t in [query_text, interpretation] if t]
774
+ if suptitle_parts and not text_panel:
775
+ fig.suptitle('\n'.join(suptitle_parts), fontsize=8, y=1.01)
776
+
777
+ plt.tight_layout()
778
+
779
+ if hardcopy:
780
+ fname = hardcopy if isinstance(hardcopy, str) else 'query_percentile_plot'
781
+ if '.' not in os.path.basename(fname):
782
+ fname += '.png'
783
+ plt.savefig(fname, bbox_inches='tight', dpi=150)
784
+ print(f"Figure saved to {fname}")
785
+
786
+ plt.show()
787
+ return fig
788
+
789
+
790
+ def save_query(query, path):
791
+ """
792
+ Save a query dict to a JSON file.
793
+
794
+ Parameters
795
+ ----------
796
+ query : dict
797
+ Query definition dictionary.
798
+ path : str
799
+ Output JSON file path.
800
+ """
801
+ with open(path, 'w') as f:
802
+ json.dump(query, f, indent=2)
803
+ print(f"Query saved to {path}")
804
+
805
+
806
+ def load_query(path):
807
+ """
808
+ Load a query dict from a JSON file.
809
+
810
+ Parameters
811
+ ----------
812
+ path : str
813
+ Input JSON file path.
814
+
815
+ Returns
816
+ -------
817
+ query : dict
818
+ Query definition dictionary.
819
+ """
820
+ with open(path, 'r') as f:
821
+ return json.load(f)
822
+
823
+
824
+ def get_prior_model_info(f_prior_h5, im):
825
+ """
826
+ Return metadata for prior model im.
827
+
828
+ Parameters
829
+ ----------
830
+ f_prior_h5 : str
831
+ Path to the prior HDF5 file.
832
+ im : int
833
+ Model index.
834
+
835
+ Returns
836
+ -------
837
+ info : dict
838
+ Keys: 'name', 'is_discrete', 'z', 'class_id', 'class_name'.
839
+ """
840
+ key = f'M{im}'
841
+ with h5py.File(f_prior_h5, 'r') as f:
842
+ ds = f[key]
843
+ info = {
844
+ 'name': ds.attrs.get('name', key),
845
+ 'is_discrete': bool(ds.attrs.get('is_discrete', 0)),
846
+ 'z': ds.attrs['x'].astype(float),
847
+ 'class_id': ds.attrs.get('class_id', None),
848
+ 'class_name': ds.attrs.get('class_name', None),
849
+ }
850
+ return info
851
+
852
+
853
+ def prior_describe(f_prior_h5):
854
+ """
855
+ Print a human-readable summary of all models in a prior HDF5 file.
856
+
857
+ Parameters
858
+ ----------
859
+ f_prior_h5 : str
860
+ Path to the prior HDF5 file.
861
+
862
+ Examples
863
+ --------
864
+ >>> ig.prior_describe('prior.h5')
865
+ Prior file: prior.h5
866
+ N realizations: 1000000
867
+ im=1 Resistivity CONTINUOUS depth 0–89 m (89 layers)
868
+ im=2 Lithology DISCRETE depth 0–89 m (89 layers)
869
+ class 1 = Sand
870
+ class 2 = Grus
871
+ im=3 Waterlevel SCALAR
872
+ """
873
+ with h5py.File(f_prior_h5, 'r') as f:
874
+ model_keys = sorted(
875
+ [k for k in f.keys() if k.startswith('M') and k[1:].isdigit()],
876
+ key=lambda k: int(k[1:])
877
+ )
878
+ N = f[model_keys[0]].shape[0] if model_keys else 0
879
+
880
+ print(f"Prior file: {f_prior_h5}")
881
+ print(f"N realizations: {N}")
882
+
883
+ for key in model_keys:
884
+ im = int(key[1:])
885
+ info = get_prior_model_info(f_prior_h5, im)
886
+ z = info['z']
887
+ depth_min, depth_max = float(z[0]), float(z[-1])
888
+ n_layers = len(z) - 1
889
+ is_scalar = (depth_max - depth_min) == 0 or n_layers == 0
890
+ name = info['name']
891
+
892
+ if is_scalar:
893
+ kind = 'SCALAR-DISCRETE' if info['is_discrete'] else 'SCALAR'
894
+ print(f" im={im} {name:<20s} {kind}")
895
+ elif info['is_discrete']:
896
+ print(f" im={im} {name:<20s} DISCRETE depth {depth_min:.0f}–{depth_max:.0f} m ({n_layers} layers)")
897
+ else:
898
+ print(f" im={im} {name:<20s} CONTINUOUS depth {depth_min:.0f}–{depth_max:.0f} m ({n_layers} layers)")
899
+
900
+ if info['is_discrete'] and info['class_id'] is not None and info['class_name'] is not None:
901
+ for cid, cname in zip(info['class_id'].flatten(), info['class_name'].flatten()):
902
+ print(f" class {int(cid):>3} = {cname}")
903
+
904
+
905
+ def _build_llm_system_prompt(f_prior_h5):
906
+ """
907
+ Build a system prompt for the LLM that includes the query schema and prior model context.
908
+
909
+ Parameters
910
+ ----------
911
+ f_prior_h5 : str
912
+ Path to the prior HDF5 file.
913
+
914
+ Returns
915
+ -------
916
+ prompt : str
917
+ System prompt string.
918
+ """
919
+ # Collect all model keys from the prior file
920
+ with h5py.File(f_prior_h5, 'r') as f:
921
+ model_keys = sorted([k for k in f.keys() if k.startswith('M') and k[1:].isdigit()])
922
+
923
+ model_sections = []
924
+ for key in model_keys:
925
+ im = int(key[1:])
926
+ info = get_prior_model_info(f_prior_h5, im)
927
+ z = info['z']
928
+ n_layers = len(z) - 1
929
+ depth_min = float(z[0])
930
+ depth_max = float(z[-1])
931
+ name = info['name'] if info['name'] != key else key
932
+
933
+ is_scalar = (depth_max - depth_min) == 0 or n_layers == 0
934
+
935
+ if is_scalar:
936
+ kind = 'SCALAR-DISCRETE' if info['is_discrete'] else 'SCALAR'
937
+ lines = [f" Model im={im}: {name} ({kind}) — single value per realization, no depth profile"]
938
+ lines.append(" Use only 'value_comparison' and 'value_threshold'. Do NOT include any thickness fields.")
939
+ if info['is_discrete'] and info['class_id'] is not None and info['class_name'] is not None:
940
+ lines.append(" Classes (use these integer IDs in the 'classes' field):")
941
+ ids = info['class_id'].flatten()
942
+ names = info['class_name'].flatten()
943
+ for cid, cname in zip(ids, names):
944
+ lines.append(f" {int(cid)} = {cname}")
945
+ elif info['is_discrete']:
946
+ lines = [f" Model im={im}: {name} (DISCRETE), depth {depth_min:.1f}–{depth_max:.1f} m, {n_layers} layers"]
947
+ if info['class_id'] is not None and info['class_name'] is not None:
948
+ lines.append(" Classes (use these integer IDs in the 'classes' field):")
949
+ ids = info['class_id'].flatten()
950
+ names = info['class_name'].flatten()
951
+ for cid, cname in zip(ids, names):
952
+ lines.append(f" {int(cid)} = {cname}")
953
+ else:
954
+ lines.append(" (Class IDs not available in prior file)")
955
+ else:
956
+ lines = [f" Model im={im}: {name} (CONTINUOUS), depth {depth_min:.1f}–{depth_max:.1f} m, {n_layers} layers"]
957
+ lines.append(" Use 'value_comparison' ('<' or '>') and 'value_threshold' for this model.")
958
+
959
+ model_sections.append('\n'.join(lines))
960
+
961
+ models_text = '\n'.join(model_sections)
962
+
963
+ prompt = f"""You are a geophysics query assistant for the INTEGRATE probabilistic inversion module.
964
+ Your task is to translate a natural-language query about geological or geophysical properties
965
+ into a valid JSON query dict that can be executed by the query() function.
966
+
967
+ ## Query types
968
+
969
+ Choose the query type based on the user's intent:
970
+
971
+ **"probability"** — the user asks for a probability, likelihood, or yes/no fraction.
972
+ Example: "What is the probability that clay thickness exceeds 10 m?"
973
+ Response includes "query_type": "probability" and a "constraints" list.
974
+
975
+ **"percentile"** — the user asks for a distribution, typical value, or p5/p50/p95.
976
+ Example: "What are the p5, p50, p95 of the cumulative thickness of sand above 10 m depth?"
977
+ Response includes "query_type": "percentile", a single "metric" object, and a "percentiles" list.
978
+
979
+ ## Response format
980
+
981
+ You must always respond with a single JSON object. Required top-level keys:
982
+ - "interpretation": 1–2 sentence plain-English confirmation of what you understood.
983
+ - "query_type": either "probability" or "percentile".
984
+ - For probability: "constraints" (list of constraint objects — see below).
985
+ - For percentile: "metric" (a single metric object — see below) and "percentiles" (list of ints,
986
+ default [5, 50, 95] if not specified by the user).
987
+
988
+ Probability response structure:
989
+ ```json
990
+ {{
991
+ "interpretation": "...",
992
+ "query_type": "probability",
993
+ "constraints": [ {{ ... }} ]
994
+ }}
995
+ ```
996
+
997
+ Percentile response structure:
998
+ ```json
999
+ {{
1000
+ "interpretation": "...",
1001
+ "query_type": "percentile",
1002
+ "metric": {{ ... }},
1003
+ "percentiles": [5, 50, 95]
1004
+ }}
1005
+ ```
1006
+
1007
+ ## Constraint fields
1008
+
1009
+ A constraint list contains one or more constraint objects combined with logical AND
1010
+ (every constraint must be satisfied).
1011
+
1012
+ ## Constraint Fields
1013
+
1014
+ | Field | Type | Required | Valid values | Description |
1015
+ |----------------------|-------------|-------------------|-------------------------------------|--------------------------------------------------|
1016
+ | im | int | always | 1, 2, 3, ... | Prior model index (see Available Models below) |
1017
+ | classes | list[int] | discrete only | class IDs from the model | Match any of these class IDs (discrete models) |
1018
+ | value_comparison | str | continuous only | "<" or ">" | Compare model value against threshold |
1019
+ | value_threshold | float | continuous only | any float | Threshold for continuous value comparison |
1020
+ | thickness_mode | str | depth models only | "cumulative" or "first_occurrence" | How to aggregate thickness of matching layers |
1021
+ | thickness_comparison | str | depth models only | ">", "<", ">=", "<=" | Operator applied to the computed thickness |
1022
+ | thickness_threshold | float | depth models only | any float (meters) | Thickness threshold in meters |
1023
+ | depth_min | float | optional | any float | Upper boundary of depth interval [m] |
1024
+ | depth_max | float | optional | any float | Lower boundary of depth interval [m] |
1025
+ | depth_max_im | int | optional | SCALAR model im | Per-realization depth_max from a scalar model |
1026
+ | depth_min_im | int | optional | SCALAR model im | Per-realization depth_min from a scalar model |
1027
+ | negate | bool | optional | true or false (default: false) | If true, invert the constraint result |
1028
+
1029
+ > **Cross-model depth bounds:** `depth_max_im` / `depth_min_im` take the `im` index of a
1030
+ > SCALAR model and use its per-realization value as the depth boundary. This enables
1031
+ > queries like "Sand above the water table" where the cutoff depth varies per realization.
1032
+ > Use `depth_max_im` to cut at the scalar model's value from above; use `depth_min_im`
1033
+ > to cut from below. These may be combined with fixed `depth_min` / `depth_max`.
1034
+
1035
+ ### thickness_mode explained
1036
+ - "cumulative": sum the thickness of ALL matching layers within the depth interval
1037
+ - "first_occurrence": thickness of the FIRST contiguous block of matching layers
1038
+
1039
+ ### Scalar models (marked SCALAR or SCALAR-DISCRETE above)
1040
+ These store a single value per realization, not a depth profile. For scalar models:
1041
+ - Omit ALL thickness fields (`thickness_mode`, `thickness_comparison`, `thickness_threshold`, `depth_min`, `depth_max`).
1042
+ - Use only `im`, `value_comparison`, `value_threshold`, and optionally `negate`.
1043
+
1044
+ ## Available Prior Models
1045
+
1046
+ {models_text}
1047
+
1048
+ ## Examples
1049
+
1050
+ ## Metric fields (for percentile queries)
1051
+
1052
+ A metric object defines WHAT to measure per realization (no comparison or threshold).
1053
+ It uses the same fields as a constraint, minus: thickness_comparison, thickness_threshold, negate.
1054
+
1055
+ | Field | Type | Required | Valid values | Description |
1056
+ |------------------|-----------|-------------------|-------------------------------------|-----------------------------------------------|
1057
+ | im | int | always | 1, 2, 3, ... | Prior model index |
1058
+ | classes | list[int] | DISCRETE only | class IDs from the model | Thickness of these classes is measured |
1059
+ | value_comparison | str | CONTINUOUS only | "<" or ">" | Condition on value before measuring thickness |
1060
+ | value_threshold | float | CONTINUOUS only | any float | Threshold for the value condition |
1061
+ | thickness_mode | str | depth models only | "cumulative" or "first_occurrence" | How to aggregate thickness |
1062
+ | depth_min | float | optional | any float | Upper depth boundary [m] |
1063
+ | depth_max | float | optional | any float | Lower depth boundary [m] |
1064
+ | depth_max_im | int | optional | SCALAR model im | Per-realization depth_max from scalar model |
1065
+ | depth_min_im | int | optional | SCALAR model im | Per-realization depth_min from scalar model |
1066
+
1067
+ For SCALAR models used as a metric: returns the raw scalar value; no thickness fields needed.
1068
+
1069
+ ### Example 1: Discrete cumulative constraint
1070
+ Query: "Probability that cumulative clay thickness exceeds 10 m within 0–30 m depth"
1071
+ ```json
1072
+ {{
1073
+ "interpretation": "Probability that the cumulative thickness of clay (class 2) exceeds 10 m within 0–30 m depth.",
1074
+ "query_type": "probability",
1075
+ "constraints": [
1076
+ {{
1077
+ "im": 2,
1078
+ "classes": [2],
1079
+ "thickness_mode": "cumulative",
1080
+ "thickness_comparison": ">",
1081
+ "thickness_threshold": 10.0,
1082
+ "depth_min": 0.0,
1083
+ "depth_max": 30.0,
1084
+ "negate": false
1085
+ }}
1086
+ ]
1087
+ }}
1088
+ ```
1089
+
1090
+ ### Example 2: Continuous cumulative constraint
1091
+ Query: "Probability that resistivity is below 100 ohm-m for at least 25 m within 0–50 m"
1092
+ ```json
1093
+ {{
1094
+ "interpretation": "Probability that resistivity (im=1) is below 100 ohm-m for a cumulative thickness of at least 25 m within 0–50 m depth.",
1095
+ "query_type": "probability",
1096
+ "constraints": [
1097
+ {{
1098
+ "im": 1,
1099
+ "value_comparison": "<",
1100
+ "value_threshold": 100.0,
1101
+ "thickness_mode": "cumulative",
1102
+ "thickness_comparison": ">",
1103
+ "thickness_threshold": 25.0,
1104
+ "depth_min": 0.0,
1105
+ "depth_max": 50.0,
1106
+ "negate": false
1107
+ }}
1108
+ ]
1109
+ }}
1110
+ ```
1111
+
1112
+ ### Example 3: Multi-constraint AND
1113
+ Query: "Probability that clay > 5 m within 0–20 m AND resistivity > 500 ohm-m for >= 1 m within 20–60 m"
1114
+ ```json
1115
+ {{
1116
+ "interpretation": "Probability that cumulative clay (class 2) thickness exceeds 5 m within 0–20 m AND resistivity (im=1) exceeds 500 ohm-m for at least 1 m within 20–60 m. Both constraints must hold simultaneously.",
1117
+ "query_type": "probability",
1118
+ "constraints": [
1119
+ {{
1120
+ "im": 2,
1121
+ "classes": [2],
1122
+ "thickness_mode": "cumulative",
1123
+ "thickness_comparison": ">",
1124
+ "thickness_threshold": 5.0,
1125
+ "depth_min": 0.0,
1126
+ "depth_max": 20.0,
1127
+ "negate": false
1128
+ }},
1129
+ {{
1130
+ "im": 1,
1131
+ "value_comparison": ">",
1132
+ "value_threshold": 500.0,
1133
+ "thickness_mode": "cumulative",
1134
+ "thickness_comparison": ">",
1135
+ "thickness_threshold": 1.0,
1136
+ "depth_min": 20.0,
1137
+ "depth_max": 60.0,
1138
+ "negate": false
1139
+ }}
1140
+ ]
1141
+ }}
1142
+ ```
1143
+
1144
+ ### Example 4: First-occurrence with negation
1145
+ Query: "Probability that the first occurrence of clay at the surface is less than 3 m thick"
1146
+ ```json
1147
+ {{
1148
+ "interpretation": "Probability that the first contiguous block of clay (class 2) starting from the surface is less than 3 m thick, within 0–30 m depth.",
1149
+ "query_type": "probability",
1150
+ "constraints": [
1151
+ {{
1152
+ "im": 2,
1153
+ "classes": [2],
1154
+ "thickness_mode": "first_occurrence",
1155
+ "thickness_comparison": "<",
1156
+ "thickness_threshold": 3.0,
1157
+ "depth_min": 0.0,
1158
+ "depth_max": 30.0,
1159
+ "negate": false
1160
+ }}
1161
+ ]
1162
+ }}
1163
+ ```
1164
+
1165
+ ### Example 5: Scalar model query (no thickness fields)
1166
+ Query: "Probability that the water table is shallower than 5 m"
1167
+ ```json
1168
+ {{
1169
+ "interpretation": "Probability that the water table depth (im=3, SCALAR) is less than 5 m.",
1170
+ "query_type": "probability",
1171
+ "constraints": [
1172
+ {{
1173
+ "im": 3,
1174
+ "value_comparison": "<",
1175
+ "value_threshold": 5.0,
1176
+ "negate": false
1177
+ }}
1178
+ ]
1179
+ }}
1180
+ ```
1181
+
1182
+ ### Example 6: Cross-model depth constraint (dynamic depth bound from scalar model)
1183
+ Query: "Probability that Sand and Grus have a cumulative thickness above the water table exceeding 5 m"
1184
+ ```json
1185
+ {{
1186
+ "interpretation": "Probability that Sand (class 1) and Grus (class 2) have a cumulative thickness exceeding 5 m within the zone above the water table (im=3), starting from the surface.",
1187
+ "query_type": "probability",
1188
+ "constraints": [
1189
+ {{
1190
+ "im": 2,
1191
+ "classes": [1, 2],
1192
+ "thickness_mode": "cumulative",
1193
+ "thickness_comparison": ">",
1194
+ "thickness_threshold": 5.0,
1195
+ "depth_min": 0.0,
1196
+ "depth_max_im": 3,
1197
+ "negate": false
1198
+ }}
1199
+ ]
1200
+ }}
1201
+ ```
1202
+
1203
+ ### Example 7: Percentile query — thickness distribution
1204
+ Query: "What are the p5, p50, and p95 of the cumulative thickness of Sand and Grus within 0 to 30 m depth?"
1205
+ ```json
1206
+ {{
1207
+ "interpretation": "P5/P50/P95 of the cumulative thickness of Sand (class 1) and Grus (class 2) within 0–30 m depth.",
1208
+ "query_type": "percentile",
1209
+ "metric": {{
1210
+ "im": 2,
1211
+ "classes": [1, 2],
1212
+ "thickness_mode": "cumulative",
1213
+ "depth_min": 0.0,
1214
+ "depth_max": 30.0
1215
+ }},
1216
+ "percentiles": [5, 50, 95]
1217
+ }}
1218
+ ```
1219
+
1220
+ ### Example 8: Percentile query — cross-model depth bound
1221
+ Query: "What is the typical (median) thickness of Sand and Grus above the water table?"
1222
+ ```json
1223
+ {{
1224
+ "interpretation": "P5/P50/P95 of the cumulative thickness of Sand (class 1) and Grus (class 2) above the water table (depth bounded per realization by im=3).",
1225
+ "query_type": "percentile",
1226
+ "metric": {{
1227
+ "im": 2,
1228
+ "classes": [1, 2],
1229
+ "thickness_mode": "cumulative",
1230
+ "depth_min": 0.0,
1231
+ "depth_max_im": 3
1232
+ }},
1233
+ "percentiles": [5, 50, 95]
1234
+ }}
1235
+ ```
1236
+
1237
+ ## Instructions
1238
+
1239
+ - Respond with ONLY a valid JSON object. No markdown fences, no extra commentary.
1240
+ - Always include "interpretation" (1–2 sentences) and "query_type" ("probability" or "percentile").
1241
+ - Use only the model indices (im) and class IDs listed under Available Prior Models above.
1242
+ - If the query cannot be expressed with the available schema and models, respond with exactly:
1243
+ UNSUPPORTED: <brief reason>
1244
+ - Do not invent class IDs or model indices that are not listed above.
1245
+ """
1246
+ return prompt
1247
+
1248
+
1249
+ def _litellm_extra(model):
1250
+ """Return extra_body kwargs for litellm.completion to disable thinking on Ollama models."""
1251
+ if model.startswith('ollama'):
1252
+ return {'extra_body': {'think': False}}
1253
+ return {}
1254
+
1255
+
1256
+ def query_from_text(text, f_prior_h5, model='anthropic/claude-sonnet-4-6', api_key=None, max_tokens=4096, verbose=False):
1257
+ """
1258
+ Translate a natural-language query into a query dict using an LLM.
1259
+
1260
+ Uses LiteLLM to interpret the user's text query in the context of the
1261
+ available prior models and the integrate query schema, returning a query
1262
+ dict and a plain-English interpretation of what the LLM understood.
1263
+
1264
+ Parameters
1265
+ ----------
1266
+ text : str
1267
+ Natural language description of the query, e.g.
1268
+ "What is the probability that cumulative clay thickness exceeds 10 m?".
1269
+ f_prior_h5 : str
1270
+ Path to the prior HDF5 file. Model metadata (class names, depth ranges,
1271
+ discrete/continuous type) is read automatically and included in the
1272
+ LLM prompt so the model knows what constraints are valid.
1273
+ model : str, optional
1274
+ LiteLLM model string (default: 'anthropic/claude-sonnet-4-6'). Any
1275
+ LiteLLM-supported model works, e.g. 'openai/gpt-4o'.
1276
+ api_key : str, optional
1277
+ Provider API key. If None, the relevant environment variable
1278
+ (e.g. ANTHROPIC_API_KEY) is used.
1279
+ verbose : bool, optional
1280
+ If True, print the system prompt and LLM response for inspection.
1281
+
1282
+ Returns
1283
+ -------
1284
+ query_dict : dict
1285
+ Query dict ready to pass to ig.query(f_post_h5, query_dict).
1286
+ interpretation : str
1287
+ Plain English confirmation of what the LLM understood the query to mean.
1288
+ Check this before running ig.query() to catch misunderstandings cheaply.
1289
+ system_prompt : str
1290
+ The full system prompt sent to the LLM. Useful for inspection and debugging.
1291
+
1292
+ Raises
1293
+ ------
1294
+ ImportError
1295
+ If the litellm package is not installed.
1296
+ ValueError
1297
+ If the LLM reports the query is unsupported, or if the response
1298
+ cannot be parsed as valid JSON.
1299
+
1300
+ Notes
1301
+ -----
1302
+ Requires either the api_key parameter or the relevant provider environment
1303
+ variable to be set. Install the dependency with: pip install litellm
1304
+
1305
+ Examples
1306
+ --------
1307
+ >>> import integrate as ig
1308
+ >>> query_dict, interpretation, system_prompt = ig.query_from_text(
1309
+ ... "Probability that cumulative clay thickness > 10 m within 0-30 m",
1310
+ ... f_prior_h5='prior.h5',
1311
+ ... api_key='sk-ant-...',
1312
+ ... )
1313
+ >>> print(interpretation)
1314
+ >>> P, meta = ig.query('posterior.h5', query_dict)
1315
+ >>> ig.query_plot(P, meta)
1316
+ """
1317
+ try:
1318
+ import litellm
1319
+ except ImportError:
1320
+ raise ImportError(
1321
+ "The 'litellm' package is required for query_from_text(). "
1322
+ "Install it with: pip install litellm"
1323
+ )
1324
+
1325
+ system_prompt = _build_llm_system_prompt(f_prior_h5)
1326
+
1327
+ if verbose:
1328
+ print("=== SYSTEM PROMPT ===")
1329
+ print(system_prompt)
1330
+ print("=== USER TEXT ===")
1331
+ print(text)
1332
+
1333
+ def _strip_fences(s):
1334
+ if s.startswith("```"):
1335
+ s = s.split("\n", 1)[-1]
1336
+ if s.endswith("```"):
1337
+ s = s.rsplit("```", 1)[0].strip()
1338
+ return s
1339
+
1340
+ response_obj = litellm.completion(
1341
+ model=model,
1342
+ max_tokens=max_tokens,
1343
+ messages=[
1344
+ {"role": "system", "content": system_prompt},
1345
+ {"role": "user", "content": text},
1346
+ ],
1347
+ api_key=api_key,
1348
+ **_litellm_extra(model),
1349
+ )
1350
+
1351
+ msg = response_obj.choices[0].message
1352
+ response = _strip_fences((msg.content or '').strip())
1353
+
1354
+ if not response:
1355
+ raise ValueError(
1356
+ f"Model '{model}' returned empty content. "
1357
+ "If this is a thinking model (e.g. Qwen3, DeepSeek-R1), ensure /no_think "
1358
+ "is in the prompt or increase max_tokens."
1359
+ )
1360
+
1361
+ if verbose:
1362
+ print("=== LLM RESPONSE ===")
1363
+ print(response)
1364
+
1365
+ if response.startswith("UNSUPPORTED:"):
1366
+ reason = response[len("UNSUPPORTED:"):].strip()
1367
+ raise ValueError(f"Query cannot be expressed with the current schema: {reason}")
1368
+
1369
+ try:
1370
+ parsed = json.loads(response)
1371
+ except json.JSONDecodeError:
1372
+ if verbose:
1373
+ print("=== JSON PARSE FAILED — RETRYING ===")
1374
+ retry_obj = litellm.completion(
1375
+ model=model,
1376
+ max_tokens=max_tokens,
1377
+ messages=[
1378
+ {"role": "system", "content": system_prompt},
1379
+ {"role": "user", "content": text},
1380
+ {"role": "assistant", "content": response},
1381
+ {"role": "user", "content": "Your response was not valid JSON. Output ONLY the JSON object with no extra text, no markdown fences, no explanation."},
1382
+ ],
1383
+ api_key=api_key,
1384
+ **_litellm_extra(model),
1385
+ )
1386
+ response = _strip_fences(retry_obj.choices[0].message.content.strip())
1387
+ if verbose:
1388
+ print("=== RETRY RESPONSE ===")
1389
+ print(response)
1390
+ try:
1391
+ parsed = json.loads(response)
1392
+ except json.JSONDecodeError as e2:
1393
+ raise ValueError(
1394
+ f"LLM response could not be parsed as JSON after retry: {e2}\nRaw response:\n{response}"
1395
+ )
1396
+
1397
+ interpretation = parsed.pop('interpretation', '')
1398
+ query_type = parsed.pop('query_type', 'probability')
1399
+ print(f"Interpretation: {interpretation}")
1400
+
1401
+ # Build the canonical query dict based on query_type
1402
+ if query_type == 'percentile':
1403
+ query_dict = {
1404
+ 'metric': parsed.get('metric', {}),
1405
+ 'percentiles': parsed.get('percentiles', [5, 50, 95]),
1406
+ }
1407
+ else:
1408
+ # probability (default, backward compatible)
1409
+ query_dict = {'constraints': parsed.get('constraints', [])}
1410
+
1411
+ return query_dict, interpretation, system_prompt
1412
+
1413
+
1414
+ def title_from_json(file_json, f_prior_h5=None, model='anthropic/claude-sonnet-4-6',
1415
+ api_key=None, showInfo=1):
1416
+ """
1417
+ Return a plain-language description of what a query JSON dict will do.
1418
+
1419
+ Uses an LLM to produce a short human-readable summary suitable for a figure
1420
+ title or log message. If the LLM is unavailable (missing package, no API key,
1421
+ network error), returns an empty string.
1422
+
1423
+ Parameters
1424
+ ----------
1425
+ file_json : str or dict
1426
+ Path to a query JSON file, or a query dict directly (e.g. from
1427
+ ``ig.load_query()``).
1428
+ f_prior_h5 : str, optional
1429
+ Path to the prior HDF5 file. When provided, real model names, depth
1430
+ ranges, and class labels are included in the prompt so the description
1431
+ uses geological names instead of numeric model/class IDs.
1432
+ model : str, optional
1433
+ LiteLLM model string (default: 'anthropic/claude-sonnet-4-6').
1434
+ api_key : str, optional
1435
+ Provider API key. If None, the relevant environment variable is used.
1436
+ showInfo : int, optional
1437
+ 0 = silent; 1 = print a message when the LLM cannot be reached (default);
1438
+ 2 = also print the exception detail.
1439
+
1440
+ Returns
1441
+ -------
1442
+ description : str
1443
+ One-sentence plain-English summary of the query, or an empty string if
1444
+ the LLM could not be reached.
1445
+
1446
+ Examples
1447
+ --------
1448
+ >>> description = ig.title_from_json('my_query.json')
1449
+ >>> description = ig.title_from_json('my_query.json', f_prior_h5='prior.h5')
1450
+ >>> query = ig.load_query('query_ex1.json')
1451
+ >>> title = ig.title_from_json(query, f_prior_h5='prior.h5')
1452
+ >>> title = ig.title_from_json(query, showInfo=0) # silent on failure
1453
+ """
1454
+ try:
1455
+ import litellm
1456
+ except ImportError:
1457
+ if showInfo >= 1:
1458
+ print("[ig.title_from_json] LLM unavailable: 'litellm' package not installed "
1459
+ "(pip install litellm). Returning empty description.")
1460
+ return ''
1461
+
1462
+ if isinstance(file_json, str):
1463
+ try:
1464
+ with open(file_json, 'r') as fh:
1465
+ query_dict = json.load(fh)
1466
+ except Exception as e:
1467
+ if showInfo >= 1:
1468
+ print(f"[ig.title_from_json] Could not read query file: {e}")
1469
+ return ''
1470
+ else:
1471
+ query_dict = dict(file_json)
1472
+
1473
+ # Collect the im indices referenced by this query
1474
+ if 'constraints' in query_dict:
1475
+ items = query_dict['constraints']
1476
+ elif 'metric' in query_dict:
1477
+ items = [query_dict['metric']]
1478
+ else:
1479
+ items = []
1480
+ needed_ims = _collect_needed_ims(items) if items else set()
1481
+
1482
+ # Build an optional model-context block from the prior file
1483
+ model_context = ''
1484
+ if f_prior_h5 and needed_ims:
1485
+ try:
1486
+ lines = ['Available models referenced in this query:']
1487
+ for im in sorted(needed_ims):
1488
+ info = get_prior_model_info(f_prior_h5, im)
1489
+ z = info['z']
1490
+ depth_min, depth_max = float(z[0]), float(z[-1])
1491
+ is_scalar = (depth_max - depth_min) == 0 or len(z) - 1 == 0
1492
+ name = info['name']
1493
+ if is_scalar:
1494
+ kind = 'scalar-discrete' if info['is_discrete'] else 'scalar'
1495
+ lines.append(f" im={im}: '{name}' ({kind})")
1496
+ elif info['is_discrete']:
1497
+ lines.append(f" im={im}: '{name}' (discrete), depth {depth_min:.1f}–{depth_max:.1f} m")
1498
+ if info['class_id'] is not None and info['class_name'] is not None:
1499
+ ids = info['class_id'].flatten()
1500
+ names = info['class_name'].flatten()
1501
+ for cid, cname in zip(ids, names):
1502
+ lines.append(f" class {int(cid)} = {cname}")
1503
+ else:
1504
+ lines.append(f" im={im}: '{name}' (continuous), depth {depth_min:.1f}–{depth_max:.1f} m")
1505
+ model_context = '\n'.join(lines)
1506
+ except Exception:
1507
+ model_context = ''
1508
+
1509
+ system_prompt = (
1510
+ "You are a geophysics assistant for the INTEGRATE probabilistic inversion module. "
1511
+ "The user will provide a query dict in JSON format. "
1512
+ "Suggest a short figure title (max ~10 words) for the result this query produces. "
1513
+ "Use geological language and real model/class names where available. "
1514
+ "Write the title in title case. Do not start with 'Computes', 'Shows', or 'Displays'. "
1515
+ "Do not include JSON syntax. Reply with only the title — no preamble, no full stop."
1516
+ )
1517
+ if model_context:
1518
+ system_prompt += f"\n\n{model_context}"
1519
+
1520
+ try:
1521
+ response_obj = litellm.completion(
1522
+ model=model,
1523
+ max_tokens=128,
1524
+ messages=[
1525
+ {"role": "system", "content": system_prompt},
1526
+ {"role": "user", "content": json.dumps(query_dict)},
1527
+ ],
1528
+ api_key=api_key,
1529
+ **_litellm_extra(model),
1530
+ )
1531
+ description = (response_obj.choices[0].message.content or '').strip()
1532
+ return description
1533
+ except Exception as e:
1534
+ if showInfo >= 1:
1535
+ print(f"[ig.title_from_json] LLM call failed — returning empty description. "
1536
+ f"Use ig.query_test_llm() to diagnose. (model='{model}')")
1537
+ if showInfo >= 2:
1538
+ print(f" Detail: {e}")
1539
+ return ''
1540
+
1541
+
1542
+ def query_test_llm(model='anthropic/claude-sonnet-4-6', api_key=None, verbose=1):
1543
+ """
1544
+ Test whether a given LLM model and API key are working correctly.
1545
+
1546
+ Sends a minimal JSON-generation prompt and checks that the response is
1547
+ valid JSON. Prints a summary and returns a status dict.
1548
+
1549
+ Parameters
1550
+ ----------
1551
+ model : str, optional
1552
+ LiteLLM model string (default: 'anthropic/claude-sonnet-4-6').
1553
+ api_key : str, optional
1554
+ Provider API key. If None, the relevant environment variable is used.
1555
+ verbose : int, optional
1556
+ 0 = silent, 1 = summary only (default), 2 = full response included.
1557
+
1558
+ Returns
1559
+ -------
1560
+ result : dict
1561
+ Keys: 'ok' (bool), 'model', 'response' (str or None), 'error' (str or None).
1562
+ """
1563
+ try:
1564
+ import litellm
1565
+ except ImportError:
1566
+ raise ImportError(
1567
+ "The 'litellm' package is required. Install it with: pip install litellm"
1568
+ )
1569
+
1570
+ test_prompt = 'Reply with exactly this JSON and nothing else: {"status": "ok"}'
1571
+ result = {'ok': False, 'model': model, 'response': None, 'error': None}
1572
+
1573
+ try:
1574
+ response_obj = litellm.completion(
1575
+ model=model,
1576
+ max_tokens=256,
1577
+ messages=[{"role": "user", "content": test_prompt}],
1578
+ api_key=api_key,
1579
+ **_litellm_extra(model),
1580
+ )
1581
+ msg = response_obj.choices[0].message
1582
+ raw = (msg.content or '').strip()
1583
+ # strip markdown fences if present
1584
+ if raw.startswith("```"):
1585
+ raw = raw.split("\n", 1)[-1]
1586
+ if raw.endswith("```"):
1587
+ raw = raw.rsplit("```", 1)[0].strip()
1588
+ result['response'] = raw
1589
+
1590
+ if not raw:
1591
+ reasoning = getattr(msg, 'reasoning_content', None)
1592
+ hint = " (model returned empty content — it may be a thinking model with all tokens used for reasoning)" if not reasoning else f" (content was empty; reasoning_content present, length {len(reasoning)})"
1593
+ raise ValueError(f"Empty response from model{hint}")
1594
+
1595
+ json.loads(raw) # validate JSON
1596
+ result['ok'] = True
1597
+ if verbose >= 1:
1598
+ print(f"[query_test_llm] OK — model '{model}' responded with valid JSON.")
1599
+ if verbose >= 2:
1600
+ print(f" Response: {raw}")
1601
+ except Exception as e:
1602
+ result['error'] = str(e)
1603
+ if verbose >= 1:
1604
+ print(f"[query_test_llm] FAILED — model '{model}'")
1605
+ print(f" Error: {e}")
1606
+ if verbose >= 2 and result['response']:
1607
+ print(f" Raw response: {result['response']}")
1608
+
1609
+ return result