uxarray-mcp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1110 @@
1
+ """Advanced analysis, remapping, export, and subsetting tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Any, Sequence
9
+
10
+ import numpy as np
11
+ import uxarray as ux
12
+ import xarray as xr
13
+ from matplotlib.path import Path as MplPath
14
+
15
+ from uxarray_mcp.domain.mesh import load_grid
16
+ from uxarray_mcp.provenance import attach_provenance
17
+ from uxarray_mcp.state import (
18
+ OperationTracker,
19
+ copy_artifact,
20
+ get_result,
21
+ get_session,
22
+ persist_result,
23
+ save_result,
24
+ summarize_array,
25
+ summarize_dataset,
26
+ summarize_grid,
27
+ write_dataarray_artifact,
28
+ write_dataset_artifact,
29
+ write_grid_artifact,
30
+ write_json_artifact,
31
+ )
32
+
33
+
34
+ def _resolve_paths(
35
+ *,
36
+ session_id: str | None = None,
37
+ dataset_handle: str | None = None,
38
+ grid_path: str | None = None,
39
+ data_path: str | None = None,
40
+ ) -> tuple[str, str | None]:
41
+ if dataset_handle is not None:
42
+ if session_id is None:
43
+ raise ValueError("session_id is required when dataset_handle is provided.")
44
+ session = get_session(session_id)
45
+ dataset = session["datasets"].get(dataset_handle)
46
+ if dataset is None:
47
+ raise FileNotFoundError(
48
+ f"Dataset handle {dataset_handle!r} not found in session {session_id!r}"
49
+ )
50
+ return dataset["grid_path"], dataset.get("data_path")
51
+ if grid_path is None:
52
+ raise ValueError("grid_path is required when dataset_handle is not provided.")
53
+ return grid_path, data_path
54
+
55
+
56
+ def _load_dataarray(
57
+ grid_path: str,
58
+ data_path: str,
59
+ variable_name: str | None,
60
+ ) -> tuple[ux.UxDataset, ux.UxDataArray, str]:
61
+ uxds = ux.open_dataset(grid_path, data_path)
62
+ selected = variable_name
63
+ if selected is None:
64
+ for name, var in uxds.data_vars.items():
65
+ if "n_face" in var.dims or "nCells" in var.dims:
66
+ selected = name
67
+ break
68
+ if selected is None:
69
+ raise ValueError("No face-centered variable found in dataset.")
70
+ if selected not in uxds.data_vars:
71
+ raise ValueError(
72
+ f"Variable '{selected}' not found. Available variables: {list(uxds.data_vars)}"
73
+ )
74
+ return uxds, uxds[selected], selected
75
+
76
+
77
+ def _persist_grid_result(
78
+ *,
79
+ grid: Any,
80
+ session_id: str | None,
81
+ name: str,
82
+ kind: str,
83
+ summary: dict[str, Any],
84
+ ) -> str:
85
+ result = persist_result(
86
+ kind=kind,
87
+ name=name,
88
+ summary=summary,
89
+ session_id=session_id,
90
+ )
91
+ artifact_path = write_grid_artifact(grid, result["result_handle"])
92
+ result["artifact_path"] = artifact_path
93
+ save_result(result)
94
+ return result["result_handle"]
95
+
96
+
97
+ def _persist_dataarray_result(
98
+ *,
99
+ data: Any,
100
+ session_id: str | None,
101
+ name: str,
102
+ kind: str,
103
+ summary: dict[str, Any],
104
+ metadata: dict[str, Any] | None = None,
105
+ ) -> str:
106
+ result = persist_result(
107
+ kind=kind,
108
+ name=name,
109
+ summary=summary,
110
+ session_id=session_id,
111
+ metadata=metadata,
112
+ )
113
+ artifact_path = write_dataarray_artifact(data, result["result_handle"])
114
+ result["artifact_path"] = artifact_path
115
+ save_result(result)
116
+ return result["result_handle"]
117
+
118
+
119
+ def _persist_dataset_result(
120
+ *,
121
+ dataset: Any,
122
+ session_id: str | None,
123
+ name: str,
124
+ kind: str,
125
+ summary: dict[str, Any],
126
+ metadata: dict[str, Any] | None = None,
127
+ ) -> str:
128
+ result = persist_result(
129
+ kind=kind,
130
+ name=name,
131
+ summary=summary,
132
+ session_id=session_id,
133
+ metadata=metadata,
134
+ )
135
+ artifact_path = write_dataset_artifact(dataset, result["result_handle"])
136
+ result["artifact_path"] = artifact_path
137
+ save_result(result)
138
+ return result["result_handle"]
139
+
140
+
141
+ def subset_bbox(
142
+ lon_bounds: list[float],
143
+ lat_bounds: list[float],
144
+ grid_path: str | None = None,
145
+ data_path: str | None = None,
146
+ variable_name: str | None = None,
147
+ session_id: str | None = None,
148
+ dataset_handle: str | None = None,
149
+ result_name: str | None = None,
150
+ ) -> dict[str, Any]:
151
+ """Subset a mesh or face-centered variable by longitude/latitude bounds."""
152
+ tracker = OperationTracker("subset_bbox", session_id=session_id)
153
+ tracker.stage("loading", "Loading grid and optional dataset.")
154
+ resolved_grid, resolved_data = _resolve_paths(
155
+ session_id=session_id,
156
+ dataset_handle=dataset_handle,
157
+ grid_path=grid_path,
158
+ data_path=data_path,
159
+ )
160
+ grid = load_grid(resolved_grid)
161
+ subset_grid = grid.subset.bounding_box(tuple(lon_bounds), tuple(lat_bounds))
162
+ result_handle = None
163
+ variable_summary = None
164
+
165
+ if resolved_data is not None:
166
+ tracker.stage("subsetting", "Applying bbox to variable data.")
167
+ _, uxda, selected = _load_dataarray(resolved_grid, resolved_data, variable_name)
168
+ subset_data = uxda.subset.bounding_box(tuple(lon_bounds), tuple(lat_bounds))
169
+ variable_summary = summarize_array(subset_data.to_xarray())
170
+ result_handle = _persist_dataarray_result(
171
+ data=subset_data,
172
+ session_id=session_id,
173
+ name=result_name or f"bbox:{selected}",
174
+ kind="subset_bbox",
175
+ summary=variable_summary,
176
+ metadata={
177
+ "selection_type": "bbox",
178
+ "lon_bounds": lon_bounds,
179
+ "lat_bounds": lat_bounds,
180
+ "variable_name": selected,
181
+ },
182
+ )
183
+ else:
184
+ tracker.stage("subsetting", "Applying bbox to grid only.")
185
+ result_handle = _persist_grid_result(
186
+ grid=subset_grid,
187
+ session_id=session_id,
188
+ name=result_name or "bbox_grid_subset",
189
+ kind="subset_bbox_grid",
190
+ summary=summarize_grid(subset_grid),
191
+ )
192
+
193
+ result: dict[str, Any] = {
194
+ "selection_type": "bbox",
195
+ "lon_bounds": lon_bounds,
196
+ "lat_bounds": lat_bounds,
197
+ "original_grid": summarize_grid(grid),
198
+ "subset_grid": summarize_grid(subset_grid),
199
+ "variable_summary": variable_summary,
200
+ "result_handle": result_handle,
201
+ }
202
+ next_steps = [
203
+ f'plot_mesh(grid_path="{resolved_grid}")',
204
+ f'export_to_netcdf("<output.nc>", result_handle="{result_handle}")',
205
+ ]
206
+ if resolved_data is not None:
207
+ next_steps.insert(
208
+ 0,
209
+ f'plot_variable("{resolved_grid}", "{resolved_data}", "<variable_name>")',
210
+ )
211
+ result["recommended_next_steps"] = next_steps
212
+ tracker.succeed("Bounding-box subset complete.")
213
+ result = attach_provenance(
214
+ result,
215
+ tool="subset_bbox",
216
+ inputs={
217
+ "lon_bounds": lon_bounds,
218
+ "lat_bounds": lat_bounds,
219
+ "grid_path": grid_path,
220
+ "data_path": data_path,
221
+ "variable_name": variable_name,
222
+ "session_id": session_id,
223
+ "dataset_handle": dataset_handle,
224
+ },
225
+ )
226
+ result["_provenance"]["operation_id"] = tracker.operation_id
227
+ return result
228
+
229
+
230
+ def subset_polygon(
231
+ polygon_lon_lat: list[list[float]],
232
+ grid_path: str | None = None,
233
+ data_path: str | None = None,
234
+ variable_name: str | None = None,
235
+ session_id: str | None = None,
236
+ dataset_handle: str | None = None,
237
+ result_name: str | None = None,
238
+ ) -> dict[str, Any]:
239
+ """Select faces whose centers fall within a polygon."""
240
+ if len(polygon_lon_lat) < 3:
241
+ raise ValueError("polygon_lon_lat must contain at least three points.")
242
+
243
+ tracker = OperationTracker("subset_polygon", session_id=session_id)
244
+ resolved_grid, resolved_data = _resolve_paths(
245
+ session_id=session_id,
246
+ dataset_handle=dataset_handle,
247
+ grid_path=grid_path,
248
+ data_path=data_path,
249
+ )
250
+ grid = load_grid(resolved_grid)
251
+ face_points = np.column_stack(
252
+ (np.asarray(grid.face_lon), np.asarray(grid.face_lat))
253
+ )
254
+ polygon = MplPath(np.asarray(polygon_lon_lat))
255
+ selected_indices = np.flatnonzero(polygon.contains_points(face_points))
256
+
257
+ variable_summary = None
258
+ result_handle = None
259
+ if resolved_data is not None:
260
+ tracker.stage("subsetting", "Selecting polygon faces from variable data.")
261
+ _, uxda, selected = _load_dataarray(resolved_grid, resolved_data, variable_name)
262
+ face_dim = "n_face" if "n_face" in uxda.dims else "nCells"
263
+ subset_data = uxda.isel({face_dim: selected_indices})
264
+ variable_summary = summarize_array(subset_data.to_xarray())
265
+ result_handle = _persist_dataarray_result(
266
+ data=subset_data,
267
+ session_id=session_id,
268
+ name=result_name or f"polygon:{selected}",
269
+ kind="subset_polygon",
270
+ summary=variable_summary,
271
+ metadata={
272
+ "selection_type": "polygon",
273
+ "polygon_lon_lat": polygon_lon_lat,
274
+ "variable_name": selected,
275
+ "selected_face_indices": selected_indices.tolist(),
276
+ },
277
+ )
278
+ else:
279
+ payload = {
280
+ "selection_type": "polygon",
281
+ "polygon_lon_lat": polygon_lon_lat,
282
+ "selected_face_indices": selected_indices.tolist(),
283
+ }
284
+ selection_record = persist_result(
285
+ kind="subset_polygon_indices",
286
+ name=result_name or "polygon_face_selection",
287
+ summary={
288
+ "selected_face_count": int(selected_indices.size),
289
+ "selected_face_indices_preview": selected_indices[:25].tolist(),
290
+ },
291
+ session_id=session_id,
292
+ )
293
+ artifact_path = write_json_artifact(payload, selection_record["result_handle"])
294
+ selection_record["artifact_path"] = artifact_path
295
+ save_result(selection_record)
296
+ result_handle = selection_record["result_handle"]
297
+
298
+ tracker.succeed("Polygon selection complete.")
299
+ result: dict[str, Any] = {
300
+ "selection_type": "polygon",
301
+ "selected_face_count": int(selected_indices.size),
302
+ "selected_face_indices_preview": selected_indices[:25].tolist(),
303
+ "variable_summary": variable_summary,
304
+ "result_handle": result_handle,
305
+ }
306
+ next_steps = [
307
+ f'plot_mesh(grid_path="{resolved_grid}")',
308
+ f'export_to_netcdf("<output.nc>", result_handle="{result_handle}")',
309
+ ]
310
+ if resolved_data is not None:
311
+ next_steps.insert(
312
+ 0,
313
+ f'plot_variable("{resolved_grid}", "{resolved_data}", "<variable_name>")',
314
+ )
315
+ result["recommended_next_steps"] = next_steps
316
+ result = attach_provenance(
317
+ result,
318
+ tool="subset_polygon",
319
+ inputs={
320
+ "polygon_lon_lat": polygon_lon_lat,
321
+ "grid_path": grid_path,
322
+ "data_path": data_path,
323
+ "variable_name": variable_name,
324
+ "session_id": session_id,
325
+ "dataset_handle": dataset_handle,
326
+ },
327
+ )
328
+ result["_provenance"]["operation_id"] = tracker.operation_id
329
+ return result
330
+
331
+
332
+ def extract_cross_section(
333
+ *,
334
+ latitude: float | None = None,
335
+ longitude: float | None = None,
336
+ grid_path: str | None = None,
337
+ data_path: str | None = None,
338
+ variable_name: str | None = None,
339
+ session_id: str | None = None,
340
+ dataset_handle: str | None = None,
341
+ result_name: str | None = None,
342
+ ) -> dict[str, Any]:
343
+ """Extract a constant-latitude or constant-longitude cross-section."""
344
+ if (latitude is None) == (longitude is None):
345
+ raise ValueError("Provide exactly one of latitude or longitude.")
346
+
347
+ tracker = OperationTracker("extract_cross_section", session_id=session_id)
348
+ resolved_grid, resolved_data = _resolve_paths(
349
+ session_id=session_id,
350
+ dataset_handle=dataset_handle,
351
+ grid_path=grid_path,
352
+ data_path=data_path,
353
+ )
354
+ grid = load_grid(resolved_grid)
355
+ if latitude is not None:
356
+ subset_grid = grid.subset.constant_latitude(latitude)
357
+ selection_type = "constant_latitude"
358
+ else:
359
+ subset_grid = grid.subset.constant_longitude(longitude)
360
+ selection_type = "constant_longitude"
361
+
362
+ result_handle = None
363
+ variable_summary = None
364
+ if resolved_data is not None:
365
+ _, uxda, selected = _load_dataarray(resolved_grid, resolved_data, variable_name)
366
+ if latitude is not None:
367
+ subset_data = uxda.subset.constant_latitude(latitude)
368
+ else:
369
+ subset_data = uxda.subset.constant_longitude(longitude)
370
+ variable_summary = summarize_array(subset_data.to_xarray())
371
+ result_handle = _persist_dataarray_result(
372
+ data=subset_data,
373
+ session_id=session_id,
374
+ name=result_name or f"{selection_type}:{selected}",
375
+ kind="cross_section",
376
+ summary=variable_summary,
377
+ metadata={
378
+ "selection_type": selection_type,
379
+ "latitude": latitude,
380
+ "longitude": longitude,
381
+ "variable_name": selected,
382
+ },
383
+ )
384
+ else:
385
+ result_handle = _persist_grid_result(
386
+ grid=subset_grid,
387
+ session_id=session_id,
388
+ name=result_name or selection_type,
389
+ kind="cross_section_grid",
390
+ summary=summarize_grid(subset_grid),
391
+ )
392
+
393
+ tracker.succeed("Cross-section extraction complete.")
394
+ result: dict[str, Any] = {
395
+ "selection_type": selection_type,
396
+ "latitude": latitude,
397
+ "longitude": longitude,
398
+ "subset_grid": summarize_grid(subset_grid),
399
+ "variable_summary": variable_summary,
400
+ "result_handle": result_handle,
401
+ }
402
+ next_steps = [
403
+ f'plot_mesh(grid_path="{resolved_grid}")',
404
+ f'export_to_netcdf("<output.nc>", result_handle="{result_handle}")',
405
+ ]
406
+ if resolved_data is not None:
407
+ next_steps.insert(
408
+ 0,
409
+ f'calculate_zonal_mean("{resolved_grid}", "{resolved_data}", "<variable_name>")',
410
+ )
411
+ result["recommended_next_steps"] = next_steps
412
+ result = attach_provenance(
413
+ result,
414
+ tool="extract_cross_section",
415
+ inputs={
416
+ "latitude": latitude,
417
+ "longitude": longitude,
418
+ "grid_path": grid_path,
419
+ "data_path": data_path,
420
+ "variable_name": variable_name,
421
+ "session_id": session_id,
422
+ "dataset_handle": dataset_handle,
423
+ },
424
+ )
425
+ result["_provenance"]["operation_id"] = tracker.operation_id
426
+ return result
427
+
428
+
429
+ def _load_comparison_arrays(
430
+ *,
431
+ grid_path: str | None,
432
+ data_path_a: str,
433
+ data_path_b: str,
434
+ variable_name: str,
435
+ ) -> tuple[xr.DataArray, xr.DataArray]:
436
+ if grid_path:
437
+ first = ux.open_dataset(grid_path, data_path_a)[variable_name].to_xarray()
438
+ second = ux.open_dataset(grid_path, data_path_b)[variable_name].to_xarray()
439
+ else:
440
+ first = xr.open_dataset(data_path_a)[variable_name]
441
+ second = xr.open_dataset(data_path_b)[variable_name]
442
+ if first.shape != second.shape or first.dims != second.dims:
443
+ raise ValueError(
444
+ "Comparison requires same-grid, same-shape variables in v1. "
445
+ f"Got dims {first.dims}/{second.dims} and shapes {first.shape}/{second.shape}."
446
+ )
447
+ return first, second
448
+
449
+
450
+ def _pattern_correlation(first: xr.DataArray, second: xr.DataArray) -> float:
451
+ a = np.asarray(first.values).ravel()
452
+ b = np.asarray(second.values).ravel()
453
+ mask = np.isfinite(a) & np.isfinite(b)
454
+ if not mask.any():
455
+ raise ValueError("No finite overlapping values available for correlation.")
456
+ a = a[mask] - np.mean(a[mask])
457
+ b = b[mask] - np.mean(b[mask])
458
+ denom = np.linalg.norm(a) * np.linalg.norm(b)
459
+ if denom == 0:
460
+ return 0.0
461
+ return float(np.dot(a, b) / denom)
462
+
463
+
464
+ def compare_fields(
465
+ variable_name: str,
466
+ data_path_a: str,
467
+ data_path_b: str,
468
+ grid_path: str | None = None,
469
+ session_id: str | None = None,
470
+ result_name: str | None = None,
471
+ ) -> dict[str, Any]:
472
+ """Compare two same-grid fields and compute core difference metrics."""
473
+ tracker = OperationTracker("compare_fields", session_id=session_id)
474
+ tracker.stage("loading", "Loading comparison fields.")
475
+ first, second = _load_comparison_arrays(
476
+ grid_path=grid_path,
477
+ data_path_a=data_path_a,
478
+ data_path_b=data_path_b,
479
+ variable_name=variable_name,
480
+ )
481
+ tracker.stage("comparing", "Computing field-to-field metrics.")
482
+ diff = first - second
483
+ bias = float(diff.mean(skipna=True).item())
484
+ rmse = float(np.sqrt((diff**2).mean(skipna=True)).item())
485
+ pattern = _pattern_correlation(first, second)
486
+ result_handle = _persist_dataarray_result(
487
+ data=diff,
488
+ session_id=session_id,
489
+ name=result_name or f"diff:{variable_name}",
490
+ kind="comparison_difference",
491
+ summary=summarize_array(diff),
492
+ metadata={
493
+ "variable_name": variable_name,
494
+ "data_path_a": data_path_a,
495
+ "data_path_b": data_path_b,
496
+ },
497
+ )
498
+ tracker.succeed("Field comparison complete.")
499
+ result: dict[str, Any] = {
500
+ "variable_name": variable_name,
501
+ "alignment_summary": {
502
+ "same_dims": True,
503
+ "dims": list(first.dims),
504
+ "shape": list(first.shape),
505
+ "grid_path": grid_path,
506
+ },
507
+ "metrics": {
508
+ "bias": bias,
509
+ "rmse": rmse,
510
+ "pattern_correlation": pattern,
511
+ "max_abs_difference": float(np.nanmax(np.abs(diff.values))),
512
+ },
513
+ "difference_field_handle": result_handle,
514
+ }
515
+ result = attach_provenance(
516
+ result,
517
+ tool="compare_fields",
518
+ inputs={
519
+ "variable_name": variable_name,
520
+ "data_path_a": data_path_a,
521
+ "data_path_b": data_path_b,
522
+ "grid_path": grid_path,
523
+ "session_id": session_id,
524
+ },
525
+ )
526
+ result["_provenance"]["operation_id"] = tracker.operation_id
527
+ return result
528
+
529
+
530
+ def calculate_bias(
531
+ variable_name: str,
532
+ data_path_a: str,
533
+ data_path_b: str,
534
+ grid_path: str | None = None,
535
+ ) -> dict[str, Any]:
536
+ """Calculate the mean bias between two same-grid fields."""
537
+ comparison = compare_fields(
538
+ variable_name=variable_name,
539
+ data_path_a=data_path_a,
540
+ data_path_b=data_path_b,
541
+ grid_path=grid_path,
542
+ )
543
+ return attach_provenance(
544
+ {"variable_name": variable_name, "bias": comparison["metrics"]["bias"]},
545
+ tool="calculate_bias",
546
+ inputs={
547
+ "variable_name": variable_name,
548
+ "data_path_a": data_path_a,
549
+ "data_path_b": data_path_b,
550
+ "grid_path": grid_path,
551
+ },
552
+ )
553
+
554
+
555
+ def calculate_rmse(
556
+ variable_name: str,
557
+ data_path_a: str,
558
+ data_path_b: str,
559
+ grid_path: str | None = None,
560
+ ) -> dict[str, Any]:
561
+ """Calculate RMSE between two same-grid fields."""
562
+ comparison = compare_fields(
563
+ variable_name=variable_name,
564
+ data_path_a=data_path_a,
565
+ data_path_b=data_path_b,
566
+ grid_path=grid_path,
567
+ )
568
+ return attach_provenance(
569
+ {"variable_name": variable_name, "rmse": comparison["metrics"]["rmse"]},
570
+ tool="calculate_rmse",
571
+ inputs={
572
+ "variable_name": variable_name,
573
+ "data_path_a": data_path_a,
574
+ "data_path_b": data_path_b,
575
+ "grid_path": grid_path,
576
+ },
577
+ )
578
+
579
+
580
+ def calculate_pattern_correlation(
581
+ variable_name: str,
582
+ data_path_a: str,
583
+ data_path_b: str,
584
+ grid_path: str | None = None,
585
+ ) -> dict[str, Any]:
586
+ """Calculate pattern correlation between two same-grid fields."""
587
+ comparison = compare_fields(
588
+ variable_name=variable_name,
589
+ data_path_a=data_path_a,
590
+ data_path_b=data_path_b,
591
+ grid_path=grid_path,
592
+ )
593
+ return attach_provenance(
594
+ {
595
+ "variable_name": variable_name,
596
+ "pattern_correlation": comparison["metrics"]["pattern_correlation"],
597
+ },
598
+ tool="calculate_pattern_correlation",
599
+ inputs={
600
+ "variable_name": variable_name,
601
+ "data_path_a": data_path_a,
602
+ "data_path_b": data_path_b,
603
+ "grid_path": grid_path,
604
+ },
605
+ )
606
+
607
+
608
+ def remap_variable(
609
+ target_grid_path: str,
610
+ variable_name: str,
611
+ grid_path: str | None = None,
612
+ data_path: str | None = None,
613
+ method: str = "nearest_neighbor",
614
+ remap_to: str = "faces",
615
+ session_id: str | None = None,
616
+ dataset_handle: str | None = None,
617
+ result_name: str | None = None,
618
+ ) -> dict[str, Any]:
619
+ """Remap a face-centered variable onto a target grid."""
620
+ tracker = OperationTracker("remap_variable", session_id=session_id)
621
+ resolved_grid, resolved_data = _resolve_paths(
622
+ session_id=session_id,
623
+ dataset_handle=dataset_handle,
624
+ grid_path=grid_path,
625
+ data_path=data_path,
626
+ )
627
+ if resolved_data is None:
628
+ raise ValueError("data_path is required for remapping.")
629
+
630
+ source_grid = load_grid(resolved_grid)
631
+ target_grid = load_grid(target_grid_path)
632
+ _, uxda, selected = _load_dataarray(resolved_grid, resolved_data, variable_name)
633
+
634
+ if not hasattr(uxda.remap, method):
635
+ raise ValueError(
636
+ f"Unsupported remap method {method!r}. Choose from "
637
+ "'nearest_neighbor', 'inverse_distance_weighted', or 'bilinear'."
638
+ )
639
+ tracker.stage("remapping", f"Running {method} remap.")
640
+ remapped = getattr(uxda.remap, method)(target_grid, remap_to=remap_to)
641
+ result_handle = _persist_dataarray_result(
642
+ data=remapped,
643
+ session_id=session_id,
644
+ name=result_name or f"remap:{selected}",
645
+ kind="remapped_variable",
646
+ summary=summarize_array(remapped.to_xarray()),
647
+ metadata={
648
+ "source_grid": resolved_grid,
649
+ "target_grid": target_grid_path,
650
+ "method": method,
651
+ "remap_to": remap_to,
652
+ "variable_name": selected,
653
+ },
654
+ )
655
+ tracker.succeed("Variable remap complete.")
656
+ result: dict[str, Any] = {
657
+ "variable_name": selected,
658
+ "method": method,
659
+ "remap_to": remap_to,
660
+ "source_grid": summarize_grid(source_grid),
661
+ "target_grid": summarize_grid(target_grid),
662
+ "result_handle": result_handle,
663
+ }
664
+ result = attach_provenance(
665
+ result,
666
+ tool="remap_variable",
667
+ inputs={
668
+ "target_grid_path": target_grid_path,
669
+ "variable_name": variable_name,
670
+ "grid_path": grid_path,
671
+ "data_path": data_path,
672
+ "method": method,
673
+ "remap_to": remap_to,
674
+ "session_id": session_id,
675
+ "dataset_handle": dataset_handle,
676
+ },
677
+ selected_variable=selected,
678
+ )
679
+ result["_provenance"]["operation_id"] = tracker.operation_id
680
+ return result
681
+
682
+
683
+ def regrid_dataset(
684
+ target_grid_path: str,
685
+ grid_path: str | None = None,
686
+ data_path: str | None = None,
687
+ variable_names: list[str] | None = None,
688
+ method: str = "nearest_neighbor",
689
+ remap_to: str = "faces",
690
+ session_id: str | None = None,
691
+ dataset_handle: str | None = None,
692
+ result_name: str | None = None,
693
+ ) -> dict[str, Any]:
694
+ """Remap all selected face-centered variables in a dataset onto a target grid."""
695
+ tracker = OperationTracker("regrid_dataset", session_id=session_id)
696
+ resolved_grid, resolved_data = _resolve_paths(
697
+ session_id=session_id,
698
+ dataset_handle=dataset_handle,
699
+ grid_path=grid_path,
700
+ data_path=data_path,
701
+ )
702
+ if resolved_data is None:
703
+ raise ValueError("data_path is required to regrid a dataset.")
704
+
705
+ uxds = ux.open_dataset(resolved_grid, resolved_data)
706
+ target_grid = load_grid(target_grid_path)
707
+ if not hasattr(uxds[next(iter(uxds.data_vars))].remap, method):
708
+ raise ValueError(
709
+ f"Unsupported remap method {method!r}. Choose from "
710
+ "'nearest_neighbor', 'inverse_distance_weighted', or 'bilinear'."
711
+ )
712
+ variables = variable_names or [
713
+ name
714
+ for name, var in uxds.data_vars.items()
715
+ if "n_face" in var.dims or "nCells" in var.dims
716
+ ]
717
+ if not variables:
718
+ raise ValueError("No face-centered variables available for remapping.")
719
+ dataset_parts = []
720
+ for name in variables:
721
+ tracker.stage("remapping", f"Remapping variable {name}")
722
+ remapped = getattr(uxds[name].remap, method)(target_grid, remap_to=remap_to)
723
+ dataset_parts.append(remapped.to_dataset(name=name).to_xarray())
724
+ remapped_dataset = xr.merge(dataset_parts)
725
+ result_handle = _persist_dataset_result(
726
+ dataset=remapped_dataset,
727
+ session_id=session_id,
728
+ name=result_name or "regridded_dataset",
729
+ kind="regridded_dataset",
730
+ summary=summarize_dataset(remapped_dataset),
731
+ metadata={
732
+ "source_grid": resolved_grid,
733
+ "target_grid": target_grid_path,
734
+ "method": method,
735
+ "variables": variables,
736
+ },
737
+ )
738
+ tracker.succeed("Dataset regridding complete.")
739
+ result: dict[str, Any] = {
740
+ "method": method,
741
+ "variables": variables,
742
+ "target_grid": summarize_grid(target_grid),
743
+ "result_handle": result_handle,
744
+ }
745
+ result = attach_provenance(
746
+ result,
747
+ tool="regrid_dataset",
748
+ inputs={
749
+ "target_grid_path": target_grid_path,
750
+ "grid_path": grid_path,
751
+ "data_path": data_path,
752
+ "variable_names": variable_names,
753
+ "method": method,
754
+ "remap_to": remap_to,
755
+ "session_id": session_id,
756
+ "dataset_handle": dataset_handle,
757
+ },
758
+ )
759
+ result["_provenance"]["operation_id"] = tracker.operation_id
760
+ return result
761
+
762
+
763
+ def calculate_temporal_mean(
764
+ data_path: str,
765
+ variable_name: str,
766
+ groupby: str | None = None,
767
+ session_id: str | None = None,
768
+ result_name: str | None = None,
769
+ ) -> dict[str, Any]:
770
+ """Calculate a temporal mean from a time-aware dataset."""
771
+ tracker = OperationTracker("calculate_temporal_mean", session_id=session_id)
772
+ ds = xr.open_dataset(data_path)
773
+ if variable_name not in ds:
774
+ raise ValueError(f"Variable '{variable_name}' not found in {data_path}.")
775
+ data = ds[variable_name]
776
+ if "time" not in data.dims:
777
+ raise ValueError("Temporal mean requires a variable with a 'time' dimension.")
778
+ tracker.stage("aggregating", "Computing temporal mean.")
779
+ if groupby is None:
780
+ result_data = data.mean(dim="time")
781
+ else:
782
+ result_data = data.groupby(f"time.{groupby}").mean()
783
+ result_handle = _persist_dataarray_result(
784
+ data=result_data,
785
+ session_id=session_id,
786
+ name=result_name or f"temporal_mean:{variable_name}",
787
+ kind="temporal_mean",
788
+ summary=summarize_array(result_data),
789
+ metadata={"groupby": groupby, "variable_name": variable_name},
790
+ )
791
+ tracker.succeed("Temporal mean complete.")
792
+ result: dict[str, Any] = {
793
+ "variable_name": variable_name,
794
+ "groupby": groupby,
795
+ "summary": summarize_array(result_data),
796
+ "result_handle": result_handle,
797
+ }
798
+ result = attach_provenance(
799
+ result,
800
+ tool="calculate_temporal_mean",
801
+ inputs={
802
+ "data_path": data_path,
803
+ "variable_name": variable_name,
804
+ "groupby": groupby,
805
+ "session_id": session_id,
806
+ },
807
+ )
808
+ result["_provenance"]["operation_id"] = tracker.operation_id
809
+ return result
810
+
811
+
812
+ def calculate_anomaly(
813
+ data_path: str,
814
+ variable_name: str,
815
+ baseline: str = "temporal_mean",
816
+ session_id: str | None = None,
817
+ result_name: str | None = None,
818
+ ) -> dict[str, Any]:
819
+ """Calculate anomalies relative to the temporal mean baseline."""
820
+ if baseline != "temporal_mean":
821
+ raise ValueError("v1 supports only baseline='temporal_mean'.")
822
+ tracker = OperationTracker("calculate_anomaly", session_id=session_id)
823
+ ds = xr.open_dataset(data_path)
824
+ if variable_name not in ds:
825
+ raise ValueError(f"Variable '{variable_name}' not found in {data_path}.")
826
+ data = ds[variable_name]
827
+ if "time" not in data.dims:
828
+ raise ValueError("Anomaly calculation requires a 'time' dimension.")
829
+ tracker.stage("aggregating", "Computing temporal baseline and anomalies.")
830
+ anomaly = data - data.mean(dim="time")
831
+ result_handle = _persist_dataarray_result(
832
+ data=anomaly,
833
+ session_id=session_id,
834
+ name=result_name or f"anomaly:{variable_name}",
835
+ kind="anomaly",
836
+ summary=summarize_array(anomaly),
837
+ metadata={"baseline": baseline, "variable_name": variable_name},
838
+ )
839
+ tracker.succeed("Anomaly calculation complete.")
840
+ result: dict[str, Any] = {
841
+ "variable_name": variable_name,
842
+ "baseline": baseline,
843
+ "summary": summarize_array(anomaly),
844
+ "result_handle": result_handle,
845
+ }
846
+ result = attach_provenance(
847
+ result,
848
+ tool="calculate_anomaly",
849
+ inputs={
850
+ "data_path": data_path,
851
+ "variable_name": variable_name,
852
+ "baseline": baseline,
853
+ "session_id": session_id,
854
+ },
855
+ )
856
+ result["_provenance"]["operation_id"] = tracker.operation_id
857
+ return result
858
+
859
+
860
+ def _load_ensemble(variable_name: str, data_paths: Sequence[str]) -> xr.DataArray:
861
+ datasets = []
862
+ for data_path in data_paths:
863
+ ds = xr.open_dataset(data_path)
864
+ if variable_name not in ds:
865
+ raise ValueError(f"Variable '{variable_name}' not found in {data_path}.")
866
+ datasets.append(ds[variable_name])
867
+ reference = datasets[0]
868
+ for dataset in datasets[1:]:
869
+ if dataset.shape != reference.shape or dataset.dims != reference.dims:
870
+ raise ValueError("All ensemble members must share dims and shape in v1.")
871
+ return xr.concat(datasets, dim="ensemble_member")
872
+
873
+
874
+ def calculate_ensemble_mean(
875
+ variable_name: str,
876
+ data_paths: list[str],
877
+ session_id: str | None = None,
878
+ result_name: str | None = None,
879
+ ) -> dict[str, Any]:
880
+ """Calculate an ensemble mean across multiple files."""
881
+ tracker = OperationTracker("calculate_ensemble_mean", session_id=session_id)
882
+ ensemble = _load_ensemble(variable_name, data_paths)
883
+ result_data = ensemble.mean(dim="ensemble_member")
884
+ result_handle = _persist_dataarray_result(
885
+ data=result_data,
886
+ session_id=session_id,
887
+ name=result_name or f"ensemble_mean:{variable_name}",
888
+ kind="ensemble_mean",
889
+ summary=summarize_array(result_data),
890
+ metadata={"member_count": len(data_paths), "variable_name": variable_name},
891
+ )
892
+ tracker.succeed("Ensemble mean complete.")
893
+ result: dict[str, Any] = {
894
+ "variable_name": variable_name,
895
+ "member_count": len(data_paths),
896
+ "summary": summarize_array(result_data),
897
+ "result_handle": result_handle,
898
+ }
899
+ result = attach_provenance(
900
+ result,
901
+ tool="calculate_ensemble_mean",
902
+ inputs={
903
+ "variable_name": variable_name,
904
+ "data_paths": data_paths,
905
+ "session_id": session_id,
906
+ },
907
+ )
908
+ result["_provenance"]["operation_id"] = tracker.operation_id
909
+ return result
910
+
911
+
912
+ def calculate_ensemble_spread(
913
+ variable_name: str,
914
+ data_paths: list[str],
915
+ session_id: str | None = None,
916
+ result_name: str | None = None,
917
+ ) -> dict[str, Any]:
918
+ """Calculate ensemble spread as standard deviation across members."""
919
+ tracker = OperationTracker("calculate_ensemble_spread", session_id=session_id)
920
+ ensemble = _load_ensemble(variable_name, data_paths)
921
+ result_data = ensemble.std(dim="ensemble_member")
922
+ result_handle = _persist_dataarray_result(
923
+ data=result_data,
924
+ session_id=session_id,
925
+ name=result_name or f"ensemble_spread:{variable_name}",
926
+ kind="ensemble_spread",
927
+ summary=summarize_array(result_data),
928
+ metadata={"member_count": len(data_paths), "variable_name": variable_name},
929
+ )
930
+ tracker.succeed("Ensemble spread complete.")
931
+ result: dict[str, Any] = {
932
+ "variable_name": variable_name,
933
+ "member_count": len(data_paths),
934
+ "summary": summarize_array(result_data),
935
+ "result_handle": result_handle,
936
+ }
937
+ result = attach_provenance(
938
+ result,
939
+ tool="calculate_ensemble_spread",
940
+ inputs={
941
+ "variable_name": variable_name,
942
+ "data_paths": data_paths,
943
+ "session_id": session_id,
944
+ },
945
+ )
946
+ result["_provenance"]["operation_id"] = tracker.operation_id
947
+ return result
948
+
949
+
950
+ def export_to_netcdf(
951
+ output_path: str,
952
+ result_handle: str | None = None,
953
+ session_id: str | None = None,
954
+ dataset_handle: str | None = None,
955
+ variable_name: str | None = None,
956
+ ) -> dict[str, Any]:
957
+ """Export a persisted result or registered dataset to NetCDF."""
958
+ tracker = OperationTracker("export_to_netcdf", session_id=session_id)
959
+ destination = Path(output_path)
960
+ destination.parent.mkdir(parents=True, exist_ok=True)
961
+
962
+ if result_handle is not None:
963
+ stored_result = get_result(result_handle)
964
+ artifact_path = stored_result.get("artifact_path")
965
+ if artifact_path is None:
966
+ raise ValueError("Result handle has no exportable artifact.")
967
+ written = copy_artifact(artifact_path, output_path)
968
+ summary = stored_result["summary"]
969
+ elif dataset_handle is not None:
970
+ if session_id is None:
971
+ raise ValueError("session_id is required when exporting a dataset_handle.")
972
+ dataset = get_session(session_id)["datasets"].get(dataset_handle)
973
+ if dataset is None:
974
+ raise FileNotFoundError(
975
+ f"Dataset handle {dataset_handle!r} not found in session {session_id!r}"
976
+ )
977
+ if dataset.get("data_path") is None:
978
+ raise ValueError("Dataset handle does not include a data file to export.")
979
+ if variable_name is None:
980
+ written = copy_artifact(dataset["data_path"], output_path)
981
+ summary = {"copied_source": dataset["data_path"]}
982
+ else:
983
+ ds = xr.open_dataset(dataset["data_path"])
984
+ if variable_name not in ds:
985
+ raise ValueError(
986
+ f"Variable '{variable_name}' not found in {dataset['data_path']}."
987
+ )
988
+ ds[[variable_name]].to_netcdf(output_path)
989
+ written = str(destination)
990
+ summary = summarize_dataset(ds[[variable_name]])
991
+ else:
992
+ raise ValueError("Provide either result_handle or dataset_handle.")
993
+
994
+ tracker.succeed("NetCDF export complete.")
995
+ response: dict[str, Any] = {"output_path": written, "summary": summary}
996
+ response = attach_provenance(
997
+ response,
998
+ tool="export_to_netcdf",
999
+ inputs={
1000
+ "output_path": output_path,
1001
+ "result_handle": result_handle,
1002
+ "session_id": session_id,
1003
+ "dataset_handle": dataset_handle,
1004
+ "variable_name": variable_name,
1005
+ },
1006
+ )
1007
+ response["_provenance"]["operation_id"] = tracker.operation_id
1008
+ return response
1009
+
1010
+
1011
+ def export_to_csv(
1012
+ output_path: str,
1013
+ result_handle: str | None = None,
1014
+ session_id: str | None = None,
1015
+ dataset_handle: str | None = None,
1016
+ variable_name: str | None = None,
1017
+ ) -> dict[str, Any]:
1018
+ """Export a persisted result or registered dataset to CSV."""
1019
+ tracker = OperationTracker("export_to_csv", session_id=session_id)
1020
+ destination = Path(output_path)
1021
+ destination.parent.mkdir(parents=True, exist_ok=True)
1022
+
1023
+ if result_handle is not None:
1024
+ stored_result = get_result(result_handle)
1025
+ artifact_path = stored_result.get("artifact_path")
1026
+ if artifact_path is None:
1027
+ raise ValueError("Result handle has no exportable artifact.")
1028
+ artifact = Path(artifact_path)
1029
+ if artifact.suffix == ".json":
1030
+ payload = json.loads(artifact.read_text())
1031
+ with destination.open("w", newline="") as handle:
1032
+ writer = csv.DictWriter(handle, fieldnames=sorted(payload))
1033
+ writer.writeheader()
1034
+ writer.writerow(payload)
1035
+ summary = {"rows_written": 1}
1036
+ else:
1037
+ try:
1038
+ data = xr.open_dataarray(artifact)
1039
+ frame = data.to_dataframe(name=data.name or "value").reset_index()
1040
+ except ValueError:
1041
+ dataset_artifact = xr.open_dataset(artifact)
1042
+ frame = dataset_artifact.to_dataframe().reset_index()
1043
+ frame.to_csv(destination, index=False)
1044
+ summary = {"rows_written": int(len(frame))}
1045
+ elif dataset_handle is not None:
1046
+ if session_id is None:
1047
+ raise ValueError("session_id is required when exporting a dataset_handle.")
1048
+ dataset = get_session(session_id)["datasets"].get(dataset_handle)
1049
+ if dataset is None:
1050
+ raise FileNotFoundError(
1051
+ f"Dataset handle {dataset_handle!r} not found in session {session_id!r}"
1052
+ )
1053
+ if dataset.get("data_path") is None:
1054
+ raise ValueError("Dataset handle does not include a data file to export.")
1055
+ ds = xr.open_dataset(dataset["data_path"])
1056
+ if variable_name is not None and variable_name not in ds:
1057
+ raise ValueError(
1058
+ f"Variable '{variable_name}' not found in {dataset['data_path']}."
1059
+ )
1060
+ export_ds = ds if variable_name is None else ds[[variable_name]]
1061
+ frame = export_ds.to_dataframe().reset_index()
1062
+ frame.to_csv(destination, index=False)
1063
+ summary = {"rows_written": int(len(frame))}
1064
+ else:
1065
+ raise ValueError("Provide either result_handle or dataset_handle.")
1066
+
1067
+ tracker.succeed("CSV export complete.")
1068
+ response: dict[str, Any] = {"output_path": str(destination), "summary": summary}
1069
+ response = attach_provenance(
1070
+ response,
1071
+ tool="export_to_csv",
1072
+ inputs={
1073
+ "output_path": output_path,
1074
+ "result_handle": result_handle,
1075
+ "session_id": session_id,
1076
+ "dataset_handle": dataset_handle,
1077
+ "variable_name": variable_name,
1078
+ },
1079
+ )
1080
+ response["_provenance"]["operation_id"] = tracker.operation_id
1081
+ return response
1082
+
1083
+
1084
+ def write_result(
1085
+ output_path: str,
1086
+ format: str,
1087
+ result_handle: str | None = None,
1088
+ session_id: str | None = None,
1089
+ dataset_handle: str | None = None,
1090
+ variable_name: str | None = None,
1091
+ ) -> dict[str, Any]:
1092
+ """Write a result or dataset using the requested output format."""
1093
+ normalized = format.lower()
1094
+ if normalized == "netcdf":
1095
+ return export_to_netcdf(
1096
+ output_path=output_path,
1097
+ result_handle=result_handle,
1098
+ session_id=session_id,
1099
+ dataset_handle=dataset_handle,
1100
+ variable_name=variable_name,
1101
+ )
1102
+ if normalized == "csv":
1103
+ return export_to_csv(
1104
+ output_path=output_path,
1105
+ result_handle=result_handle,
1106
+ session_id=session_id,
1107
+ dataset_handle=dataset_handle,
1108
+ variable_name=variable_name,
1109
+ )
1110
+ raise ValueError("Unsupported format. Choose 'netcdf' or 'csv'.")