uxarray-mcp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
uxarray_mcp/state.py ADDED
@@ -0,0 +1,521 @@
1
+ """Persistent state, workflow, and result storage for UXarray MCP tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import shutil
8
+ import uuid
9
+ from contextlib import suppress
10
+ from dataclasses import dataclass, field
11
+ from datetime import datetime, timezone
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import xarray as xr
16
+
17
+
18
+ def _now_utc() -> str:
19
+ return datetime.now(timezone.utc).isoformat()
20
+
21
+
22
+ def _state_root() -> Path:
23
+ configured = os.getenv("UXARRAY_MCP_STATE_DIR")
24
+ if configured:
25
+ return Path(configured).expanduser()
26
+ return Path.home() / ".uxarray_mcp_server"
27
+
28
+
29
+ def _ensure_dir(path: Path) -> Path:
30
+ path.mkdir(parents=True, exist_ok=True)
31
+ return path
32
+
33
+
34
+ def _json_safe(value: Any) -> Any:
35
+ if isinstance(value, Path):
36
+ return str(value)
37
+ if isinstance(value, datetime):
38
+ return value.isoformat()
39
+ if isinstance(value, dict):
40
+ return {str(k): _json_safe(v) for k, v in value.items()}
41
+ if isinstance(value, (list, tuple, set)):
42
+ return [_json_safe(v) for v in value]
43
+ with suppress(Exception):
44
+ import numpy as np
45
+
46
+ if isinstance(value, np.generic):
47
+ return value.item()
48
+ if isinstance(value, np.ndarray):
49
+ return value.tolist()
50
+ return value
51
+
52
+
53
+ def _write_json(path: Path, payload: dict[str, Any]) -> None:
54
+ path.write_text(json.dumps(_json_safe(payload), indent=2, sort_keys=True))
55
+
56
+
57
+ def _read_json(path: Path) -> dict[str, Any]:
58
+ return json.loads(path.read_text())
59
+
60
+
61
+ def _record_path(kind: str, record_id: str) -> Path:
62
+ return _ensure_dir(_state_root() / kind) / f"{record_id}.json"
63
+
64
+
65
+ def _artifacts_dir() -> Path:
66
+ return _ensure_dir(_state_root() / "artifacts")
67
+
68
+
69
+ def _new_id(prefix: str) -> str:
70
+ return f"{prefix}_{uuid.uuid4().hex[:12]}"
71
+
72
+
73
+ def _result_path(result_id: str, suffix: str) -> Path:
74
+ return _artifacts_dir() / f"{result_id}{suffix}"
75
+
76
+
77
+ def _sanitize_netcdf_attr_value(value: Any) -> Any:
78
+ if isinstance(value, bool):
79
+ return int(value)
80
+ if isinstance(value, (str, int, float)):
81
+ return value
82
+ if value is None:
83
+ return ""
84
+ return str(value)
85
+
86
+
87
+ def _sanitize_netcdf_attrs(data: Any) -> Any:
88
+ cleaned = data.copy(deep=False)
89
+ cleaned.attrs = {
90
+ str(key): _sanitize_netcdf_attr_value(value)
91
+ for key, value in getattr(data, "attrs", {}).items()
92
+ }
93
+ return cleaned
94
+
95
+
96
+ def summarize_grid(grid: Any) -> dict[str, Any]:
97
+ return {
98
+ "format": str(getattr(grid, "source_grid_spec", "Unknown")),
99
+ "n_face": int(getattr(grid, "n_face", 0)),
100
+ "n_node": int(getattr(grid, "n_node", 0)),
101
+ "n_edge": int(getattr(grid, "n_edge", 0)),
102
+ }
103
+
104
+
105
+ def summarize_array(data: xr.DataArray) -> dict[str, Any]:
106
+ values = data.values
107
+ summary: dict[str, Any] = {
108
+ "dims": list(data.dims),
109
+ "shape": list(data.shape),
110
+ "dtype": str(data.dtype),
111
+ "name": str(data.name) if data.name is not None else None,
112
+ }
113
+ if values.size > 0:
114
+ with suppress(Exception):
115
+ summary["min"] = float(values.min())
116
+ summary["max"] = float(values.max())
117
+ summary["mean"] = float(values.mean())
118
+ return summary
119
+
120
+
121
+ def summarize_dataset(dataset: xr.Dataset) -> dict[str, Any]:
122
+ return {
123
+ "variables": list(dataset.data_vars),
124
+ "dims": {k: int(v) for k, v in dataset.sizes.items()},
125
+ }
126
+
127
+
128
+ def create_session(name: str | None = None) -> dict[str, Any]:
129
+ session_id = _new_id("session")
130
+ record: dict[str, Any] = {
131
+ "session_id": session_id,
132
+ "name": name,
133
+ "created_at": _now_utc(),
134
+ "updated_at": _now_utc(),
135
+ "datasets": {},
136
+ "results": {},
137
+ "workflow_ids": [],
138
+ "operation_ids": [],
139
+ "last_result_handle": None,
140
+ }
141
+ _write_json(_record_path("sessions", session_id), record)
142
+ return record
143
+
144
+
145
+ def get_session(session_id: str) -> dict[str, Any]:
146
+ path = _record_path("sessions", session_id)
147
+ if not path.exists():
148
+ raise FileNotFoundError(f"Session not found: {session_id}")
149
+ return _read_json(path)
150
+
151
+
152
+ def save_session(session: dict[str, Any]) -> dict[str, Any]:
153
+ session["updated_at"] = _now_utc()
154
+ _write_json(_record_path("sessions", session["session_id"]), session)
155
+ return session
156
+
157
+
158
+ def register_dataset(
159
+ session_id: str,
160
+ *,
161
+ grid_path: str,
162
+ data_path: str | None = None,
163
+ name: str | None = None,
164
+ ) -> dict[str, Any]:
165
+ session = get_session(session_id)
166
+ dataset_handle = _new_id("dataset")
167
+ dataset_record = {
168
+ "dataset_handle": dataset_handle,
169
+ "name": name or Path(data_path or grid_path).stem,
170
+ "grid_path": grid_path,
171
+ "data_path": data_path,
172
+ "registered_at": _now_utc(),
173
+ }
174
+ session["datasets"][dataset_handle] = dataset_record
175
+ save_session(session)
176
+ return dataset_record
177
+
178
+
179
+ def get_dataset_handle(session_id: str, dataset_handle: str) -> dict[str, Any]:
180
+ session = get_session(session_id)
181
+ if dataset_handle not in session["datasets"]:
182
+ raise FileNotFoundError(
183
+ f"Dataset handle {dataset_handle!r} not found in session {session_id!r}"
184
+ )
185
+ return session["datasets"][dataset_handle]
186
+
187
+
188
+ def create_operation(
189
+ *,
190
+ tool_name: str,
191
+ session_id: str | None = None,
192
+ workflow_id: str | None = None,
193
+ ) -> dict[str, Any]:
194
+ operation_id = _new_id("op")
195
+ record = {
196
+ "operation_id": operation_id,
197
+ "tool_name": tool_name,
198
+ "session_id": session_id,
199
+ "workflow_id": workflow_id,
200
+ "status": "running",
201
+ "stage": "started",
202
+ "created_at": _now_utc(),
203
+ "updated_at": _now_utc(),
204
+ "events": [
205
+ {
206
+ "timestamp_utc": _now_utc(),
207
+ "stage": "started",
208
+ "message": f"{tool_name} started",
209
+ }
210
+ ],
211
+ }
212
+ _write_json(_record_path("operations", operation_id), record)
213
+ if session_id:
214
+ session = get_session(session_id)
215
+ session["operation_ids"].append(operation_id)
216
+ save_session(session)
217
+ return record
218
+
219
+
220
+ def get_operation(operation_id: str) -> dict[str, Any]:
221
+ path = _record_path("operations", operation_id)
222
+ if not path.exists():
223
+ raise FileNotFoundError(f"Operation not found: {operation_id}")
224
+ return _read_json(path)
225
+
226
+
227
+ def save_operation(operation: dict[str, Any]) -> dict[str, Any]:
228
+ operation["updated_at"] = _now_utc()
229
+ _write_json(_record_path("operations", operation["operation_id"]), operation)
230
+ return operation
231
+
232
+
233
+ def append_operation_event(
234
+ operation_id: str,
235
+ *,
236
+ stage: str,
237
+ message: str,
238
+ status: str | None = None,
239
+ details: dict[str, Any] | None = None,
240
+ ) -> dict[str, Any]:
241
+ operation = get_operation(operation_id)
242
+ operation["stage"] = stage
243
+ if status is not None:
244
+ operation["status"] = status
245
+ event: dict[str, Any] = {
246
+ "timestamp_utc": _now_utc(),
247
+ "stage": stage,
248
+ "message": message,
249
+ }
250
+ if details:
251
+ event["details"] = details
252
+ operation["events"].append(event)
253
+ return save_operation(operation)
254
+
255
+
256
+ def finalize_operation(
257
+ operation_id: str, *, status: str, summary: str | None = None
258
+ ) -> dict[str, Any]:
259
+ operation = get_operation(operation_id)
260
+ operation["status"] = status
261
+ operation["stage"] = status
262
+ if summary:
263
+ operation["summary"] = summary
264
+ operation["events"].append(
265
+ {
266
+ "timestamp_utc": _now_utc(),
267
+ "stage": status,
268
+ "message": summary,
269
+ }
270
+ )
271
+ return save_operation(operation)
272
+
273
+
274
+ def list_operations(session_id: str | None = None) -> list[dict[str, Any]]:
275
+ operations_dir = _ensure_dir(_state_root() / "operations")
276
+ operations = [_read_json(path) for path in operations_dir.glob("*.json")]
277
+ operations.sort(key=lambda item: item.get("created_at", ""))
278
+ if session_id is None:
279
+ return operations
280
+ return [op for op in operations if op.get("session_id") == session_id]
281
+
282
+
283
+ def persist_result(
284
+ *,
285
+ kind: str,
286
+ name: str,
287
+ summary: dict[str, Any],
288
+ session_id: str | None = None,
289
+ artifact_path: str | None = None,
290
+ metadata: dict[str, Any] | None = None,
291
+ ) -> dict[str, Any]:
292
+ result_handle = _new_id("result")
293
+ record = {
294
+ "result_handle": result_handle,
295
+ "kind": kind,
296
+ "name": name,
297
+ "summary": summary,
298
+ "artifact_path": artifact_path,
299
+ "metadata": metadata or {},
300
+ "created_at": _now_utc(),
301
+ "session_id": session_id,
302
+ }
303
+ _write_json(_record_path("results", result_handle), record)
304
+ if session_id:
305
+ session = get_session(session_id)
306
+ session["results"][result_handle] = {
307
+ "kind": kind,
308
+ "name": name,
309
+ "created_at": record["created_at"],
310
+ }
311
+ session["last_result_handle"] = result_handle
312
+ save_session(session)
313
+ return record
314
+
315
+
316
+ def get_result(result_handle: str) -> dict[str, Any]:
317
+ path = _record_path("results", result_handle)
318
+ if not path.exists():
319
+ raise FileNotFoundError(f"Result not found: {result_handle}")
320
+ return _read_json(path)
321
+
322
+
323
+ def save_result(result: dict[str, Any]) -> dict[str, Any]:
324
+ _write_json(_record_path("results", result["result_handle"]), result)
325
+ return result
326
+
327
+
328
+ def write_grid_artifact(grid: Any, result_id: str) -> str:
329
+ path = _result_path(result_id, ".nc")
330
+ grid.to_xarray().to_netcdf(path)
331
+ return str(path)
332
+
333
+
334
+ def write_dataarray_artifact(data: Any, result_id: str) -> str:
335
+ path = _result_path(result_id, ".nc")
336
+ _sanitize_netcdf_attrs(data).to_netcdf(path)
337
+ return str(path)
338
+
339
+
340
+ def write_dataset_artifact(data: Any, result_id: str) -> str:
341
+ path = _result_path(result_id, ".nc")
342
+ sanitized = _sanitize_netcdf_attrs(data)
343
+ sanitized = sanitized.assign_attrs(
344
+ {
345
+ str(key): _sanitize_netcdf_attr_value(value)
346
+ for key, value in getattr(sanitized, "attrs", {}).items()
347
+ }
348
+ )
349
+ for name in getattr(sanitized, "data_vars", {}):
350
+ sanitized[name].attrs = {
351
+ str(key): _sanitize_netcdf_attr_value(value)
352
+ for key, value in sanitized[name].attrs.items()
353
+ }
354
+ sanitized.to_netcdf(path)
355
+ return str(path)
356
+
357
+
358
+ def write_json_artifact(payload: dict[str, Any], result_id: str) -> str:
359
+ path = _result_path(result_id, ".json")
360
+ _write_json(path, payload)
361
+ return str(path)
362
+
363
+
364
+ def copy_artifact(src: str, dest: str) -> str:
365
+ destination = Path(dest)
366
+ destination.parent.mkdir(parents=True, exist_ok=True)
367
+ shutil.copyfile(src, destination)
368
+ return str(destination)
369
+
370
+
371
+ def create_workflow(
372
+ *,
373
+ template: str,
374
+ inputs: dict[str, Any],
375
+ session_id: str | None = None,
376
+ steps: list[str],
377
+ ) -> dict[str, Any]:
378
+ workflow_id = _new_id("workflow")
379
+ record = {
380
+ "workflow_id": workflow_id,
381
+ "template": template,
382
+ "session_id": session_id,
383
+ "inputs": _json_safe(inputs),
384
+ "status": "pending",
385
+ "created_at": _now_utc(),
386
+ "updated_at": _now_utc(),
387
+ "events": [],
388
+ "steps": [
389
+ {"name": name, "status": "pending", "summary": None, "error": None}
390
+ for name in steps
391
+ ],
392
+ "result_handle": None,
393
+ }
394
+ _write_json(_record_path("workflows", workflow_id), record)
395
+ if session_id:
396
+ session = get_session(session_id)
397
+ session["workflow_ids"].append(workflow_id)
398
+ save_session(session)
399
+ return record
400
+
401
+
402
+ def get_workflow(workflow_id: str) -> dict[str, Any]:
403
+ path = _record_path("workflows", workflow_id)
404
+ if not path.exists():
405
+ raise FileNotFoundError(f"Workflow not found: {workflow_id}")
406
+ return _read_json(path)
407
+
408
+
409
+ def save_workflow(workflow: dict[str, Any]) -> dict[str, Any]:
410
+ workflow["updated_at"] = _now_utc()
411
+ _write_json(_record_path("workflows", workflow["workflow_id"]), workflow)
412
+ return workflow
413
+
414
+
415
+ def append_workflow_event(
416
+ workflow_id: str, *, stage: str, message: str, details: dict[str, Any] | None = None
417
+ ) -> dict[str, Any]:
418
+ workflow = get_workflow(workflow_id)
419
+ event: dict[str, Any] = {
420
+ "timestamp_utc": _now_utc(),
421
+ "stage": stage,
422
+ "message": message,
423
+ }
424
+ if details:
425
+ event["details"] = details
426
+ workflow["events"].append(event)
427
+ return save_workflow(workflow)
428
+
429
+
430
+ def update_workflow_step(
431
+ workflow_id: str,
432
+ step_name: str,
433
+ *,
434
+ status: str,
435
+ summary: str | None = None,
436
+ error: str | None = None,
437
+ ) -> dict[str, Any]:
438
+ workflow = get_workflow(workflow_id)
439
+ for step in workflow["steps"]:
440
+ if step["name"] == step_name:
441
+ step["status"] = status
442
+ step["summary"] = summary
443
+ step["error"] = error
444
+ break
445
+ return save_workflow(workflow)
446
+
447
+
448
+ def reset_session(session_id: str, *, clear_artifacts: bool = False) -> dict[str, Any]:
449
+ session = get_session(session_id)
450
+ result_handles = list(session["results"])
451
+ workflow_ids = list(session["workflow_ids"])
452
+ operation_ids = list(session["operation_ids"])
453
+
454
+ removed_artifacts: list[str] = []
455
+ if clear_artifacts:
456
+ for result_handle in result_handles:
457
+ with suppress(FileNotFoundError):
458
+ result = get_result(result_handle)
459
+ artifact_path = result.get("artifact_path")
460
+ if artifact_path and Path(artifact_path).exists():
461
+ Path(artifact_path).unlink()
462
+ removed_artifacts.append(artifact_path)
463
+
464
+ for result_handle in result_handles:
465
+ with suppress(FileNotFoundError):
466
+ _record_path("results", result_handle).unlink()
467
+ for workflow_id in workflow_ids:
468
+ with suppress(FileNotFoundError):
469
+ _record_path("workflows", workflow_id).unlink()
470
+ for operation_id in operation_ids:
471
+ with suppress(FileNotFoundError):
472
+ _record_path("operations", operation_id).unlink()
473
+
474
+ session["results"] = {}
475
+ session["workflow_ids"] = []
476
+ session["operation_ids"] = []
477
+ session["last_result_handle"] = None
478
+ save_session(session)
479
+
480
+ return {
481
+ "session_id": session_id,
482
+ "cleared_results": result_handles,
483
+ "cleared_workflows": workflow_ids,
484
+ "cleared_operations": operation_ids,
485
+ "removed_artifacts": removed_artifacts,
486
+ }
487
+
488
+
489
+ @dataclass
490
+ class OperationTracker:
491
+ """Simple persistent operation tracker for long-running tools."""
492
+
493
+ tool_name: str
494
+ session_id: str | None = None
495
+ workflow_id: str | None = None
496
+ operation_id: str = field(init=False)
497
+
498
+ def __post_init__(self) -> None:
499
+ record = create_operation(
500
+ tool_name=self.tool_name,
501
+ session_id=self.session_id,
502
+ workflow_id=self.workflow_id,
503
+ )
504
+ self.operation_id = record["operation_id"]
505
+
506
+ def stage(
507
+ self, stage: str, message: str, details: dict[str, Any] | None = None
508
+ ) -> None:
509
+ append_operation_event(
510
+ self.operation_id,
511
+ stage=stage,
512
+ message=message,
513
+ status="running",
514
+ details=details,
515
+ )
516
+
517
+ def succeed(self, summary: str) -> None:
518
+ finalize_operation(self.operation_id, status="completed", summary=summary)
519
+
520
+ def fail(self, summary: str) -> None:
521
+ finalize_operation(self.operation_id, status="failed", summary=summary)
@@ -0,0 +1,115 @@
1
+ from .advanced import (
2
+ calculate_anomaly,
3
+ calculate_bias,
4
+ calculate_ensemble_mean,
5
+ calculate_ensemble_spread,
6
+ calculate_pattern_correlation,
7
+ calculate_rmse,
8
+ calculate_temporal_mean,
9
+ compare_fields,
10
+ export_to_csv,
11
+ export_to_netcdf,
12
+ extract_cross_section,
13
+ regrid_dataset,
14
+ remap_variable,
15
+ subset_bbox,
16
+ subset_polygon,
17
+ write_result,
18
+ )
19
+ from .capabilities import get_capabilities
20
+ from .catalog import list_datasets
21
+ from .execution_control import (
22
+ endpoint_status,
23
+ get_execution_mode,
24
+ probe_path_access,
25
+ set_execution_mode,
26
+ validate_hpc_setup,
27
+ )
28
+ from .inspection import validate_dataset
29
+ from .orchestration import analyze_dataset
30
+
31
+ # Public tool surface for inspection and plotting. Each function is a
32
+ # dispatcher that runs locally by default and routes to a Globus Compute
33
+ # endpoint when ``use_remote=True``. Internal callers that need the pure
34
+ # local implementation can import the underscored helpers from
35
+ # ``.inspection`` / ``.plotting`` directly.
36
+ from .plotting import plot_mesh_geo
37
+ from .remote_tools import (
38
+ calculate_area,
39
+ calculate_zonal_mean,
40
+ inspect_mesh,
41
+ inspect_variable,
42
+ plot_mesh,
43
+ plot_variable,
44
+ plot_zonal_mean,
45
+ )
46
+ from .scientific_agent import run_scientific_agent
47
+ from .stateful import (
48
+ create_session,
49
+ get_operation_status,
50
+ get_result_handle,
51
+ get_session_state,
52
+ get_workflow_status,
53
+ list_operations,
54
+ register_dataset,
55
+ reset_session_state,
56
+ resume_workflow,
57
+ run_workflow,
58
+ )
59
+ from .vector_calc import (
60
+ calculate_azimuthal_mean,
61
+ calculate_curl,
62
+ calculate_divergence,
63
+ calculate_gradient,
64
+ )
65
+
66
+ __all__ = [
67
+ "get_capabilities",
68
+ "list_datasets",
69
+ "analyze_dataset",
70
+ "run_scientific_agent",
71
+ "run_workflow",
72
+ "resume_workflow",
73
+ "get_workflow_status",
74
+ "create_session",
75
+ "register_dataset",
76
+ "get_session_state",
77
+ "reset_session_state",
78
+ "get_result_handle",
79
+ "get_operation_status",
80
+ "list_operations",
81
+ "inspect_mesh",
82
+ "inspect_variable",
83
+ "calculate_area",
84
+ "calculate_zonal_mean",
85
+ "validate_dataset",
86
+ "plot_mesh",
87
+ "plot_mesh_geo",
88
+ "plot_variable",
89
+ "plot_zonal_mean",
90
+ "subset_bbox",
91
+ "subset_polygon",
92
+ "extract_cross_section",
93
+ "compare_fields",
94
+ "calculate_bias",
95
+ "calculate_rmse",
96
+ "calculate_pattern_correlation",
97
+ "remap_variable",
98
+ "regrid_dataset",
99
+ "calculate_temporal_mean",
100
+ "calculate_anomaly",
101
+ "calculate_ensemble_mean",
102
+ "calculate_ensemble_spread",
103
+ "export_to_netcdf",
104
+ "export_to_csv",
105
+ "write_result",
106
+ "calculate_gradient",
107
+ "calculate_curl",
108
+ "calculate_divergence",
109
+ "calculate_azimuthal_mean",
110
+ "endpoint_status",
111
+ "get_execution_mode",
112
+ "probe_path_access",
113
+ "set_execution_mode",
114
+ "validate_hpc_setup",
115
+ ]