integrate_module 0.99.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- integrate/__init__.py +144 -0
- integrate/gex.py +402 -0
- integrate/integrate.py +4063 -0
- integrate/integrate_borehole.py +1127 -0
- integrate/integrate_hdf5_info_cli.py +122 -0
- integrate/integrate_io.py +5293 -0
- integrate/integrate_plot.py +4986 -0
- integrate/integrate_query.py +1609 -0
- integrate/integrate_rejection.py +1836 -0
- integrate/integrate_rejection_cli.py +210 -0
- integrate/integrate_rejection_jax.py +494 -0
- integrate/integrate_timing_cli.py +407 -0
- integrate/integrate_www_cli.py +8 -0
- integrate_module-0.99.1.dist-info/METADATA +229 -0
- integrate_module-0.99.1.dist-info/RECORD +19 -0
- integrate_module-0.99.1.dist-info/WHEEL +5 -0
- integrate_module-0.99.1.dist-info/entry_points.txt +5 -0
- integrate_module-0.99.1.dist-info/licenses/LICENSE +21 -0
- integrate_module-0.99.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1609 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Query posterior realizations based on geophysical constraints.
|
|
3
|
+
|
|
4
|
+
This module provides tools to compute probabilities that posterior realizations
|
|
5
|
+
from Bayesian inversion satisfy user-defined constraints (e.g., thickness of
|
|
6
|
+
lithology classes, resistivity thresholds).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import os
|
|
11
|
+
import numpy as np
|
|
12
|
+
import h5py
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_clipped_thickness(z, depth_min, depth_max):
|
|
17
|
+
"""
|
|
18
|
+
Compute per-layer thickness clipped to a depth interval.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
z : ndarray (N_depth,)
|
|
23
|
+
Depth values [m] at layer interfaces.
|
|
24
|
+
depth_min : float, None, or ndarray (N_samples,)
|
|
25
|
+
Lower depth bound [m]. Array triggers per-sample mode.
|
|
26
|
+
depth_max : float, None, or ndarray (N_samples,)
|
|
27
|
+
Upper depth bound [m]. Array triggers per-sample mode.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
t : ndarray (N_depth - 1,) or (N_samples, N_depth - 1)
|
|
32
|
+
Clipped thickness per layer. 2-D when bounds are per-sample.
|
|
33
|
+
mask : ndarray (N_depth - 1,) bool or None
|
|
34
|
+
True where clipped thickness > 0 (scalar bounds). None when 2-D.
|
|
35
|
+
"""
|
|
36
|
+
n_layers = len(z) - 1
|
|
37
|
+
top = z[:n_layers]
|
|
38
|
+
bot = z[1:]
|
|
39
|
+
lo = depth_min if depth_min is not None else -np.inf
|
|
40
|
+
hi = depth_max if depth_max is not None else np.inf
|
|
41
|
+
|
|
42
|
+
if np.ndim(lo) > 0 or np.ndim(hi) > 0:
|
|
43
|
+
# Per-sample bounds: broadcast (N_samples,1) against (N_layers,)
|
|
44
|
+
lo2 = np.asarray(lo)[:, np.newaxis] if np.ndim(lo) > 0 else lo
|
|
45
|
+
hi2 = np.asarray(hi)[:, np.newaxis] if np.ndim(hi) > 0 else hi
|
|
46
|
+
t = np.maximum(np.minimum(bot, hi2) - np.maximum(top, lo2), 0.0)
|
|
47
|
+
return t, None # 2-D t; mask handled downstream
|
|
48
|
+
else:
|
|
49
|
+
t = np.maximum(np.minimum(bot, hi) - np.maximum(top, lo), 0.0)
|
|
50
|
+
return t, t > 0
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _first_occurrence_thickness(condition, t):
|
|
54
|
+
"""
|
|
55
|
+
Thickness of the first contiguous True run per sample row.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
condition : ndarray (N_samples, N_layers) bool
|
|
60
|
+
t : ndarray (N_layers,) float
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
metric : ndarray (N_samples,) float
|
|
65
|
+
"""
|
|
66
|
+
metric = np.zeros(condition.shape[0])
|
|
67
|
+
has_true = condition.any(axis=1)
|
|
68
|
+
first_idx = np.argmax(condition, axis=1)
|
|
69
|
+
t_2d = t.ndim == 2
|
|
70
|
+
for i in np.where(has_true)[0]:
|
|
71
|
+
j = first_idx[i]
|
|
72
|
+
while j < condition.shape[1] and condition[i, j]:
|
|
73
|
+
metric[i] += t[i, j] if t_2d else t[j]
|
|
74
|
+
j += 1
|
|
75
|
+
return metric
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _compute_metric(M_samples, z, metric_def, scalar_model_values=None):
|
|
79
|
+
"""
|
|
80
|
+
Compute the raw per-realization metric value for a model.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
M_samples : ndarray (N_samples, N_depth)
|
|
85
|
+
Model values for the sampled realizations.
|
|
86
|
+
z : ndarray (N_depth,)
|
|
87
|
+
Depth values [m].
|
|
88
|
+
metric_def : dict
|
|
89
|
+
Metric definition — same fields as a constraint minus comparison fields.
|
|
90
|
+
Supports: im, classes, value_comparison, value_threshold, thickness_mode,
|
|
91
|
+
depth_min, depth_max, depth_max_im, depth_min_im.
|
|
92
|
+
scalar_model_values : dict {im: ndarray (N_samples,)}, optional
|
|
93
|
+
Per-sample scalar model values for dynamic depth bounds.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
metric : ndarray (N_samples,)
|
|
98
|
+
For scalar models: raw model value per realization.
|
|
99
|
+
For depth models: cumulative or first-occurrence thickness [m] of
|
|
100
|
+
layers satisfying the class/value condition.
|
|
101
|
+
"""
|
|
102
|
+
depth_min = metric_def.get('depth_min', None)
|
|
103
|
+
depth_max = metric_def.get('depth_max', None)
|
|
104
|
+
|
|
105
|
+
if scalar_model_values:
|
|
106
|
+
if 'depth_max_im' in metric_def:
|
|
107
|
+
depth_max = scalar_model_values.get(metric_def['depth_max_im'], depth_max)
|
|
108
|
+
if 'depth_min_im' in metric_def:
|
|
109
|
+
depth_min = scalar_model_values.get(metric_def['depth_min_im'], depth_min)
|
|
110
|
+
|
|
111
|
+
t, layer_mask = _get_clipped_thickness(z, depth_min, depth_max)
|
|
112
|
+
n_layers = len(z) - 1
|
|
113
|
+
|
|
114
|
+
# Scalar model: return the single value directly (no thickness concept)
|
|
115
|
+
if (z[-1] - z[0]) == 0 or n_layers == 0:
|
|
116
|
+
return M_samples[:, 0].copy()
|
|
117
|
+
|
|
118
|
+
if layer_mask is None:
|
|
119
|
+
M_sel = M_samples[:, :n_layers]
|
|
120
|
+
t_sel = t
|
|
121
|
+
else:
|
|
122
|
+
M_sel = M_samples[:, :n_layers][:, layer_mask]
|
|
123
|
+
t_sel = t[layer_mask]
|
|
124
|
+
|
|
125
|
+
if 'classes' in metric_def:
|
|
126
|
+
condition = np.isin(np.round(M_sel).astype(int), metric_def['classes'])
|
|
127
|
+
else:
|
|
128
|
+
v_cmp = metric_def.get('value_comparison', '<')
|
|
129
|
+
v_thr = metric_def.get('value_threshold', 0.0)
|
|
130
|
+
condition = M_sel < v_thr if v_cmp == '<' else M_sel > v_thr
|
|
131
|
+
|
|
132
|
+
mode = metric_def.get('thickness_mode', 'cumulative')
|
|
133
|
+
if mode == 'cumulative':
|
|
134
|
+
return (condition * t_sel).sum(axis=1)
|
|
135
|
+
else:
|
|
136
|
+
return _first_occurrence_thickness(condition, t_sel)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _evaluate_constraint(M_samples, z, constraint, scalar_model_values=None):
|
|
140
|
+
"""
|
|
141
|
+
Evaluate one constraint for a batch of realizations.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
M_samples : ndarray (N_samples, N_depth)
|
|
146
|
+
Model values for the sampled realizations.
|
|
147
|
+
z : ndarray (N_depth,)
|
|
148
|
+
Depth values [m].
|
|
149
|
+
constraint : dict
|
|
150
|
+
Constraint definition.
|
|
151
|
+
scalar_model_values : dict {im: ndarray (N_samples,)}, optional
|
|
152
|
+
Per-sample values of scalar models. Used to resolve `depth_max_im`
|
|
153
|
+
and `depth_min_im` constraint fields into per-sample depth bounds.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
valid : ndarray (N_samples,) bool
|
|
158
|
+
True for each realization that satisfies the constraint.
|
|
159
|
+
"""
|
|
160
|
+
n_layers = len(z) - 1
|
|
161
|
+
is_scalar = (z[-1] - z[0]) == 0 or n_layers == 0
|
|
162
|
+
|
|
163
|
+
raw = _compute_metric(M_samples, z, constraint, scalar_model_values)
|
|
164
|
+
|
|
165
|
+
if is_scalar:
|
|
166
|
+
if 'classes' in constraint:
|
|
167
|
+
valid = np.isin(np.round(raw).astype(int), constraint['classes'])
|
|
168
|
+
else:
|
|
169
|
+
v_cmp = constraint.get('value_comparison', '<')
|
|
170
|
+
v_thr = constraint.get('value_threshold', 0.0)
|
|
171
|
+
valid = raw < v_thr if v_cmp == '<' else raw > v_thr
|
|
172
|
+
else:
|
|
173
|
+
t_cmp = constraint.get('thickness_comparison', '>')
|
|
174
|
+
t_thr = constraint.get('thickness_threshold', 0.0)
|
|
175
|
+
_ops = {
|
|
176
|
+
'>': lambda m: m > t_thr,
|
|
177
|
+
'<': lambda m: m < t_thr,
|
|
178
|
+
'>=': lambda m: m >= t_thr,
|
|
179
|
+
'<=': lambda m: m <= t_thr,
|
|
180
|
+
}
|
|
181
|
+
valid = _ops.get(t_cmp, _ops['>'])(raw)
|
|
182
|
+
|
|
183
|
+
return ~valid if constraint.get('negate', False) else valid
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _collect_needed_ims(items):
|
|
187
|
+
"""Return the set of model indices (im) referenced in a constraint list or metric dict."""
|
|
188
|
+
needed = set()
|
|
189
|
+
if isinstance(items, dict):
|
|
190
|
+
items = [items]
|
|
191
|
+
for c in items:
|
|
192
|
+
needed.add(c['im'])
|
|
193
|
+
for field in ('depth_max_im', 'depth_min_im'):
|
|
194
|
+
if field in c:
|
|
195
|
+
needed.add(c[field])
|
|
196
|
+
return needed
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _load_query_inputs(f_post_h5, needed_ims):
|
|
200
|
+
"""
|
|
201
|
+
Load posterior indices, coordinates, and prior model arrays.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
f_post_h5 : str
|
|
206
|
+
Path to the posterior HDF5 file.
|
|
207
|
+
needed_ims : set of int
|
|
208
|
+
Model indices to pre-load from the prior file.
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
i_use : ndarray (N_data, N_post)
|
|
213
|
+
X, Y : ndarray or None
|
|
214
|
+
prior_models : dict {im: (M, z, is_discrete)}
|
|
215
|
+
N_data, N_post : int
|
|
216
|
+
"""
|
|
217
|
+
with h5py.File(f_post_h5, 'r') as f:
|
|
218
|
+
i_use = f['i_use'][:]
|
|
219
|
+
f_prior_h5 = str(f.attrs.get('f5_prior', ''))
|
|
220
|
+
f_data_h5 = str(f.attrs.get('f5_data', ''))
|
|
221
|
+
X = f['UTMX'][:] if 'UTMX' in f else None
|
|
222
|
+
Y = f['UTMY'][:] if 'UTMY' in f else None
|
|
223
|
+
|
|
224
|
+
if not f_prior_h5:
|
|
225
|
+
raise ValueError("Posterior file missing 'f5_prior' attribute.")
|
|
226
|
+
if not os.path.isfile(f_prior_h5):
|
|
227
|
+
raise FileNotFoundError(f"Prior file not found: {f_prior_h5}")
|
|
228
|
+
|
|
229
|
+
if (X is None or Y is None) and f_data_h5 and os.path.isfile(f_data_h5):
|
|
230
|
+
with h5py.File(f_data_h5, 'r') as f:
|
|
231
|
+
if X is None and 'UTMX' in f:
|
|
232
|
+
X = f['UTMX'][:]
|
|
233
|
+
if Y is None and 'UTMY' in f:
|
|
234
|
+
Y = f['UTMY'][:]
|
|
235
|
+
|
|
236
|
+
N_data, N_post = i_use.shape
|
|
237
|
+
|
|
238
|
+
prior_models = {}
|
|
239
|
+
for im in needed_ims:
|
|
240
|
+
key = f'M{im}'
|
|
241
|
+
with h5py.File(f_prior_h5, 'r') as f:
|
|
242
|
+
M = f[key][:]
|
|
243
|
+
z = f[key].attrs['x'].astype(float)
|
|
244
|
+
is_discrete = bool(f[key].attrs.get('is_discrete', 0))
|
|
245
|
+
prior_models[im] = (M, z, is_discrete)
|
|
246
|
+
|
|
247
|
+
return i_use, X, Y, prior_models, N_data, N_post
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _build_scalar_vals(prior_models, idx):
|
|
251
|
+
"""Return per-sample values for all scalar models in prior_models."""
|
|
252
|
+
scalar_vals = {}
|
|
253
|
+
for im_s, (M_s, z_s, _) in prior_models.items():
|
|
254
|
+
if (z_s[-1] - z_s[0]) == 0 or len(z_s) - 1 == 0:
|
|
255
|
+
scalar_vals[im_s] = M_s[idx, 0]
|
|
256
|
+
return scalar_vals
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def query_probability(f_post_h5, query_dict):
|
|
260
|
+
"""
|
|
261
|
+
Compute per-data-point probability that posterior realizations satisfy a query.
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
f_post_h5 : str
|
|
266
|
+
Path to the posterior HDF5 file.
|
|
267
|
+
query_dict : str or dict
|
|
268
|
+
Path to a JSON file, or a dict with a ``"constraints"`` key.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
P : ndarray (N_data,)
|
|
273
|
+
Probability [0, 1] for each data location.
|
|
274
|
+
meta : dict
|
|
275
|
+
Keys: 'X', 'Y', 'N_data', 'N_post', 'i_use', 'i_use_query'.
|
|
276
|
+
|
|
277
|
+
Examples
|
|
278
|
+
--------
|
|
279
|
+
>>> query_def = {
|
|
280
|
+
... "constraints": [{
|
|
281
|
+
... "im": 2, "classes": [2],
|
|
282
|
+
... "thickness_mode": "cumulative",
|
|
283
|
+
... "thickness_comparison": ">",
|
|
284
|
+
... "thickness_threshold": 10.0,
|
|
285
|
+
... "depth_min": 0.0, "depth_max": 30.0
|
|
286
|
+
... }]
|
|
287
|
+
... }
|
|
288
|
+
>>> P, meta = query_probability('f_post.h5', query_def)
|
|
289
|
+
"""
|
|
290
|
+
if isinstance(query_dict, str):
|
|
291
|
+
with open(query_dict, 'r') as fh:
|
|
292
|
+
query_dict = json.load(fh)
|
|
293
|
+
|
|
294
|
+
constraints = query_dict['constraints']
|
|
295
|
+
needed_ims = _collect_needed_ims(constraints)
|
|
296
|
+
i_use, X, Y, prior_models, N_data, N_post = _load_query_inputs(f_post_h5, needed_ims)
|
|
297
|
+
|
|
298
|
+
P = np.zeros(N_data)
|
|
299
|
+
i_use_query = []
|
|
300
|
+
for i in tqdm(range(N_data), desc='Evaluating probability query', unit='location'):
|
|
301
|
+
idx = i_use[i]
|
|
302
|
+
valid = np.ones(N_post, dtype=bool)
|
|
303
|
+
scalar_vals = _build_scalar_vals(prior_models, idx)
|
|
304
|
+
for c in constraints:
|
|
305
|
+
M, z, _ = prior_models[c['im']]
|
|
306
|
+
valid &= _evaluate_constraint(M[idx, :], z, c, scalar_model_values=scalar_vals)
|
|
307
|
+
P[i] = valid.mean()
|
|
308
|
+
i_use_query.append(idx[valid])
|
|
309
|
+
|
|
310
|
+
meta = {
|
|
311
|
+
'X': X, 'Y': Y,
|
|
312
|
+
'N_data': N_data, 'N_post': N_post,
|
|
313
|
+
'i_use': i_use, 'i_use_query': i_use_query,
|
|
314
|
+
}
|
|
315
|
+
return P, meta
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def query_percentile(f_post_h5, query_dict):
|
|
319
|
+
"""
|
|
320
|
+
Compute per-data-point percentiles of a metric over posterior realizations.
|
|
321
|
+
|
|
322
|
+
Rather than asking "what fraction of realizations satisfy condition X?", this
|
|
323
|
+
asks "what is the p5/p50/p95 of metric X across realizations?". The metric
|
|
324
|
+
is defined by the same fields as a probability constraint, minus the comparison
|
|
325
|
+
fields (``thickness_comparison``, ``thickness_threshold``, ``negate``).
|
|
326
|
+
|
|
327
|
+
Parameters
|
|
328
|
+
----------
|
|
329
|
+
f_post_h5 : str
|
|
330
|
+
Path to the posterior HDF5 file.
|
|
331
|
+
query_dict : str or dict
|
|
332
|
+
Path to a JSON file, or a dict with a ``"metric"`` key and an optional
|
|
333
|
+
``"percentiles"`` key (default ``[5, 50, 95]``).
|
|
334
|
+
|
|
335
|
+
Returns
|
|
336
|
+
-------
|
|
337
|
+
percentile_values : ndarray (N_data, n_percentiles)
|
|
338
|
+
Requested percentile values for each data location.
|
|
339
|
+
meta : dict
|
|
340
|
+
Keys: 'X', 'Y', 'N_data', 'N_post', 'i_use', 'percentiles'.
|
|
341
|
+
|
|
342
|
+
Examples
|
|
343
|
+
--------
|
|
344
|
+
>>> query_def = {
|
|
345
|
+
... "metric": {
|
|
346
|
+
... "im": 2, "classes": [1, 2],
|
|
347
|
+
... "thickness_mode": "cumulative",
|
|
348
|
+
... "depth_max": 30.0
|
|
349
|
+
... },
|
|
350
|
+
... "percentiles": [5, 50, 95]
|
|
351
|
+
... }
|
|
352
|
+
>>> pct_values, meta = query_percentile('f_post.h5', query_def)
|
|
353
|
+
>>> # pct_values shape: (N_data, 3) — p5, p50, p95 per location
|
|
354
|
+
"""
|
|
355
|
+
if isinstance(query_dict, str):
|
|
356
|
+
with open(query_dict, 'r') as fh:
|
|
357
|
+
query_dict = json.load(fh)
|
|
358
|
+
|
|
359
|
+
metric_def = query_dict['metric']
|
|
360
|
+
percentiles = query_dict.get('percentiles', [5, 50, 95])
|
|
361
|
+
needed_ims = _collect_needed_ims(metric_def)
|
|
362
|
+
i_use, X, Y, prior_models, N_data, N_post = _load_query_inputs(f_post_h5, needed_ims)
|
|
363
|
+
|
|
364
|
+
M_main, z_main, _ = prior_models[metric_def['im']]
|
|
365
|
+
n_pct = len(percentiles)
|
|
366
|
+
result = np.zeros((N_data, n_pct))
|
|
367
|
+
|
|
368
|
+
for i in tqdm(range(N_data), desc='Evaluating percentile query', unit='location'):
|
|
369
|
+
idx = i_use[i]
|
|
370
|
+
scalar_vals = _build_scalar_vals(prior_models, idx)
|
|
371
|
+
values = _compute_metric(M_main[idx, :], z_main, metric_def,
|
|
372
|
+
scalar_model_values=scalar_vals)
|
|
373
|
+
result[i, :] = np.percentile(values, percentiles)
|
|
374
|
+
|
|
375
|
+
meta = {
|
|
376
|
+
'X': X, 'Y': Y,
|
|
377
|
+
'N_data': N_data, 'N_post': N_post,
|
|
378
|
+
'i_use': i_use,
|
|
379
|
+
'percentiles': percentiles,
|
|
380
|
+
}
|
|
381
|
+
return result, meta
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def query(f_post_h5, query_dict):
|
|
385
|
+
"""
|
|
386
|
+
Dispatcher: route to query_probability() or query_percentile() based on query_dict.
|
|
387
|
+
|
|
388
|
+
If ``query_dict`` contains a ``"metric"`` key, calls :func:`query_percentile`.
|
|
389
|
+
Otherwise calls :func:`query_probability` (backward compatible with all existing
|
|
390
|
+
``"constraints"``-based dicts).
|
|
391
|
+
|
|
392
|
+
Parameters
|
|
393
|
+
----------
|
|
394
|
+
f_post_h5 : str
|
|
395
|
+
Path to the posterior HDF5 file.
|
|
396
|
+
query_dict : str or dict
|
|
397
|
+
Query definition. See :func:`query_probability` and
|
|
398
|
+
:func:`query_percentile` for the respective schemas.
|
|
399
|
+
|
|
400
|
+
Returns
|
|
401
|
+
-------
|
|
402
|
+
result, meta
|
|
403
|
+
See the delegated function for details.
|
|
404
|
+
"""
|
|
405
|
+
if not os.path.isfile(str(f_post_h5)):
|
|
406
|
+
print(f"[ig.query] Posterior file not found: {f_post_h5}")
|
|
407
|
+
return None, {}
|
|
408
|
+
if isinstance(query_dict, str):
|
|
409
|
+
with open(query_dict, 'r') as fh:
|
|
410
|
+
query_dict = json.load(fh)
|
|
411
|
+
if 'metric' in query_dict:
|
|
412
|
+
return query_percentile(f_post_h5, query_dict)
|
|
413
|
+
return query_probability(f_post_h5, query_dict)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def query_plot(P, meta, ip=None, query_dict=None, f_prior_h5=None, f_post_h5=None, title=None,
|
|
417
|
+
query_text=None, interpretation=None, text_panel=False, hardcopy=False, **kwargs):
|
|
418
|
+
"""
|
|
419
|
+
Plot query results and optionally detailed model visualization for a data point.
|
|
420
|
+
|
|
421
|
+
If ip is None, displays the XY probability map showing P(x, y).
|
|
422
|
+
If ip is provided (together with query_dict and f_prior_h5/f_post_h5), skips the
|
|
423
|
+
probability map and shows only the detailed single-point visualization of all
|
|
424
|
+
posterior realizations and the query-matching subset.
|
|
425
|
+
|
|
426
|
+
Parameters
|
|
427
|
+
----------
|
|
428
|
+
P : ndarray (N_data,)
|
|
429
|
+
Probability array from query().
|
|
430
|
+
meta : dict
|
|
431
|
+
Metadata dict from query() containing 'X', 'Y', 'i_use', 'i_use_query'.
|
|
432
|
+
ip : int, optional
|
|
433
|
+
Data point index to visualize in detail. If None, only shows probability map.
|
|
434
|
+
query_dict : dict, optional
|
|
435
|
+
Query dict used in query(). Required for detailed visualization.
|
|
436
|
+
f_prior_h5 : str, optional
|
|
437
|
+
Path to prior HDF5 file. If not provided, will be extracted from f_post_h5.
|
|
438
|
+
f_post_h5 : str, optional
|
|
439
|
+
Path to posterior HDF5 file. Used to automatically extract prior file path
|
|
440
|
+
if f_prior_h5 is not provided.
|
|
441
|
+
title : str, optional
|
|
442
|
+
Custom title for the probability map. If None, a title is built from
|
|
443
|
+
query_text and interpretation (if provided), or 'Query Probability Map'.
|
|
444
|
+
query_text : str, optional
|
|
445
|
+
The original natural-language query string. Shown in the figure title,
|
|
446
|
+
or in the text panel if text_panel=True.
|
|
447
|
+
interpretation : str, optional
|
|
448
|
+
The LLM interpretation string returned by query_from_text(). Shown as a
|
|
449
|
+
second line in the figure title, or in the text panel if text_panel=True.
|
|
450
|
+
text_panel : bool, optional
|
|
451
|
+
If True and query_text or interpretation is provided, adds a narrow text
|
|
452
|
+
column to the right of the probability map. The query text appears at the
|
|
453
|
+
top and the interpretation below. Default False.
|
|
454
|
+
hardcopy : bool or str, optional
|
|
455
|
+
Save the probability map figure. If True, saves as 'query_plot.png'.
|
|
456
|
+
If a string, uses that as the filename (a '.png' extension is appended
|
|
457
|
+
if the string has no extension). Default False.
|
|
458
|
+
**kwargs
|
|
459
|
+
All remaining keyword arguments are forwarded to :func:`plot_xy`, giving
|
|
460
|
+
full control over ``cmap``, ``clim``, ``uselog``, ``colorbar``,
|
|
461
|
+
``colorbar_label``, ``plotPoints``, ``plotPoints_color``,
|
|
462
|
+
``plotPoints_marker``, ``s``, etc.
|
|
463
|
+
``cmap`` defaults to ``'hot_r'`` and ``clim`` defaults to ``[0, 1]``.
|
|
464
|
+
|
|
465
|
+
Examples
|
|
466
|
+
--------
|
|
467
|
+
>>> P, meta = query(f_post_h5, query_def)
|
|
468
|
+
>>> query_plot(P, meta) # Just probability map
|
|
469
|
+
>>> query_plot(P, meta, title='Custom Query Title') # Custom title
|
|
470
|
+
>>> query_plot(P, meta, ip=1000, query_dict=query_def, f_post_h5='posterior.h5')
|
|
471
|
+
>>> query_plot(P, meta, ip=1000, query_dict=query_def, f_prior_h5='prior.h5')
|
|
472
|
+
>>> # With LLM query text and interpretation:
|
|
473
|
+
>>> query_dict, interp = ig.query_from_text(text, f_prior_h5)
|
|
474
|
+
>>> P, meta = ig.query(f_post_h5, query_dict)
|
|
475
|
+
>>> ig.query_plot(P, meta, query_text=text, interpretation=interp)
|
|
476
|
+
"""
|
|
477
|
+
import matplotlib.pyplot as plt
|
|
478
|
+
|
|
479
|
+
# Auto-extract prior file path from posterior file if needed
|
|
480
|
+
if f_prior_h5 is None and f_post_h5 is not None:
|
|
481
|
+
with h5py.File(f_post_h5, 'r') as f:
|
|
482
|
+
f_prior_h5 = str(f.attrs.get('f5_prior', ''))
|
|
483
|
+
if not f_prior_h5:
|
|
484
|
+
print("Warning: Could not extract f5_prior attribute from posterior file")
|
|
485
|
+
|
|
486
|
+
X = meta['X']
|
|
487
|
+
Y = meta['Y']
|
|
488
|
+
|
|
489
|
+
# Plot XY probability map only when no specific point is requested
|
|
490
|
+
if ip is None:
|
|
491
|
+
has_text = text_panel and (query_text is not None or interpretation is not None)
|
|
492
|
+
if has_text:
|
|
493
|
+
fig = plt.figure(figsize=(11, 6))
|
|
494
|
+
gs = fig.add_gridspec(1, 2, width_ratios=[3, 1], wspace=0.05)
|
|
495
|
+
ax = fig.add_subplot(gs[0])
|
|
496
|
+
ax_text = fig.add_subplot(gs[1])
|
|
497
|
+
else:
|
|
498
|
+
fig, ax = plt.subplots(figsize=(8, 6))
|
|
499
|
+
|
|
500
|
+
# Determine title
|
|
501
|
+
if has_text:
|
|
502
|
+
_title = 'Query Probability Map'
|
|
503
|
+
elif title is not None:
|
|
504
|
+
_title = title
|
|
505
|
+
elif query_text is not None or interpretation is not None:
|
|
506
|
+
parts = []
|
|
507
|
+
if query_text is not None:
|
|
508
|
+
parts.append(f"Query: {query_text}")
|
|
509
|
+
if interpretation is not None:
|
|
510
|
+
parts.append(f"Interpreted as: {interpretation}")
|
|
511
|
+
_title = '\n'.join(parts)
|
|
512
|
+
else:
|
|
513
|
+
_title = 'Query Probability Map'
|
|
514
|
+
|
|
515
|
+
import textwrap
|
|
516
|
+
_title = '\n'.join(
|
|
517
|
+
'\n'.join(textwrap.wrap(line, width=60)) if len(line) > 60 else line
|
|
518
|
+
for line in _title.splitlines()
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
# Background dots so P=0 (white) areas are visible, then probability scatter
|
|
522
|
+
_cmap = kwargs.pop('cmap', 'hot_r')
|
|
523
|
+
_clim = kwargs.pop('clim', [0, 1])
|
|
524
|
+
_s = kwargs.pop('s', 1)
|
|
525
|
+
_colorbar = kwargs.pop('colorbar', True)
|
|
526
|
+
_colorbar_label = kwargs.pop('colorbar_label', 'Probability')
|
|
527
|
+
_plotPoints = kwargs.pop('plotPoints', True)
|
|
528
|
+
from integrate.integrate_plot import plot_xy
|
|
529
|
+
_, ax, sc = plot_xy(P, X=X, Y=Y,
|
|
530
|
+
cmap=_cmap, clim=_clim,
|
|
531
|
+
title=_title, colorbar=_colorbar, colorbar_label=_colorbar_label,
|
|
532
|
+
ax=ax, s=_s, plotPoints=_plotPoints, **kwargs)
|
|
533
|
+
ax.set_xlabel('UTMX [m]')
|
|
534
|
+
ax.set_ylabel('UTMY [m]')
|
|
535
|
+
|
|
536
|
+
if has_text:
|
|
537
|
+
import textwrap
|
|
538
|
+
ax_text.set_axis_off()
|
|
539
|
+
CHARS = 36 # characters per wrapped line
|
|
540
|
+
LH = 0.062 # axes-fraction height per text line (fontsize 8)
|
|
541
|
+
LABEL_GAP = 0.03 # gap between bold label and text box
|
|
542
|
+
SECTION_GAP = 0.07 # gap between sections
|
|
543
|
+
|
|
544
|
+
y = 0.97
|
|
545
|
+
if query_text is not None:
|
|
546
|
+
ax_text.text(0.02, y, "Query:", transform=ax_text.transAxes,
|
|
547
|
+
fontsize=8, fontweight='bold', va='top')
|
|
548
|
+
y -= LH + LABEL_GAP
|
|
549
|
+
wrapped_q = textwrap.fill(query_text, CHARS)
|
|
550
|
+
n_q = wrapped_q.count('\n') + 1
|
|
551
|
+
ax_text.text(0.02, y, wrapped_q, transform=ax_text.transAxes,
|
|
552
|
+
fontsize=7.5, va='top',
|
|
553
|
+
bbox=dict(boxstyle='round,pad=0.4', facecolor='#f0f0f0', edgecolor='none'))
|
|
554
|
+
y -= n_q * LH + SECTION_GAP
|
|
555
|
+
if interpretation is not None:
|
|
556
|
+
ax_text.text(0.02, y, "Interpretation:", transform=ax_text.transAxes,
|
|
557
|
+
fontsize=8, fontweight='bold', va='top')
|
|
558
|
+
y -= LH + LABEL_GAP
|
|
559
|
+
wrapped_i = textwrap.fill(interpretation, CHARS)
|
|
560
|
+
ax_text.text(0.02, y, wrapped_i, transform=ax_text.transAxes,
|
|
561
|
+
fontsize=7.5, va='top',
|
|
562
|
+
bbox=dict(boxstyle='round,pad=0.4', facecolor='#e8f4e8', edgecolor='none'))
|
|
563
|
+
|
|
564
|
+
plt.tight_layout()
|
|
565
|
+
|
|
566
|
+
# If ip provided and we have necessary data, plot detailed model view
|
|
567
|
+
if ip is not None and query_dict is not None and f_prior_h5 is not None:
|
|
568
|
+
# Load prior model for the first constraint
|
|
569
|
+
im = query_dict['constraints'][0]['im']
|
|
570
|
+
with h5py.File(f_prior_h5, 'r') as f:
|
|
571
|
+
M = f[f'M{im}'][:]
|
|
572
|
+
# Read model attributes
|
|
573
|
+
is_discrete = bool(f[f'M{im}'].attrs.get('is_discrete', 0))
|
|
574
|
+
class_id = None
|
|
575
|
+
class_name = None
|
|
576
|
+
prior_cmap = None
|
|
577
|
+
|
|
578
|
+
# Always try to read colormap from prior file if available
|
|
579
|
+
if 'cmap' in f[f'M{im}'].attrs.keys():
|
|
580
|
+
try:
|
|
581
|
+
cmap_array = f[f'M{im}'].attrs['cmap'][:]
|
|
582
|
+
from matplotlib.colors import ListedColormap
|
|
583
|
+
# Format is [3, nlev] or [4, nlev] - transpose to get [nlev, 3] or [nlev, 4]
|
|
584
|
+
prior_cmap = ListedColormap(cmap_array.T)
|
|
585
|
+
except Exception:
|
|
586
|
+
prior_cmap = None
|
|
587
|
+
|
|
588
|
+
# Read class information for discrete models
|
|
589
|
+
if is_discrete:
|
|
590
|
+
if 'class_id' in f[f'M{im}'].attrs.keys():
|
|
591
|
+
class_id = f[f'M{im}'].attrs['class_id'][:].flatten()
|
|
592
|
+
if 'class_name' in f[f'M{im}'].attrs.keys():
|
|
593
|
+
class_name = f[f'M{im}'].attrs['class_name'][:].flatten()
|
|
594
|
+
|
|
595
|
+
# Get posterior and query-matching indices
|
|
596
|
+
i_use = meta['i_use'][ip, :]
|
|
597
|
+
i_use_query = meta['i_use_query'][ip]
|
|
598
|
+
|
|
599
|
+
# Calculate statistics
|
|
600
|
+
n_total = len(i_use)
|
|
601
|
+
n_accepted = len(i_use_query)
|
|
602
|
+
probability = P[ip]
|
|
603
|
+
|
|
604
|
+
# Get all posterior realizations
|
|
605
|
+
M_use_all = M[i_use]
|
|
606
|
+
|
|
607
|
+
# Create filtered version with NaN for non-matching realizations
|
|
608
|
+
# Convert to float to allow NaN values
|
|
609
|
+
M_use_filtered = M_use_all.astype(float)
|
|
610
|
+
# Create mask: True where i_use is in i_use_query
|
|
611
|
+
matching_mask = np.isin(i_use, i_use_query)
|
|
612
|
+
# Set non-matching realizations to NaN
|
|
613
|
+
M_use_filtered[~matching_mask, :] = np.nan
|
|
614
|
+
|
|
615
|
+
# Create detailed model plot
|
|
616
|
+
plt.figure(figsize=(12, 8))
|
|
617
|
+
|
|
618
|
+
# Determine color limits for discrete models
|
|
619
|
+
# Use the full range of class IDs from the prior file
|
|
620
|
+
if is_discrete and class_id is not None:
|
|
621
|
+
vmin_plot = np.min(class_id) - 0.5
|
|
622
|
+
vmax_plot = np.max(class_id) + 0.5
|
|
623
|
+
else:
|
|
624
|
+
vmin_plot = None
|
|
625
|
+
vmax_plot = None
|
|
626
|
+
|
|
627
|
+
# Subplot 1: All posterior realizations
|
|
628
|
+
plt.subplot(2, 1, 1)
|
|
629
|
+
# Use colormap from prior file if available
|
|
630
|
+
if prior_cmap is not None and vmin_plot is not None:
|
|
631
|
+
im1 = plt.imshow(M_use_all.T, aspect='auto', cmap=prior_cmap, interpolation='nearest',
|
|
632
|
+
vmin=vmin_plot, vmax=vmax_plot)
|
|
633
|
+
elif prior_cmap is not None:
|
|
634
|
+
im1 = plt.imshow(M_use_all.T, aspect='auto', cmap=prior_cmap, interpolation='nearest')
|
|
635
|
+
else:
|
|
636
|
+
im1 = plt.imshow(M_use_all.T, aspect='auto', cmap='jet', interpolation='nearest')
|
|
637
|
+
|
|
638
|
+
plt.title(f'All Posterior Realizations (Point {ip})\n'
|
|
639
|
+
f'Total Realizations: {n_total} | Accepted: {n_accepted} | Probability: {probability:.3f}')
|
|
640
|
+
plt.xlabel('Realization index')
|
|
641
|
+
plt.ylabel('Layer index')
|
|
642
|
+
|
|
643
|
+
# Create colorbar with class names if discrete
|
|
644
|
+
if is_discrete and class_id is not None and class_name is not None:
|
|
645
|
+
cbar1 = plt.colorbar(im1)
|
|
646
|
+
cbar1.set_ticks(class_id)
|
|
647
|
+
# Create tick labels with format "ClassName (ID)"
|
|
648
|
+
tick_labels = [f'{name} ({int(cid)})' for name, cid in zip(class_name, class_id)]
|
|
649
|
+
cbar1.set_ticklabels(tick_labels)
|
|
650
|
+
cbar1.ax.invert_yaxis()
|
|
651
|
+
else:
|
|
652
|
+
plt.colorbar(im1, label='Model value')
|
|
653
|
+
|
|
654
|
+
# Subplot 2: Query-matching realizations only (others set to NaN)
|
|
655
|
+
plt.subplot(2, 1, 2)
|
|
656
|
+
# Use colormap from prior file if available
|
|
657
|
+
if prior_cmap is not None and vmin_plot is not None:
|
|
658
|
+
im2 = plt.imshow(M_use_filtered.T, aspect='auto', cmap=prior_cmap, interpolation='nearest',
|
|
659
|
+
vmin=vmin_plot, vmax=vmax_plot)
|
|
660
|
+
elif prior_cmap is not None:
|
|
661
|
+
im2 = plt.imshow(M_use_filtered.T, aspect='auto', cmap=prior_cmap, interpolation='nearest')
|
|
662
|
+
else:
|
|
663
|
+
im2 = plt.imshow(M_use_filtered.T, aspect='auto', cmap='jet', interpolation='nearest')
|
|
664
|
+
|
|
665
|
+
plt.title('Query-Matching Realizations Only (non-matching set to NaN)')
|
|
666
|
+
plt.xlabel('Realization index')
|
|
667
|
+
plt.ylabel('Layer index')
|
|
668
|
+
|
|
669
|
+
# Create colorbar with class names if discrete
|
|
670
|
+
if is_discrete and class_id is not None and class_name is not None:
|
|
671
|
+
cbar2 = plt.colorbar(im2)
|
|
672
|
+
cbar2.set_ticks(class_id)
|
|
673
|
+
tick_labels = [f'{name} ({int(cid)})' for name, cid in zip(class_name, class_id)]
|
|
674
|
+
cbar2.set_ticklabels(tick_labels)
|
|
675
|
+
cbar2.ax.invert_yaxis()
|
|
676
|
+
else:
|
|
677
|
+
plt.colorbar(im2, label='Model value')
|
|
678
|
+
|
|
679
|
+
plt.tight_layout()
|
|
680
|
+
|
|
681
|
+
_VALID_EXTS = {'.png', '.jpg', '.jpeg', '.pdf', '.svg', '.eps', '.tif', '.tiff', '.webp'}
|
|
682
|
+
if hardcopy is not False and hardcopy is not None:
|
|
683
|
+
if isinstance(hardcopy, str):
|
|
684
|
+
safe = hardcopy.replace(':', '_').replace('/', '_')
|
|
685
|
+
f_png = safe if os.path.splitext(safe)[1].lower() in _VALID_EXTS else safe + '.png'
|
|
686
|
+
else:
|
|
687
|
+
f_png = 'query_plot.png'
|
|
688
|
+
plt.savefig(f_png)
|
|
689
|
+
print(f"Figure saved to {f_png}")
|
|
690
|
+
|
|
691
|
+
plt.show()
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def query_percentile_plot(percentile_values, meta, query_text=None, interpretation=None,
|
|
695
|
+
text_panel=False, hardcopy=False, **kwargs):
|
|
696
|
+
"""
|
|
697
|
+
Plot one probability map per requested percentile as side-by-side subplots.
|
|
698
|
+
|
|
699
|
+
Parameters
|
|
700
|
+
----------
|
|
701
|
+
percentile_values : ndarray (N_data, n_percentiles)
|
|
702
|
+
Output of query_percentile().
|
|
703
|
+
meta : dict
|
|
704
|
+
Metadata dict from query_percentile() containing 'X', 'Y', 'percentiles'.
|
|
705
|
+
query_text : str, optional
|
|
706
|
+
Original query string — shown as figure suptitle.
|
|
707
|
+
interpretation : str, optional
|
|
708
|
+
LLM interpretation string — shown below query_text if provided.
|
|
709
|
+
text_panel : bool, optional
|
|
710
|
+
If True, add a narrow text column to the right of the maps.
|
|
711
|
+
hardcopy : bool or str, optional
|
|
712
|
+
Save figure to disk. True → 'query_percentile_plot.png'; a string is
|
|
713
|
+
used as the filename (.png appended if no extension).
|
|
714
|
+
**kwargs
|
|
715
|
+
All remaining keyword arguments are forwarded to :func:`plot_xy`, giving
|
|
716
|
+
full control over ``cmap``, ``clim``, ``uselog``, ``colorbar``,
|
|
717
|
+
``colorbar_label``, ``plotPoints``, ``plotPoints_color``,
|
|
718
|
+
``plotPoints_marker``, ``s``, etc.
|
|
719
|
+
``clim`` defaults to ``[percentile_values.min(), percentile_values.max()]``
|
|
720
|
+
so that all subplots share the same colour scale.
|
|
721
|
+
``cmap`` defaults to ``'viridis'``.
|
|
722
|
+
|
|
723
|
+
Returns
|
|
724
|
+
-------
|
|
725
|
+
fig : matplotlib Figure
|
|
726
|
+
"""
|
|
727
|
+
import matplotlib.pyplot as plt
|
|
728
|
+
|
|
729
|
+
percentiles = meta.get('percentiles', [5, 50, 95])
|
|
730
|
+
n_pct = len(percentiles)
|
|
731
|
+
X = meta.get('X')
|
|
732
|
+
Y = meta.get('Y')
|
|
733
|
+
|
|
734
|
+
# Shared colour scale across all subplots unless caller overrides.
|
|
735
|
+
cmap = kwargs.pop('cmap', 'viridis')
|
|
736
|
+
clim = kwargs.pop('clim', [float(percentile_values.min()), float(percentile_values.max())])
|
|
737
|
+
|
|
738
|
+
ncols = n_pct + (1 if text_panel else 0)
|
|
739
|
+
width_ratios = [4] * n_pct + ([1.5] if text_panel else [])
|
|
740
|
+
fig, axes = plt.subplots(1, ncols, figsize=(4.5 * n_pct + (1.5 if text_panel else 0), 5),
|
|
741
|
+
gridspec_kw={'width_ratios': width_ratios} if text_panel else {},
|
|
742
|
+
squeeze=False)
|
|
743
|
+
|
|
744
|
+
map_axes = axes[0, :n_pct]
|
|
745
|
+
|
|
746
|
+
from integrate.integrate_plot import plot_xy
|
|
747
|
+
for k, (pct, ax) in enumerate(zip(percentiles, map_axes)):
|
|
748
|
+
vals = percentile_values[:, k]
|
|
749
|
+
if X is not None and Y is not None:
|
|
750
|
+
_, ax, _ = plot_xy(vals, X=X, Y=Y,
|
|
751
|
+
cmap=cmap, clim=clim,
|
|
752
|
+
ax=ax, **kwargs)
|
|
753
|
+
ax.set_xlabel('UTMX')
|
|
754
|
+
if k == 0:
|
|
755
|
+
ax.set_ylabel('UTMY')
|
|
756
|
+
else:
|
|
757
|
+
ax.plot(vals)
|
|
758
|
+
ax.set_xlabel('Location index')
|
|
759
|
+
ax.set_title(f'P{pct} (median={np.median(vals):.1f})')
|
|
760
|
+
|
|
761
|
+
if text_panel:
|
|
762
|
+
tax = axes[0, n_pct]
|
|
763
|
+
tax.axis('off')
|
|
764
|
+
txt = ''
|
|
765
|
+
if query_text:
|
|
766
|
+
txt += f'Query:\n{query_text}\n\n'
|
|
767
|
+
if interpretation:
|
|
768
|
+
txt += f'Interpretation:\n{interpretation}'
|
|
769
|
+
if txt:
|
|
770
|
+
tax.text(0.05, 0.95, txt, transform=tax.transAxes,
|
|
771
|
+
fontsize=7, va='top', wrap=True)
|
|
772
|
+
|
|
773
|
+
suptitle_parts = [t for t in [query_text, interpretation] if t]
|
|
774
|
+
if suptitle_parts and not text_panel:
|
|
775
|
+
fig.suptitle('\n'.join(suptitle_parts), fontsize=8, y=1.01)
|
|
776
|
+
|
|
777
|
+
plt.tight_layout()
|
|
778
|
+
|
|
779
|
+
if hardcopy:
|
|
780
|
+
fname = hardcopy if isinstance(hardcopy, str) else 'query_percentile_plot'
|
|
781
|
+
if '.' not in os.path.basename(fname):
|
|
782
|
+
fname += '.png'
|
|
783
|
+
plt.savefig(fname, bbox_inches='tight', dpi=150)
|
|
784
|
+
print(f"Figure saved to {fname}")
|
|
785
|
+
|
|
786
|
+
plt.show()
|
|
787
|
+
return fig
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
def save_query(query, path):
|
|
791
|
+
"""
|
|
792
|
+
Save a query dict to a JSON file.
|
|
793
|
+
|
|
794
|
+
Parameters
|
|
795
|
+
----------
|
|
796
|
+
query : dict
|
|
797
|
+
Query definition dictionary.
|
|
798
|
+
path : str
|
|
799
|
+
Output JSON file path.
|
|
800
|
+
"""
|
|
801
|
+
with open(path, 'w') as f:
|
|
802
|
+
json.dump(query, f, indent=2)
|
|
803
|
+
print(f"Query saved to {path}")
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
def load_query(path):
|
|
807
|
+
"""
|
|
808
|
+
Load a query dict from a JSON file.
|
|
809
|
+
|
|
810
|
+
Parameters
|
|
811
|
+
----------
|
|
812
|
+
path : str
|
|
813
|
+
Input JSON file path.
|
|
814
|
+
|
|
815
|
+
Returns
|
|
816
|
+
-------
|
|
817
|
+
query : dict
|
|
818
|
+
Query definition dictionary.
|
|
819
|
+
"""
|
|
820
|
+
with open(path, 'r') as f:
|
|
821
|
+
return json.load(f)
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
def get_prior_model_info(f_prior_h5, im):
|
|
825
|
+
"""
|
|
826
|
+
Return metadata for prior model im.
|
|
827
|
+
|
|
828
|
+
Parameters
|
|
829
|
+
----------
|
|
830
|
+
f_prior_h5 : str
|
|
831
|
+
Path to the prior HDF5 file.
|
|
832
|
+
im : int
|
|
833
|
+
Model index.
|
|
834
|
+
|
|
835
|
+
Returns
|
|
836
|
+
-------
|
|
837
|
+
info : dict
|
|
838
|
+
Keys: 'name', 'is_discrete', 'z', 'class_id', 'class_name'.
|
|
839
|
+
"""
|
|
840
|
+
key = f'M{im}'
|
|
841
|
+
with h5py.File(f_prior_h5, 'r') as f:
|
|
842
|
+
ds = f[key]
|
|
843
|
+
info = {
|
|
844
|
+
'name': ds.attrs.get('name', key),
|
|
845
|
+
'is_discrete': bool(ds.attrs.get('is_discrete', 0)),
|
|
846
|
+
'z': ds.attrs['x'].astype(float),
|
|
847
|
+
'class_id': ds.attrs.get('class_id', None),
|
|
848
|
+
'class_name': ds.attrs.get('class_name', None),
|
|
849
|
+
}
|
|
850
|
+
return info
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
def prior_describe(f_prior_h5):
|
|
854
|
+
"""
|
|
855
|
+
Print a human-readable summary of all models in a prior HDF5 file.
|
|
856
|
+
|
|
857
|
+
Parameters
|
|
858
|
+
----------
|
|
859
|
+
f_prior_h5 : str
|
|
860
|
+
Path to the prior HDF5 file.
|
|
861
|
+
|
|
862
|
+
Examples
|
|
863
|
+
--------
|
|
864
|
+
>>> ig.prior_describe('prior.h5')
|
|
865
|
+
Prior file: prior.h5
|
|
866
|
+
N realizations: 1000000
|
|
867
|
+
im=1 Resistivity CONTINUOUS depth 0–89 m (89 layers)
|
|
868
|
+
im=2 Lithology DISCRETE depth 0–89 m (89 layers)
|
|
869
|
+
class 1 = Sand
|
|
870
|
+
class 2 = Grus
|
|
871
|
+
im=3 Waterlevel SCALAR
|
|
872
|
+
"""
|
|
873
|
+
with h5py.File(f_prior_h5, 'r') as f:
|
|
874
|
+
model_keys = sorted(
|
|
875
|
+
[k for k in f.keys() if k.startswith('M') and k[1:].isdigit()],
|
|
876
|
+
key=lambda k: int(k[1:])
|
|
877
|
+
)
|
|
878
|
+
N = f[model_keys[0]].shape[0] if model_keys else 0
|
|
879
|
+
|
|
880
|
+
print(f"Prior file: {f_prior_h5}")
|
|
881
|
+
print(f"N realizations: {N}")
|
|
882
|
+
|
|
883
|
+
for key in model_keys:
|
|
884
|
+
im = int(key[1:])
|
|
885
|
+
info = get_prior_model_info(f_prior_h5, im)
|
|
886
|
+
z = info['z']
|
|
887
|
+
depth_min, depth_max = float(z[0]), float(z[-1])
|
|
888
|
+
n_layers = len(z) - 1
|
|
889
|
+
is_scalar = (depth_max - depth_min) == 0 or n_layers == 0
|
|
890
|
+
name = info['name']
|
|
891
|
+
|
|
892
|
+
if is_scalar:
|
|
893
|
+
kind = 'SCALAR-DISCRETE' if info['is_discrete'] else 'SCALAR'
|
|
894
|
+
print(f" im={im} {name:<20s} {kind}")
|
|
895
|
+
elif info['is_discrete']:
|
|
896
|
+
print(f" im={im} {name:<20s} DISCRETE depth {depth_min:.0f}–{depth_max:.0f} m ({n_layers} layers)")
|
|
897
|
+
else:
|
|
898
|
+
print(f" im={im} {name:<20s} CONTINUOUS depth {depth_min:.0f}–{depth_max:.0f} m ({n_layers} layers)")
|
|
899
|
+
|
|
900
|
+
if info['is_discrete'] and info['class_id'] is not None and info['class_name'] is not None:
|
|
901
|
+
for cid, cname in zip(info['class_id'].flatten(), info['class_name'].flatten()):
|
|
902
|
+
print(f" class {int(cid):>3} = {cname}")
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
def _build_llm_system_prompt(f_prior_h5):
|
|
906
|
+
"""
|
|
907
|
+
Build a system prompt for the LLM that includes the query schema and prior model context.
|
|
908
|
+
|
|
909
|
+
Parameters
|
|
910
|
+
----------
|
|
911
|
+
f_prior_h5 : str
|
|
912
|
+
Path to the prior HDF5 file.
|
|
913
|
+
|
|
914
|
+
Returns
|
|
915
|
+
-------
|
|
916
|
+
prompt : str
|
|
917
|
+
System prompt string.
|
|
918
|
+
"""
|
|
919
|
+
# Collect all model keys from the prior file
|
|
920
|
+
with h5py.File(f_prior_h5, 'r') as f:
|
|
921
|
+
model_keys = sorted([k for k in f.keys() if k.startswith('M') and k[1:].isdigit()])
|
|
922
|
+
|
|
923
|
+
model_sections = []
|
|
924
|
+
for key in model_keys:
|
|
925
|
+
im = int(key[1:])
|
|
926
|
+
info = get_prior_model_info(f_prior_h5, im)
|
|
927
|
+
z = info['z']
|
|
928
|
+
n_layers = len(z) - 1
|
|
929
|
+
depth_min = float(z[0])
|
|
930
|
+
depth_max = float(z[-1])
|
|
931
|
+
name = info['name'] if info['name'] != key else key
|
|
932
|
+
|
|
933
|
+
is_scalar = (depth_max - depth_min) == 0 or n_layers == 0
|
|
934
|
+
|
|
935
|
+
if is_scalar:
|
|
936
|
+
kind = 'SCALAR-DISCRETE' if info['is_discrete'] else 'SCALAR'
|
|
937
|
+
lines = [f" Model im={im}: {name} ({kind}) — single value per realization, no depth profile"]
|
|
938
|
+
lines.append(" Use only 'value_comparison' and 'value_threshold'. Do NOT include any thickness fields.")
|
|
939
|
+
if info['is_discrete'] and info['class_id'] is not None and info['class_name'] is not None:
|
|
940
|
+
lines.append(" Classes (use these integer IDs in the 'classes' field):")
|
|
941
|
+
ids = info['class_id'].flatten()
|
|
942
|
+
names = info['class_name'].flatten()
|
|
943
|
+
for cid, cname in zip(ids, names):
|
|
944
|
+
lines.append(f" {int(cid)} = {cname}")
|
|
945
|
+
elif info['is_discrete']:
|
|
946
|
+
lines = [f" Model im={im}: {name} (DISCRETE), depth {depth_min:.1f}–{depth_max:.1f} m, {n_layers} layers"]
|
|
947
|
+
if info['class_id'] is not None and info['class_name'] is not None:
|
|
948
|
+
lines.append(" Classes (use these integer IDs in the 'classes' field):")
|
|
949
|
+
ids = info['class_id'].flatten()
|
|
950
|
+
names = info['class_name'].flatten()
|
|
951
|
+
for cid, cname in zip(ids, names):
|
|
952
|
+
lines.append(f" {int(cid)} = {cname}")
|
|
953
|
+
else:
|
|
954
|
+
lines.append(" (Class IDs not available in prior file)")
|
|
955
|
+
else:
|
|
956
|
+
lines = [f" Model im={im}: {name} (CONTINUOUS), depth {depth_min:.1f}–{depth_max:.1f} m, {n_layers} layers"]
|
|
957
|
+
lines.append(" Use 'value_comparison' ('<' or '>') and 'value_threshold' for this model.")
|
|
958
|
+
|
|
959
|
+
model_sections.append('\n'.join(lines))
|
|
960
|
+
|
|
961
|
+
models_text = '\n'.join(model_sections)
|
|
962
|
+
|
|
963
|
+
prompt = f"""You are a geophysics query assistant for the INTEGRATE probabilistic inversion module.
|
|
964
|
+
Your task is to translate a natural-language query about geological or geophysical properties
|
|
965
|
+
into a valid JSON query dict that can be executed by the query() function.
|
|
966
|
+
|
|
967
|
+
## Query types
|
|
968
|
+
|
|
969
|
+
Choose the query type based on the user's intent:
|
|
970
|
+
|
|
971
|
+
**"probability"** — the user asks for a probability, likelihood, or yes/no fraction.
|
|
972
|
+
Example: "What is the probability that clay thickness exceeds 10 m?"
|
|
973
|
+
Response includes "query_type": "probability" and a "constraints" list.
|
|
974
|
+
|
|
975
|
+
**"percentile"** — the user asks for a distribution, typical value, or p5/p50/p95.
|
|
976
|
+
Example: "What are the p5, p50, p95 of the cumulative thickness of sand above 10 m depth?"
|
|
977
|
+
Response includes "query_type": "percentile", a single "metric" object, and a "percentiles" list.
|
|
978
|
+
|
|
979
|
+
## Response format
|
|
980
|
+
|
|
981
|
+
You must always respond with a single JSON object. Required top-level keys:
|
|
982
|
+
- "interpretation": 1–2 sentence plain-English confirmation of what you understood.
|
|
983
|
+
- "query_type": either "probability" or "percentile".
|
|
984
|
+
- For probability: "constraints" (list of constraint objects — see below).
|
|
985
|
+
- For percentile: "metric" (a single metric object — see below) and "percentiles" (list of ints,
|
|
986
|
+
default [5, 50, 95] if not specified by the user).
|
|
987
|
+
|
|
988
|
+
Probability response structure:
|
|
989
|
+
```json
|
|
990
|
+
{{
|
|
991
|
+
"interpretation": "...",
|
|
992
|
+
"query_type": "probability",
|
|
993
|
+
"constraints": [ {{ ... }} ]
|
|
994
|
+
}}
|
|
995
|
+
```
|
|
996
|
+
|
|
997
|
+
Percentile response structure:
|
|
998
|
+
```json
|
|
999
|
+
{{
|
|
1000
|
+
"interpretation": "...",
|
|
1001
|
+
"query_type": "percentile",
|
|
1002
|
+
"metric": {{ ... }},
|
|
1003
|
+
"percentiles": [5, 50, 95]
|
|
1004
|
+
}}
|
|
1005
|
+
```
|
|
1006
|
+
|
|
1007
|
+
## Constraint fields
|
|
1008
|
+
|
|
1009
|
+
A constraint list contains one or more constraint objects combined with logical AND
|
|
1010
|
+
(every constraint must be satisfied).
|
|
1011
|
+
|
|
1012
|
+
## Constraint Fields
|
|
1013
|
+
|
|
1014
|
+
| Field | Type | Required | Valid values | Description |
|
|
1015
|
+
|----------------------|-------------|-------------------|-------------------------------------|--------------------------------------------------|
|
|
1016
|
+
| im | int | always | 1, 2, 3, ... | Prior model index (see Available Models below) |
|
|
1017
|
+
| classes | list[int] | discrete only | class IDs from the model | Match any of these class IDs (discrete models) |
|
|
1018
|
+
| value_comparison | str | continuous only | "<" or ">" | Compare model value against threshold |
|
|
1019
|
+
| value_threshold | float | continuous only | any float | Threshold for continuous value comparison |
|
|
1020
|
+
| thickness_mode | str | depth models only | "cumulative" or "first_occurrence" | How to aggregate thickness of matching layers |
|
|
1021
|
+
| thickness_comparison | str | depth models only | ">", "<", ">=", "<=" | Operator applied to the computed thickness |
|
|
1022
|
+
| thickness_threshold | float | depth models only | any float (meters) | Thickness threshold in meters |
|
|
1023
|
+
| depth_min | float | optional | any float | Upper boundary of depth interval [m] |
|
|
1024
|
+
| depth_max | float | optional | any float | Lower boundary of depth interval [m] |
|
|
1025
|
+
| depth_max_im | int | optional | SCALAR model im | Per-realization depth_max from a scalar model |
|
|
1026
|
+
| depth_min_im | int | optional | SCALAR model im | Per-realization depth_min from a scalar model |
|
|
1027
|
+
| negate | bool | optional | true or false (default: false) | If true, invert the constraint result |
|
|
1028
|
+
|
|
1029
|
+
> **Cross-model depth bounds:** `depth_max_im` / `depth_min_im` take the `im` index of a
|
|
1030
|
+
> SCALAR model and use its per-realization value as the depth boundary. This enables
|
|
1031
|
+
> queries like "Sand above the water table" where the cutoff depth varies per realization.
|
|
1032
|
+
> Use `depth_max_im` to cut at the scalar model's value from above; use `depth_min_im`
|
|
1033
|
+
> to cut from below. These may be combined with fixed `depth_min` / `depth_max`.
|
|
1034
|
+
|
|
1035
|
+
### thickness_mode explained
|
|
1036
|
+
- "cumulative": sum the thickness of ALL matching layers within the depth interval
|
|
1037
|
+
- "first_occurrence": thickness of the FIRST contiguous block of matching layers
|
|
1038
|
+
|
|
1039
|
+
### Scalar models (marked SCALAR or SCALAR-DISCRETE above)
|
|
1040
|
+
These store a single value per realization, not a depth profile. For scalar models:
|
|
1041
|
+
- Omit ALL thickness fields (`thickness_mode`, `thickness_comparison`, `thickness_threshold`, `depth_min`, `depth_max`).
|
|
1042
|
+
- Use only `im`, `value_comparison`, `value_threshold`, and optionally `negate`.
|
|
1043
|
+
|
|
1044
|
+
## Available Prior Models
|
|
1045
|
+
|
|
1046
|
+
{models_text}
|
|
1047
|
+
|
|
1048
|
+
## Examples
|
|
1049
|
+
|
|
1050
|
+
## Metric fields (for percentile queries)
|
|
1051
|
+
|
|
1052
|
+
A metric object defines WHAT to measure per realization (no comparison or threshold).
|
|
1053
|
+
It uses the same fields as a constraint, minus: thickness_comparison, thickness_threshold, negate.
|
|
1054
|
+
|
|
1055
|
+
| Field | Type | Required | Valid values | Description |
|
|
1056
|
+
|------------------|-----------|-------------------|-------------------------------------|-----------------------------------------------|
|
|
1057
|
+
| im | int | always | 1, 2, 3, ... | Prior model index |
|
|
1058
|
+
| classes | list[int] | DISCRETE only | class IDs from the model | Thickness of these classes is measured |
|
|
1059
|
+
| value_comparison | str | CONTINUOUS only | "<" or ">" | Condition on value before measuring thickness |
|
|
1060
|
+
| value_threshold | float | CONTINUOUS only | any float | Threshold for the value condition |
|
|
1061
|
+
| thickness_mode | str | depth models only | "cumulative" or "first_occurrence" | How to aggregate thickness |
|
|
1062
|
+
| depth_min | float | optional | any float | Upper depth boundary [m] |
|
|
1063
|
+
| depth_max | float | optional | any float | Lower depth boundary [m] |
|
|
1064
|
+
| depth_max_im | int | optional | SCALAR model im | Per-realization depth_max from scalar model |
|
|
1065
|
+
| depth_min_im | int | optional | SCALAR model im | Per-realization depth_min from scalar model |
|
|
1066
|
+
|
|
1067
|
+
For SCALAR models used as a metric: returns the raw scalar value; no thickness fields needed.
|
|
1068
|
+
|
|
1069
|
+
### Example 1: Discrete cumulative constraint
|
|
1070
|
+
Query: "Probability that cumulative clay thickness exceeds 10 m within 0–30 m depth"
|
|
1071
|
+
```json
|
|
1072
|
+
{{
|
|
1073
|
+
"interpretation": "Probability that the cumulative thickness of clay (class 2) exceeds 10 m within 0–30 m depth.",
|
|
1074
|
+
"query_type": "probability",
|
|
1075
|
+
"constraints": [
|
|
1076
|
+
{{
|
|
1077
|
+
"im": 2,
|
|
1078
|
+
"classes": [2],
|
|
1079
|
+
"thickness_mode": "cumulative",
|
|
1080
|
+
"thickness_comparison": ">",
|
|
1081
|
+
"thickness_threshold": 10.0,
|
|
1082
|
+
"depth_min": 0.0,
|
|
1083
|
+
"depth_max": 30.0,
|
|
1084
|
+
"negate": false
|
|
1085
|
+
}}
|
|
1086
|
+
]
|
|
1087
|
+
}}
|
|
1088
|
+
```
|
|
1089
|
+
|
|
1090
|
+
### Example 2: Continuous cumulative constraint
|
|
1091
|
+
Query: "Probability that resistivity is below 100 ohm-m for at least 25 m within 0–50 m"
|
|
1092
|
+
```json
|
|
1093
|
+
{{
|
|
1094
|
+
"interpretation": "Probability that resistivity (im=1) is below 100 ohm-m for a cumulative thickness of at least 25 m within 0–50 m depth.",
|
|
1095
|
+
"query_type": "probability",
|
|
1096
|
+
"constraints": [
|
|
1097
|
+
{{
|
|
1098
|
+
"im": 1,
|
|
1099
|
+
"value_comparison": "<",
|
|
1100
|
+
"value_threshold": 100.0,
|
|
1101
|
+
"thickness_mode": "cumulative",
|
|
1102
|
+
"thickness_comparison": ">",
|
|
1103
|
+
"thickness_threshold": 25.0,
|
|
1104
|
+
"depth_min": 0.0,
|
|
1105
|
+
"depth_max": 50.0,
|
|
1106
|
+
"negate": false
|
|
1107
|
+
}}
|
|
1108
|
+
]
|
|
1109
|
+
}}
|
|
1110
|
+
```
|
|
1111
|
+
|
|
1112
|
+
### Example 3: Multi-constraint AND
|
|
1113
|
+
Query: "Probability that clay > 5 m within 0–20 m AND resistivity > 500 ohm-m for >= 1 m within 20–60 m"
|
|
1114
|
+
```json
|
|
1115
|
+
{{
|
|
1116
|
+
"interpretation": "Probability that cumulative clay (class 2) thickness exceeds 5 m within 0–20 m AND resistivity (im=1) exceeds 500 ohm-m for at least 1 m within 20–60 m. Both constraints must hold simultaneously.",
|
|
1117
|
+
"query_type": "probability",
|
|
1118
|
+
"constraints": [
|
|
1119
|
+
{{
|
|
1120
|
+
"im": 2,
|
|
1121
|
+
"classes": [2],
|
|
1122
|
+
"thickness_mode": "cumulative",
|
|
1123
|
+
"thickness_comparison": ">",
|
|
1124
|
+
"thickness_threshold": 5.0,
|
|
1125
|
+
"depth_min": 0.0,
|
|
1126
|
+
"depth_max": 20.0,
|
|
1127
|
+
"negate": false
|
|
1128
|
+
}},
|
|
1129
|
+
{{
|
|
1130
|
+
"im": 1,
|
|
1131
|
+
"value_comparison": ">",
|
|
1132
|
+
"value_threshold": 500.0,
|
|
1133
|
+
"thickness_mode": "cumulative",
|
|
1134
|
+
"thickness_comparison": ">",
|
|
1135
|
+
"thickness_threshold": 1.0,
|
|
1136
|
+
"depth_min": 20.0,
|
|
1137
|
+
"depth_max": 60.0,
|
|
1138
|
+
"negate": false
|
|
1139
|
+
}}
|
|
1140
|
+
]
|
|
1141
|
+
}}
|
|
1142
|
+
```
|
|
1143
|
+
|
|
1144
|
+
### Example 4: First-occurrence with negation
|
|
1145
|
+
Query: "Probability that the first occurrence of clay at the surface is less than 3 m thick"
|
|
1146
|
+
```json
|
|
1147
|
+
{{
|
|
1148
|
+
"interpretation": "Probability that the first contiguous block of clay (class 2) starting from the surface is less than 3 m thick, within 0–30 m depth.",
|
|
1149
|
+
"query_type": "probability",
|
|
1150
|
+
"constraints": [
|
|
1151
|
+
{{
|
|
1152
|
+
"im": 2,
|
|
1153
|
+
"classes": [2],
|
|
1154
|
+
"thickness_mode": "first_occurrence",
|
|
1155
|
+
"thickness_comparison": "<",
|
|
1156
|
+
"thickness_threshold": 3.0,
|
|
1157
|
+
"depth_min": 0.0,
|
|
1158
|
+
"depth_max": 30.0,
|
|
1159
|
+
"negate": false
|
|
1160
|
+
}}
|
|
1161
|
+
]
|
|
1162
|
+
}}
|
|
1163
|
+
```
|
|
1164
|
+
|
|
1165
|
+
### Example 5: Scalar model query (no thickness fields)
|
|
1166
|
+
Query: "Probability that the water table is shallower than 5 m"
|
|
1167
|
+
```json
|
|
1168
|
+
{{
|
|
1169
|
+
"interpretation": "Probability that the water table depth (im=3, SCALAR) is less than 5 m.",
|
|
1170
|
+
"query_type": "probability",
|
|
1171
|
+
"constraints": [
|
|
1172
|
+
{{
|
|
1173
|
+
"im": 3,
|
|
1174
|
+
"value_comparison": "<",
|
|
1175
|
+
"value_threshold": 5.0,
|
|
1176
|
+
"negate": false
|
|
1177
|
+
}}
|
|
1178
|
+
]
|
|
1179
|
+
}}
|
|
1180
|
+
```
|
|
1181
|
+
|
|
1182
|
+
### Example 6: Cross-model depth constraint (dynamic depth bound from scalar model)
|
|
1183
|
+
Query: "Probability that Sand and Grus have a cumulative thickness above the water table exceeding 5 m"
|
|
1184
|
+
```json
|
|
1185
|
+
{{
|
|
1186
|
+
"interpretation": "Probability that Sand (class 1) and Grus (class 2) have a cumulative thickness exceeding 5 m within the zone above the water table (im=3), starting from the surface.",
|
|
1187
|
+
"query_type": "probability",
|
|
1188
|
+
"constraints": [
|
|
1189
|
+
{{
|
|
1190
|
+
"im": 2,
|
|
1191
|
+
"classes": [1, 2],
|
|
1192
|
+
"thickness_mode": "cumulative",
|
|
1193
|
+
"thickness_comparison": ">",
|
|
1194
|
+
"thickness_threshold": 5.0,
|
|
1195
|
+
"depth_min": 0.0,
|
|
1196
|
+
"depth_max_im": 3,
|
|
1197
|
+
"negate": false
|
|
1198
|
+
}}
|
|
1199
|
+
]
|
|
1200
|
+
}}
|
|
1201
|
+
```
|
|
1202
|
+
|
|
1203
|
+
### Example 7: Percentile query — thickness distribution
|
|
1204
|
+
Query: "What are the p5, p50, and p95 of the cumulative thickness of Sand and Grus within 0 to 30 m depth?"
|
|
1205
|
+
```json
|
|
1206
|
+
{{
|
|
1207
|
+
"interpretation": "P5/P50/P95 of the cumulative thickness of Sand (class 1) and Grus (class 2) within 0–30 m depth.",
|
|
1208
|
+
"query_type": "percentile",
|
|
1209
|
+
"metric": {{
|
|
1210
|
+
"im": 2,
|
|
1211
|
+
"classes": [1, 2],
|
|
1212
|
+
"thickness_mode": "cumulative",
|
|
1213
|
+
"depth_min": 0.0,
|
|
1214
|
+
"depth_max": 30.0
|
|
1215
|
+
}},
|
|
1216
|
+
"percentiles": [5, 50, 95]
|
|
1217
|
+
}}
|
|
1218
|
+
```
|
|
1219
|
+
|
|
1220
|
+
### Example 8: Percentile query — cross-model depth bound
|
|
1221
|
+
Query: "What is the typical (median) thickness of Sand and Grus above the water table?"
|
|
1222
|
+
```json
|
|
1223
|
+
{{
|
|
1224
|
+
"interpretation": "P5/P50/P95 of the cumulative thickness of Sand (class 1) and Grus (class 2) above the water table (depth bounded per realization by im=3).",
|
|
1225
|
+
"query_type": "percentile",
|
|
1226
|
+
"metric": {{
|
|
1227
|
+
"im": 2,
|
|
1228
|
+
"classes": [1, 2],
|
|
1229
|
+
"thickness_mode": "cumulative",
|
|
1230
|
+
"depth_min": 0.0,
|
|
1231
|
+
"depth_max_im": 3
|
|
1232
|
+
}},
|
|
1233
|
+
"percentiles": [5, 50, 95]
|
|
1234
|
+
}}
|
|
1235
|
+
```
|
|
1236
|
+
|
|
1237
|
+
## Instructions
|
|
1238
|
+
|
|
1239
|
+
- Respond with ONLY a valid JSON object. No markdown fences, no extra commentary.
|
|
1240
|
+
- Always include "interpretation" (1–2 sentences) and "query_type" ("probability" or "percentile").
|
|
1241
|
+
- Use only the model indices (im) and class IDs listed under Available Prior Models above.
|
|
1242
|
+
- If the query cannot be expressed with the available schema and models, respond with exactly:
|
|
1243
|
+
UNSUPPORTED: <brief reason>
|
|
1244
|
+
- Do not invent class IDs or model indices that are not listed above.
|
|
1245
|
+
"""
|
|
1246
|
+
return prompt
|
|
1247
|
+
|
|
1248
|
+
|
|
1249
|
+
def _litellm_extra(model):
|
|
1250
|
+
"""Return extra_body kwargs for litellm.completion to disable thinking on Ollama models."""
|
|
1251
|
+
if model.startswith('ollama'):
|
|
1252
|
+
return {'extra_body': {'think': False}}
|
|
1253
|
+
return {}
|
|
1254
|
+
|
|
1255
|
+
|
|
1256
|
+
def query_from_text(text, f_prior_h5, model='anthropic/claude-sonnet-4-6', api_key=None, max_tokens=4096, verbose=False):
|
|
1257
|
+
"""
|
|
1258
|
+
Translate a natural-language query into a query dict using an LLM.
|
|
1259
|
+
|
|
1260
|
+
Uses LiteLLM to interpret the user's text query in the context of the
|
|
1261
|
+
available prior models and the integrate query schema, returning a query
|
|
1262
|
+
dict and a plain-English interpretation of what the LLM understood.
|
|
1263
|
+
|
|
1264
|
+
Parameters
|
|
1265
|
+
----------
|
|
1266
|
+
text : str
|
|
1267
|
+
Natural language description of the query, e.g.
|
|
1268
|
+
"What is the probability that cumulative clay thickness exceeds 10 m?".
|
|
1269
|
+
f_prior_h5 : str
|
|
1270
|
+
Path to the prior HDF5 file. Model metadata (class names, depth ranges,
|
|
1271
|
+
discrete/continuous type) is read automatically and included in the
|
|
1272
|
+
LLM prompt so the model knows what constraints are valid.
|
|
1273
|
+
model : str, optional
|
|
1274
|
+
LiteLLM model string (default: 'anthropic/claude-sonnet-4-6'). Any
|
|
1275
|
+
LiteLLM-supported model works, e.g. 'openai/gpt-4o'.
|
|
1276
|
+
api_key : str, optional
|
|
1277
|
+
Provider API key. If None, the relevant environment variable
|
|
1278
|
+
(e.g. ANTHROPIC_API_KEY) is used.
|
|
1279
|
+
verbose : bool, optional
|
|
1280
|
+
If True, print the system prompt and LLM response for inspection.
|
|
1281
|
+
|
|
1282
|
+
Returns
|
|
1283
|
+
-------
|
|
1284
|
+
query_dict : dict
|
|
1285
|
+
Query dict ready to pass to ig.query(f_post_h5, query_dict).
|
|
1286
|
+
interpretation : str
|
|
1287
|
+
Plain English confirmation of what the LLM understood the query to mean.
|
|
1288
|
+
Check this before running ig.query() to catch misunderstandings cheaply.
|
|
1289
|
+
system_prompt : str
|
|
1290
|
+
The full system prompt sent to the LLM. Useful for inspection and debugging.
|
|
1291
|
+
|
|
1292
|
+
Raises
|
|
1293
|
+
------
|
|
1294
|
+
ImportError
|
|
1295
|
+
If the litellm package is not installed.
|
|
1296
|
+
ValueError
|
|
1297
|
+
If the LLM reports the query is unsupported, or if the response
|
|
1298
|
+
cannot be parsed as valid JSON.
|
|
1299
|
+
|
|
1300
|
+
Notes
|
|
1301
|
+
-----
|
|
1302
|
+
Requires either the api_key parameter or the relevant provider environment
|
|
1303
|
+
variable to be set. Install the dependency with: pip install litellm
|
|
1304
|
+
|
|
1305
|
+
Examples
|
|
1306
|
+
--------
|
|
1307
|
+
>>> import integrate as ig
|
|
1308
|
+
>>> query_dict, interpretation, system_prompt = ig.query_from_text(
|
|
1309
|
+
... "Probability that cumulative clay thickness > 10 m within 0-30 m",
|
|
1310
|
+
... f_prior_h5='prior.h5',
|
|
1311
|
+
... api_key='sk-ant-...',
|
|
1312
|
+
... )
|
|
1313
|
+
>>> print(interpretation)
|
|
1314
|
+
>>> P, meta = ig.query('posterior.h5', query_dict)
|
|
1315
|
+
>>> ig.query_plot(P, meta)
|
|
1316
|
+
"""
|
|
1317
|
+
try:
|
|
1318
|
+
import litellm
|
|
1319
|
+
except ImportError:
|
|
1320
|
+
raise ImportError(
|
|
1321
|
+
"The 'litellm' package is required for query_from_text(). "
|
|
1322
|
+
"Install it with: pip install litellm"
|
|
1323
|
+
)
|
|
1324
|
+
|
|
1325
|
+
system_prompt = _build_llm_system_prompt(f_prior_h5)
|
|
1326
|
+
|
|
1327
|
+
if verbose:
|
|
1328
|
+
print("=== SYSTEM PROMPT ===")
|
|
1329
|
+
print(system_prompt)
|
|
1330
|
+
print("=== USER TEXT ===")
|
|
1331
|
+
print(text)
|
|
1332
|
+
|
|
1333
|
+
def _strip_fences(s):
|
|
1334
|
+
if s.startswith("```"):
|
|
1335
|
+
s = s.split("\n", 1)[-1]
|
|
1336
|
+
if s.endswith("```"):
|
|
1337
|
+
s = s.rsplit("```", 1)[0].strip()
|
|
1338
|
+
return s
|
|
1339
|
+
|
|
1340
|
+
response_obj = litellm.completion(
|
|
1341
|
+
model=model,
|
|
1342
|
+
max_tokens=max_tokens,
|
|
1343
|
+
messages=[
|
|
1344
|
+
{"role": "system", "content": system_prompt},
|
|
1345
|
+
{"role": "user", "content": text},
|
|
1346
|
+
],
|
|
1347
|
+
api_key=api_key,
|
|
1348
|
+
**_litellm_extra(model),
|
|
1349
|
+
)
|
|
1350
|
+
|
|
1351
|
+
msg = response_obj.choices[0].message
|
|
1352
|
+
response = _strip_fences((msg.content or '').strip())
|
|
1353
|
+
|
|
1354
|
+
if not response:
|
|
1355
|
+
raise ValueError(
|
|
1356
|
+
f"Model '{model}' returned empty content. "
|
|
1357
|
+
"If this is a thinking model (e.g. Qwen3, DeepSeek-R1), ensure /no_think "
|
|
1358
|
+
"is in the prompt or increase max_tokens."
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
if verbose:
|
|
1362
|
+
print("=== LLM RESPONSE ===")
|
|
1363
|
+
print(response)
|
|
1364
|
+
|
|
1365
|
+
if response.startswith("UNSUPPORTED:"):
|
|
1366
|
+
reason = response[len("UNSUPPORTED:"):].strip()
|
|
1367
|
+
raise ValueError(f"Query cannot be expressed with the current schema: {reason}")
|
|
1368
|
+
|
|
1369
|
+
try:
|
|
1370
|
+
parsed = json.loads(response)
|
|
1371
|
+
except json.JSONDecodeError:
|
|
1372
|
+
if verbose:
|
|
1373
|
+
print("=== JSON PARSE FAILED — RETRYING ===")
|
|
1374
|
+
retry_obj = litellm.completion(
|
|
1375
|
+
model=model,
|
|
1376
|
+
max_tokens=max_tokens,
|
|
1377
|
+
messages=[
|
|
1378
|
+
{"role": "system", "content": system_prompt},
|
|
1379
|
+
{"role": "user", "content": text},
|
|
1380
|
+
{"role": "assistant", "content": response},
|
|
1381
|
+
{"role": "user", "content": "Your response was not valid JSON. Output ONLY the JSON object with no extra text, no markdown fences, no explanation."},
|
|
1382
|
+
],
|
|
1383
|
+
api_key=api_key,
|
|
1384
|
+
**_litellm_extra(model),
|
|
1385
|
+
)
|
|
1386
|
+
response = _strip_fences(retry_obj.choices[0].message.content.strip())
|
|
1387
|
+
if verbose:
|
|
1388
|
+
print("=== RETRY RESPONSE ===")
|
|
1389
|
+
print(response)
|
|
1390
|
+
try:
|
|
1391
|
+
parsed = json.loads(response)
|
|
1392
|
+
except json.JSONDecodeError as e2:
|
|
1393
|
+
raise ValueError(
|
|
1394
|
+
f"LLM response could not be parsed as JSON after retry: {e2}\nRaw response:\n{response}"
|
|
1395
|
+
)
|
|
1396
|
+
|
|
1397
|
+
interpretation = parsed.pop('interpretation', '')
|
|
1398
|
+
query_type = parsed.pop('query_type', 'probability')
|
|
1399
|
+
print(f"Interpretation: {interpretation}")
|
|
1400
|
+
|
|
1401
|
+
# Build the canonical query dict based on query_type
|
|
1402
|
+
if query_type == 'percentile':
|
|
1403
|
+
query_dict = {
|
|
1404
|
+
'metric': parsed.get('metric', {}),
|
|
1405
|
+
'percentiles': parsed.get('percentiles', [5, 50, 95]),
|
|
1406
|
+
}
|
|
1407
|
+
else:
|
|
1408
|
+
# probability (default, backward compatible)
|
|
1409
|
+
query_dict = {'constraints': parsed.get('constraints', [])}
|
|
1410
|
+
|
|
1411
|
+
return query_dict, interpretation, system_prompt
|
|
1412
|
+
|
|
1413
|
+
|
|
1414
|
+
def title_from_json(file_json, f_prior_h5=None, model='anthropic/claude-sonnet-4-6',
|
|
1415
|
+
api_key=None, showInfo=1):
|
|
1416
|
+
"""
|
|
1417
|
+
Return a plain-language description of what a query JSON dict will do.
|
|
1418
|
+
|
|
1419
|
+
Uses an LLM to produce a short human-readable summary suitable for a figure
|
|
1420
|
+
title or log message. If the LLM is unavailable (missing package, no API key,
|
|
1421
|
+
network error), returns an empty string.
|
|
1422
|
+
|
|
1423
|
+
Parameters
|
|
1424
|
+
----------
|
|
1425
|
+
file_json : str or dict
|
|
1426
|
+
Path to a query JSON file, or a query dict directly (e.g. from
|
|
1427
|
+
``ig.load_query()``).
|
|
1428
|
+
f_prior_h5 : str, optional
|
|
1429
|
+
Path to the prior HDF5 file. When provided, real model names, depth
|
|
1430
|
+
ranges, and class labels are included in the prompt so the description
|
|
1431
|
+
uses geological names instead of numeric model/class IDs.
|
|
1432
|
+
model : str, optional
|
|
1433
|
+
LiteLLM model string (default: 'anthropic/claude-sonnet-4-6').
|
|
1434
|
+
api_key : str, optional
|
|
1435
|
+
Provider API key. If None, the relevant environment variable is used.
|
|
1436
|
+
showInfo : int, optional
|
|
1437
|
+
0 = silent; 1 = print a message when the LLM cannot be reached (default);
|
|
1438
|
+
2 = also print the exception detail.
|
|
1439
|
+
|
|
1440
|
+
Returns
|
|
1441
|
+
-------
|
|
1442
|
+
description : str
|
|
1443
|
+
One-sentence plain-English summary of the query, or an empty string if
|
|
1444
|
+
the LLM could not be reached.
|
|
1445
|
+
|
|
1446
|
+
Examples
|
|
1447
|
+
--------
|
|
1448
|
+
>>> description = ig.title_from_json('my_query.json')
|
|
1449
|
+
>>> description = ig.title_from_json('my_query.json', f_prior_h5='prior.h5')
|
|
1450
|
+
>>> query = ig.load_query('query_ex1.json')
|
|
1451
|
+
>>> title = ig.title_from_json(query, f_prior_h5='prior.h5')
|
|
1452
|
+
>>> title = ig.title_from_json(query, showInfo=0) # silent on failure
|
|
1453
|
+
"""
|
|
1454
|
+
try:
|
|
1455
|
+
import litellm
|
|
1456
|
+
except ImportError:
|
|
1457
|
+
if showInfo >= 1:
|
|
1458
|
+
print("[ig.title_from_json] LLM unavailable: 'litellm' package not installed "
|
|
1459
|
+
"(pip install litellm). Returning empty description.")
|
|
1460
|
+
return ''
|
|
1461
|
+
|
|
1462
|
+
if isinstance(file_json, str):
|
|
1463
|
+
try:
|
|
1464
|
+
with open(file_json, 'r') as fh:
|
|
1465
|
+
query_dict = json.load(fh)
|
|
1466
|
+
except Exception as e:
|
|
1467
|
+
if showInfo >= 1:
|
|
1468
|
+
print(f"[ig.title_from_json] Could not read query file: {e}")
|
|
1469
|
+
return ''
|
|
1470
|
+
else:
|
|
1471
|
+
query_dict = dict(file_json)
|
|
1472
|
+
|
|
1473
|
+
# Collect the im indices referenced by this query
|
|
1474
|
+
if 'constraints' in query_dict:
|
|
1475
|
+
items = query_dict['constraints']
|
|
1476
|
+
elif 'metric' in query_dict:
|
|
1477
|
+
items = [query_dict['metric']]
|
|
1478
|
+
else:
|
|
1479
|
+
items = []
|
|
1480
|
+
needed_ims = _collect_needed_ims(items) if items else set()
|
|
1481
|
+
|
|
1482
|
+
# Build an optional model-context block from the prior file
|
|
1483
|
+
model_context = ''
|
|
1484
|
+
if f_prior_h5 and needed_ims:
|
|
1485
|
+
try:
|
|
1486
|
+
lines = ['Available models referenced in this query:']
|
|
1487
|
+
for im in sorted(needed_ims):
|
|
1488
|
+
info = get_prior_model_info(f_prior_h5, im)
|
|
1489
|
+
z = info['z']
|
|
1490
|
+
depth_min, depth_max = float(z[0]), float(z[-1])
|
|
1491
|
+
is_scalar = (depth_max - depth_min) == 0 or len(z) - 1 == 0
|
|
1492
|
+
name = info['name']
|
|
1493
|
+
if is_scalar:
|
|
1494
|
+
kind = 'scalar-discrete' if info['is_discrete'] else 'scalar'
|
|
1495
|
+
lines.append(f" im={im}: '{name}' ({kind})")
|
|
1496
|
+
elif info['is_discrete']:
|
|
1497
|
+
lines.append(f" im={im}: '{name}' (discrete), depth {depth_min:.1f}–{depth_max:.1f} m")
|
|
1498
|
+
if info['class_id'] is not None and info['class_name'] is not None:
|
|
1499
|
+
ids = info['class_id'].flatten()
|
|
1500
|
+
names = info['class_name'].flatten()
|
|
1501
|
+
for cid, cname in zip(ids, names):
|
|
1502
|
+
lines.append(f" class {int(cid)} = {cname}")
|
|
1503
|
+
else:
|
|
1504
|
+
lines.append(f" im={im}: '{name}' (continuous), depth {depth_min:.1f}–{depth_max:.1f} m")
|
|
1505
|
+
model_context = '\n'.join(lines)
|
|
1506
|
+
except Exception:
|
|
1507
|
+
model_context = ''
|
|
1508
|
+
|
|
1509
|
+
system_prompt = (
|
|
1510
|
+
"You are a geophysics assistant for the INTEGRATE probabilistic inversion module. "
|
|
1511
|
+
"The user will provide a query dict in JSON format. "
|
|
1512
|
+
"Suggest a short figure title (max ~10 words) for the result this query produces. "
|
|
1513
|
+
"Use geological language and real model/class names where available. "
|
|
1514
|
+
"Write the title in title case. Do not start with 'Computes', 'Shows', or 'Displays'. "
|
|
1515
|
+
"Do not include JSON syntax. Reply with only the title — no preamble, no full stop."
|
|
1516
|
+
)
|
|
1517
|
+
if model_context:
|
|
1518
|
+
system_prompt += f"\n\n{model_context}"
|
|
1519
|
+
|
|
1520
|
+
try:
|
|
1521
|
+
response_obj = litellm.completion(
|
|
1522
|
+
model=model,
|
|
1523
|
+
max_tokens=128,
|
|
1524
|
+
messages=[
|
|
1525
|
+
{"role": "system", "content": system_prompt},
|
|
1526
|
+
{"role": "user", "content": json.dumps(query_dict)},
|
|
1527
|
+
],
|
|
1528
|
+
api_key=api_key,
|
|
1529
|
+
**_litellm_extra(model),
|
|
1530
|
+
)
|
|
1531
|
+
description = (response_obj.choices[0].message.content or '').strip()
|
|
1532
|
+
return description
|
|
1533
|
+
except Exception as e:
|
|
1534
|
+
if showInfo >= 1:
|
|
1535
|
+
print(f"[ig.title_from_json] LLM call failed — returning empty description. "
|
|
1536
|
+
f"Use ig.query_test_llm() to diagnose. (model='{model}')")
|
|
1537
|
+
if showInfo >= 2:
|
|
1538
|
+
print(f" Detail: {e}")
|
|
1539
|
+
return ''
|
|
1540
|
+
|
|
1541
|
+
|
|
1542
|
+
def query_test_llm(model='anthropic/claude-sonnet-4-6', api_key=None, verbose=1):
|
|
1543
|
+
"""
|
|
1544
|
+
Test whether a given LLM model and API key are working correctly.
|
|
1545
|
+
|
|
1546
|
+
Sends a minimal JSON-generation prompt and checks that the response is
|
|
1547
|
+
valid JSON. Prints a summary and returns a status dict.
|
|
1548
|
+
|
|
1549
|
+
Parameters
|
|
1550
|
+
----------
|
|
1551
|
+
model : str, optional
|
|
1552
|
+
LiteLLM model string (default: 'anthropic/claude-sonnet-4-6').
|
|
1553
|
+
api_key : str, optional
|
|
1554
|
+
Provider API key. If None, the relevant environment variable is used.
|
|
1555
|
+
verbose : int, optional
|
|
1556
|
+
0 = silent, 1 = summary only (default), 2 = full response included.
|
|
1557
|
+
|
|
1558
|
+
Returns
|
|
1559
|
+
-------
|
|
1560
|
+
result : dict
|
|
1561
|
+
Keys: 'ok' (bool), 'model', 'response' (str or None), 'error' (str or None).
|
|
1562
|
+
"""
|
|
1563
|
+
try:
|
|
1564
|
+
import litellm
|
|
1565
|
+
except ImportError:
|
|
1566
|
+
raise ImportError(
|
|
1567
|
+
"The 'litellm' package is required. Install it with: pip install litellm"
|
|
1568
|
+
)
|
|
1569
|
+
|
|
1570
|
+
test_prompt = 'Reply with exactly this JSON and nothing else: {"status": "ok"}'
|
|
1571
|
+
result = {'ok': False, 'model': model, 'response': None, 'error': None}
|
|
1572
|
+
|
|
1573
|
+
try:
|
|
1574
|
+
response_obj = litellm.completion(
|
|
1575
|
+
model=model,
|
|
1576
|
+
max_tokens=256,
|
|
1577
|
+
messages=[{"role": "user", "content": test_prompt}],
|
|
1578
|
+
api_key=api_key,
|
|
1579
|
+
**_litellm_extra(model),
|
|
1580
|
+
)
|
|
1581
|
+
msg = response_obj.choices[0].message
|
|
1582
|
+
raw = (msg.content or '').strip()
|
|
1583
|
+
# strip markdown fences if present
|
|
1584
|
+
if raw.startswith("```"):
|
|
1585
|
+
raw = raw.split("\n", 1)[-1]
|
|
1586
|
+
if raw.endswith("```"):
|
|
1587
|
+
raw = raw.rsplit("```", 1)[0].strip()
|
|
1588
|
+
result['response'] = raw
|
|
1589
|
+
|
|
1590
|
+
if not raw:
|
|
1591
|
+
reasoning = getattr(msg, 'reasoning_content', None)
|
|
1592
|
+
hint = " (model returned empty content — it may be a thinking model with all tokens used for reasoning)" if not reasoning else f" (content was empty; reasoning_content present, length {len(reasoning)})"
|
|
1593
|
+
raise ValueError(f"Empty response from model{hint}")
|
|
1594
|
+
|
|
1595
|
+
json.loads(raw) # validate JSON
|
|
1596
|
+
result['ok'] = True
|
|
1597
|
+
if verbose >= 1:
|
|
1598
|
+
print(f"[query_test_llm] OK — model '{model}' responded with valid JSON.")
|
|
1599
|
+
if verbose >= 2:
|
|
1600
|
+
print(f" Response: {raw}")
|
|
1601
|
+
except Exception as e:
|
|
1602
|
+
result['error'] = str(e)
|
|
1603
|
+
if verbose >= 1:
|
|
1604
|
+
print(f"[query_test_llm] FAILED — model '{model}'")
|
|
1605
|
+
print(f" Error: {e}")
|
|
1606
|
+
if verbose >= 2 and result['response']:
|
|
1607
|
+
print(f" Raw response: {result['response']}")
|
|
1608
|
+
|
|
1609
|
+
return result
|