uxarray-mcp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,493 @@
1
+ """Academy agent for orchestrating local and remote UXarray computations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import warnings
7
+ from typing import Any, Dict, Optional
8
+
9
+ try:
10
+ from academy.agent import Agent as _AcademyAgent
11
+ from academy.agent import action
12
+ except ImportError:
13
+ _AcademyAgent = object # type: ignore[assignment,misc]
14
+
15
+ def action(fn): # type: ignore[no-redef]
16
+ """No-op decorator when academy is not installed."""
17
+ return fn
18
+
19
+
20
+ from .compute_functions import (
21
+ remote_calculate_area,
22
+ remote_calculate_azimuthal_mean,
23
+ remote_calculate_curl,
24
+ remote_calculate_divergence,
25
+ remote_calculate_gradient,
26
+ remote_calculate_zonal_mean,
27
+ remote_inspect_mesh,
28
+ remote_inspect_variable,
29
+ remote_plot_mesh,
30
+ remote_plot_variable,
31
+ remote_plot_zonal_mean,
32
+ remote_probe_path,
33
+ )
34
+ from .config import HPCConfig
35
+
36
+
37
+ class UXarrayComputeAgent(_AcademyAgent):
38
+ """Academy agent for UXarray computations with HPC support.
39
+
40
+ This agent orchestrates execution of UXarray operations either locally
41
+ or on remote HPC resources via Globus Compute.
42
+
43
+ Parameters
44
+ ----------
45
+ config : HPCConfig
46
+ Configuration for HPC execution
47
+
48
+ Examples
49
+ --------
50
+ >>> from uxarray_mcp.remote import load_config, UXarrayComputeAgent
51
+ >>> config = load_config()
52
+ >>> agent = UXarrayComputeAgent(config)
53
+ """
54
+
55
+ def __init__(self, config: HPCConfig):
56
+ super().__init__()
57
+ self.config = config
58
+ self._executor: Any = None
59
+
60
+ def _get_executor(self):
61
+ """Get or create Globus Compute executor with AllCodeStrategies.
62
+
63
+ AllCodeStrategies serializes the actual function code instead of
64
+ just the module reference, so the HPC endpoint does not need
65
+ uxarray_mcp installed — only uxarray and its dependencies.
66
+ """
67
+ if self._executor is None and self.config.endpoint_id:
68
+ from globus_compute_sdk import Executor
69
+ from globus_compute_sdk.serialize import (
70
+ AllCodeStrategies,
71
+ ComputeSerializer,
72
+ )
73
+
74
+ with warnings.catch_warnings():
75
+ warnings.filterwarnings(
76
+ "ignore",
77
+ message=r"(?s).*Environment differences detected between local SDK and endpoint.*",
78
+ category=UserWarning,
79
+ )
80
+ self._executor = Executor(
81
+ endpoint_id=self.config.endpoint_id,
82
+ serializer=ComputeSerializer(strategy_code=AllCodeStrategies()),
83
+ )
84
+ return self._executor
85
+
86
+ @action
87
+ async def inspect_mesh_remote(
88
+ self, file_path: str, use_remote: bool = False
89
+ ) -> Dict[str, Any]:
90
+ """Inspect mesh topology with optional remote execution.
91
+
92
+ Parameters
93
+ ----------
94
+ file_path : str
95
+ Path to mesh file
96
+ use_remote : bool
97
+ If True, execute on HPC; if False, execute locally
98
+
99
+ Returns
100
+ -------
101
+ dict
102
+ Mesh topology info (n_face, n_node, n_edge, source)
103
+ """
104
+ if use_remote and self.config.endpoint_id:
105
+ return await self._run_on_hpc(remote_inspect_mesh, file_path)
106
+ else:
107
+ return self._run_local_inspect_mesh(file_path)
108
+
109
+ @action
110
+ async def calculate_area_remote(
111
+ self, file_path: str, use_remote: bool = False
112
+ ) -> Dict[str, Any]:
113
+ """Calculate face areas with optional remote execution.
114
+
115
+ Parameters
116
+ ----------
117
+ file_path : str
118
+ Path to mesh file
119
+ use_remote : bool
120
+ If True, execute on HPC; if False, execute locally
121
+
122
+ Returns
123
+ -------
124
+ dict
125
+ Area statistics
126
+
127
+ Examples
128
+ --------
129
+ >>> result = await agent.calculate_area_remote("mesh.nc", use_remote=False)
130
+ """
131
+ if use_remote and self.config.endpoint_id:
132
+ return await self._run_on_hpc(remote_calculate_area, file_path)
133
+ else:
134
+ return self._run_local_calculate_area(file_path)
135
+
136
+ @action
137
+ async def inspect_variable_remote(
138
+ self,
139
+ grid_path: str,
140
+ data_path: str,
141
+ variable_name: Optional[str] = None,
142
+ use_remote: bool = False,
143
+ ) -> Dict[str, Any]:
144
+ """Inspect variables with optional remote execution.
145
+
146
+ Parameters
147
+ ----------
148
+ grid_path : str
149
+ Path to grid file
150
+ data_path : str
151
+ Path to data file
152
+ variable_name : str | None
153
+ Variable to inspect, or None for all
154
+ use_remote : bool
155
+ If True, execute on HPC; if False, execute locally
156
+
157
+ Returns
158
+ -------
159
+ dict
160
+ Variable metadata
161
+ """
162
+ if use_remote and self.config.endpoint_id:
163
+ return await self._run_on_hpc(
164
+ remote_inspect_variable, grid_path, data_path, variable_name
165
+ )
166
+ else:
167
+ return self._run_local_inspect_variable(grid_path, data_path, variable_name)
168
+
169
+ @action
170
+ async def calculate_zonal_mean_remote(
171
+ self,
172
+ grid_path: str,
173
+ data_path: str,
174
+ variable_name: str,
175
+ lat_spec: Optional[tuple | float | list] = None,
176
+ conservative: bool = False,
177
+ use_remote: bool = False,
178
+ ) -> Dict[str, Any]:
179
+ """Calculate zonal mean with optional remote execution.
180
+
181
+ Parameters
182
+ ----------
183
+ grid_path : str
184
+ Path to grid file
185
+ data_path : str
186
+ Path to data file
187
+ variable_name : str
188
+ Variable to compute zonal mean for
189
+ lat_spec : tuple | float | list | None
190
+ Latitude specification
191
+ conservative : bool
192
+ Use conservative averaging
193
+ use_remote : bool
194
+ If True, execute on HPC; if False, execute locally
195
+
196
+ Returns
197
+ -------
198
+ dict
199
+ Zonal mean results
200
+ """
201
+ if use_remote and self.config.endpoint_id:
202
+ return await self._run_on_hpc(
203
+ remote_calculate_zonal_mean,
204
+ grid_path,
205
+ data_path,
206
+ variable_name,
207
+ lat_spec,
208
+ conservative,
209
+ )
210
+ else:
211
+ return self._run_local_calculate_zonal_mean(
212
+ grid_path, data_path, variable_name, lat_spec, conservative
213
+ )
214
+
215
+ @action
216
+ async def calculate_gradient_remote(
217
+ self, grid_path: str, data_path: str, variable_name: str
218
+ ) -> Dict[str, Any]:
219
+ """Compute spatial gradient on HPC."""
220
+ return await self._run_on_hpc(
221
+ remote_calculate_gradient, grid_path, data_path, variable_name
222
+ )
223
+
224
+ @action
225
+ async def calculate_curl_remote(
226
+ self, grid_path: str, data_path: str, u_variable: str, v_variable: str
227
+ ) -> Dict[str, Any]:
228
+ """Compute relative vorticity (curl) on HPC."""
229
+ return await self._run_on_hpc(
230
+ remote_calculate_curl, grid_path, data_path, u_variable, v_variable
231
+ )
232
+
233
+ @action
234
+ async def calculate_divergence_remote(
235
+ self, grid_path: str, data_path: str, u_variable: str, v_variable: str
236
+ ) -> Dict[str, Any]:
237
+ """Compute horizontal divergence on HPC."""
238
+ return await self._run_on_hpc(
239
+ remote_calculate_divergence, grid_path, data_path, u_variable, v_variable
240
+ )
241
+
242
+ @action
243
+ async def calculate_azimuthal_mean_remote(
244
+ self,
245
+ grid_path: str,
246
+ data_path: str,
247
+ variable_name: str,
248
+ center_lon: float,
249
+ center_lat: float,
250
+ outer_radius: float,
251
+ radius_step: float,
252
+ ) -> Dict[str, Any]:
253
+ """Compute azimuthal mean around a centre point on HPC."""
254
+ return await self._run_on_hpc(
255
+ remote_calculate_azimuthal_mean,
256
+ grid_path,
257
+ data_path,
258
+ variable_name,
259
+ center_lon,
260
+ center_lat,
261
+ outer_radius,
262
+ radius_step,
263
+ )
264
+
265
+ @action
266
+ async def probe_path_remote(
267
+ self, file_path: str, inspect_netcdf: bool = True, use_remote: bool = False
268
+ ) -> Dict[str, Any]:
269
+ """Probe whether a remote worker can read the exact target path."""
270
+ if use_remote and self.config.endpoint_id:
271
+ return await self._run_on_hpc(remote_probe_path, file_path, inspect_netcdf)
272
+ else:
273
+ return remote_probe_path(file_path, inspect_netcdf)
274
+
275
+ @action
276
+ async def plot_mesh_remote(
277
+ self,
278
+ grid_path: str,
279
+ width: int = 800,
280
+ height: int = 400,
281
+ use_remote: bool = False,
282
+ ) -> Dict[str, Any]:
283
+ """Render mesh wireframe PNG on HPC and return base64 bytes."""
284
+ if use_remote and self.config.endpoint_id:
285
+ return await self._run_on_hpc(remote_plot_mesh, grid_path, width, height)
286
+ else:
287
+ return remote_plot_mesh(grid_path, width, height)
288
+
289
+ @action
290
+ async def plot_variable_remote(
291
+ self,
292
+ grid_path: str,
293
+ data_path: str,
294
+ variable_name: Optional[str] = None,
295
+ width: int = 800,
296
+ height: int = 400,
297
+ cmap: str = "viridis",
298
+ vmin: Optional[float] = None,
299
+ vmax: Optional[float] = None,
300
+ title: Optional[str] = None,
301
+ time_index: int = 0,
302
+ use_remote: bool = False,
303
+ ) -> Dict[str, Any]:
304
+ """Render face-centered variable PNG on HPC and return base64 bytes."""
305
+ if use_remote and self.config.endpoint_id:
306
+ return await self._run_on_hpc(
307
+ remote_plot_variable,
308
+ grid_path,
309
+ data_path,
310
+ variable_name,
311
+ width,
312
+ height,
313
+ cmap,
314
+ vmin,
315
+ vmax,
316
+ title,
317
+ time_index,
318
+ )
319
+ else:
320
+ return remote_plot_variable(
321
+ grid_path,
322
+ data_path,
323
+ variable_name,
324
+ width,
325
+ height,
326
+ cmap,
327
+ vmin,
328
+ vmax,
329
+ title,
330
+ time_index,
331
+ )
332
+
333
+ @action
334
+ async def plot_zonal_mean_remote(
335
+ self,
336
+ grid_path: str,
337
+ data_path: str,
338
+ variable_name: str,
339
+ width: int = 800,
340
+ height: int = 400,
341
+ lat_spec=None,
342
+ conservative: bool = False,
343
+ line_color: str = "#1f77b4",
344
+ title: Optional[str] = None,
345
+ use_remote: bool = False,
346
+ ) -> Dict[str, Any]:
347
+ """Render zonal mean profile PNG on HPC and return base64 bytes."""
348
+ if use_remote and self.config.endpoint_id:
349
+ return await self._run_on_hpc(
350
+ remote_plot_zonal_mean,
351
+ grid_path,
352
+ data_path,
353
+ variable_name,
354
+ width,
355
+ height,
356
+ lat_spec,
357
+ conservative,
358
+ line_color,
359
+ title,
360
+ )
361
+ else:
362
+ return remote_plot_zonal_mean(
363
+ grid_path,
364
+ data_path,
365
+ variable_name,
366
+ width,
367
+ height,
368
+ lat_spec,
369
+ conservative,
370
+ line_color,
371
+ title,
372
+ )
373
+
374
+ async def _run_on_hpc(self, func, *args, **kwargs) -> Dict[str, Any]:
375
+ """Execute function on HPC via Globus Compute.
376
+
377
+ Parameters
378
+ ----------
379
+ func : callable
380
+ Remote function to execute
381
+ *args
382
+ Positional arguments for function
383
+ **kwargs
384
+ Keyword arguments for function
385
+
386
+ Returns
387
+ -------
388
+ dict
389
+ Function result from HPC
390
+ """
391
+ executor = self._get_executor()
392
+ if executor is None:
393
+ raise RuntimeError("HPC endpoint not configured")
394
+
395
+ loop = asyncio.get_event_loop()
396
+ with warnings.catch_warnings():
397
+ warnings.filterwarnings(
398
+ "ignore",
399
+ message=r"(?s).*Environment differences detected between local SDK and endpoint.*",
400
+ category=UserWarning,
401
+ )
402
+ future = executor.submit(func, *args, **kwargs)
403
+ result = await loop.run_in_executor(
404
+ None, future.result, self.config.timeout_seconds
405
+ )
406
+
407
+ # Attach provenance with the correct HPC venue — the remote functions
408
+ # are self contained and don't call attach_provenance themselves.
409
+ from uxarray_mcp.provenance import attach_provenance
410
+
411
+ endpoint_label = self.config.endpoint_name or "configured"
412
+ return attach_provenance(
413
+ result,
414
+ tool=func.__name__,
415
+ inputs={"args": [str(a) for a in args]},
416
+ venue=f"hpc:{endpoint_label}",
417
+ )
418
+
419
+ def _run_local_inspect_mesh(self, file_path: str) -> Dict[str, Any]:
420
+ """Execute inspect_mesh locally as fallback."""
421
+ from uxarray_mcp.tools import inspect_mesh
422
+
423
+ return inspect_mesh(file_path)
424
+
425
+ def _run_local_calculate_area(self, file_path: str) -> Dict[str, Any]:
426
+ """Execute calculate_area locally as fallback."""
427
+ from uxarray_mcp.tools import calculate_area
428
+
429
+ return calculate_area(file_path)
430
+
431
+ def _run_local_inspect_variable(
432
+ self, grid_path: str, data_path: str, variable_name: Optional[str]
433
+ ) -> Dict[str, Any]:
434
+ """Execute inspect_variable locally as fallback."""
435
+ from uxarray_mcp.tools import inspect_variable
436
+
437
+ return inspect_variable(grid_path, data_path, variable_name)
438
+
439
+ def _run_local_calculate_zonal_mean(
440
+ self,
441
+ grid_path: str,
442
+ data_path: str,
443
+ variable_name: str,
444
+ lat_spec: Optional[tuple | float | list],
445
+ conservative: bool,
446
+ ) -> Dict[str, Any]:
447
+ """Execute calculate_zonal_mean locally as fallback."""
448
+ from uxarray_mcp.tools import calculate_zonal_mean
449
+
450
+ return calculate_zonal_mean(
451
+ grid_path, data_path, variable_name, lat_spec, conservative
452
+ )
453
+
454
+
455
+ _agent_instance = None
456
+ _agent_instances: dict[str, UXarrayComputeAgent] = {}
457
+
458
+
459
+ def get_agent(
460
+ endpoint: str | None = None, path: str | None = None
461
+ ) -> UXarrayComputeAgent:
462
+ """Get or create singleton agent instance.
463
+
464
+ Returns
465
+ -------
466
+ UXarrayComputeAgent
467
+ Configured agent instance
468
+
469
+ Examples
470
+ --------
471
+ >>> agent = get_agent()
472
+ >>> result = await agent.calculate_area_remote("mesh.nc")
473
+ """
474
+ global _agent_instance
475
+ from .config import load_config
476
+
477
+ base_config = load_config()
478
+ config = base_config.for_endpoint(endpoint=endpoint, path=path)
479
+
480
+ if endpoint is None and path is None:
481
+ if _agent_instance is None:
482
+ _agent_instance = UXarrayComputeAgent(config)
483
+ return _agent_instance
484
+
485
+ key = (
486
+ f"{config.endpoint_name or 'default'}:"
487
+ f"{config.endpoint_id or 'local'}:"
488
+ f"{config.execution_mode}:"
489
+ f"{config.timeout_seconds}"
490
+ )
491
+ if key not in _agent_instances:
492
+ _agent_instances[key] = UXarrayComputeAgent(config)
493
+ return _agent_instances[key]