climate-ref-core 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- climate_ref_core/cmip6_to_cmip7.py +598 -0
- climate_ref_core/dataset_registry.py +43 -0
- climate_ref_core/diagnostics.py +10 -0
- climate_ref_core/env.py +37 -0
- climate_ref_core/esgf/__init__.py +21 -0
- climate_ref_core/esgf/base.py +122 -0
- climate_ref_core/esgf/cmip6.py +119 -0
- climate_ref_core/esgf/fetcher.py +138 -0
- climate_ref_core/esgf/obs4mips.py +94 -0
- climate_ref_core/esgf/registry.py +307 -0
- climate_ref_core/exceptions.py +24 -0
- climate_ref_core/providers.py +143 -17
- climate_ref_core/testing.py +621 -0
- {climate_ref_core-0.8.1.dist-info → climate_ref_core-0.9.0.dist-info}/METADATA +4 -2
- climate_ref_core-0.9.0.dist-info/RECORD +32 -0
- climate_ref_core-0.8.1.dist-info/RECORD +0 -24
- {climate_ref_core-0.8.1.dist-info → climate_ref_core-0.9.0.dist-info}/WHEEL +0 -0
- {climate_ref_core-0.8.1.dist-info → climate_ref_core-0.9.0.dist-info}/licenses/LICENCE +0 -0
- {climate_ref_core-0.8.1.dist-info → climate_ref_core-0.9.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Registry-based dataset request implementation.
|
|
3
|
+
|
|
4
|
+
This module provides request classes for fetching datasets from pooch registries
|
|
5
|
+
(e.g., pmp-climatology) rather than ESGF.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import re
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import pandas as pd
|
|
15
|
+
from loguru import logger
|
|
16
|
+
|
|
17
|
+
from climate_ref_core.dataset_registry import dataset_registry_manager
|
|
18
|
+
|
|
19
|
+
# Number of path parts in PMP climatology registry keys
|
|
20
|
+
_PMP_CLIMATOLOGY_PATH_PARTS = 5
|
|
21
|
+
# Number of path parts in obs4REF registry keys
|
|
22
|
+
_OBS4REF_PATH_PARTS = 8
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _parse_obs4ref_key(key: str) -> dict[str, Any]:
|
|
26
|
+
"""
|
|
27
|
+
Parse an obs4REF registry key to extract metadata.
|
|
28
|
+
|
|
29
|
+
Keys follow the pattern:
|
|
30
|
+
obs4REF/{institution_id}/{source_id}/{frequency}/{variable_id}/{grid_label}/{version}/{filename}
|
|
31
|
+
|
|
32
|
+
Where filename is:
|
|
33
|
+
{variable_id}_{frequency}_{source_id}_{inst_short}_{grid_label}_{time_range}.nc
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
key
|
|
38
|
+
The registry key (path) to parse
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
Dictionary with parsed metadata, or empty dict if parsing fails
|
|
43
|
+
"""
|
|
44
|
+
# Example: obs4REF/MOHC/HadISST-1-1/mon/ts/gn/v20250415/ts_mon_HadISST-1-1_PCMDI_gn_187001-202501.nc
|
|
45
|
+
parts = key.split("/")
|
|
46
|
+
if len(parts) != _OBS4REF_PATH_PARTS:
|
|
47
|
+
logger.debug(f"Unexpected obs4REF key format (expected 8 parts): {key}")
|
|
48
|
+
return {}
|
|
49
|
+
|
|
50
|
+
_, institution_id, _source_id, _frequency, _variable_id, _grid_label, version, filename = parts
|
|
51
|
+
|
|
52
|
+
# Parse filename: {var}_{freq}_{source_id}_{inst_short}_{grid}_{time_range}.nc
|
|
53
|
+
# Handle source_ids with hyphens (e.g., "HadISST-1-1", "GPCP-Monthly-3-2")
|
|
54
|
+
filename_pattern = re.compile(
|
|
55
|
+
r"^(?P<variable_id>[a-zA-Z0-9]+)_"
|
|
56
|
+
r"(?P<frequency>[a-z]+)_"
|
|
57
|
+
r"(?P<source_id>[A-Za-z0-9-]+)_"
|
|
58
|
+
r"(?P<institution_short>[A-Za-z0-9-]+)_"
|
|
59
|
+
r"(?P<grid_label>[a-zA-Z]+)_"
|
|
60
|
+
r"(?P<time_range>\d+-\d+)\.nc$"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
match = filename_pattern.match(filename)
|
|
64
|
+
if not match:
|
|
65
|
+
logger.debug(f"obs4REF filename doesn't match expected pattern: {filename}")
|
|
66
|
+
return {}
|
|
67
|
+
|
|
68
|
+
metadata = match.groupdict()
|
|
69
|
+
|
|
70
|
+
# Add path-derived metadata (can override filename metadata for consistency)
|
|
71
|
+
metadata["institution_id"] = institution_id
|
|
72
|
+
metadata["version"] = version
|
|
73
|
+
|
|
74
|
+
# Parse time range (format: YYYYMM-YYYYMM)
|
|
75
|
+
time_parts = metadata["time_range"].split("-")
|
|
76
|
+
if len(time_parts) == 2: # noqa: PLR2004
|
|
77
|
+
metadata["time_start"] = time_parts[0]
|
|
78
|
+
metadata["time_end"] = time_parts[1]
|
|
79
|
+
|
|
80
|
+
# Add the full key for reference
|
|
81
|
+
metadata["key"] = key
|
|
82
|
+
|
|
83
|
+
return metadata
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _parse_pmp_climatology_key(key: str) -> dict[str, Any]:
|
|
87
|
+
"""
|
|
88
|
+
Parse a PMP climatology registry key to extract metadata.
|
|
89
|
+
|
|
90
|
+
Keys follow the pattern:
|
|
91
|
+
PMP_obs4MIPsClims/{variable_id}/{grid_label}/{version}/{filename}
|
|
92
|
+
|
|
93
|
+
Where filename is:
|
|
94
|
+
{variable_id}_mon_{source_id}_{institution_id}_{grid_label}_{time_range}_AC_{version}_{resolution}.nc
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
key
|
|
99
|
+
The registry key (path) to parse
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
Dictionary with parsed metadata, or empty dict if parsing fails
|
|
104
|
+
"""
|
|
105
|
+
# Example: PMP_obs4MIPsClims/psl/gr/v20250224/
|
|
106
|
+
# psl_mon_ERA-5_PCMDI_gr_198101-200412_AC_v20250224_2.5x2.5.nc
|
|
107
|
+
parts = key.split("/")
|
|
108
|
+
if len(parts) != _PMP_CLIMATOLOGY_PATH_PARTS:
|
|
109
|
+
logger.debug(f"Unexpected key format (expected 5 parts): {key}")
|
|
110
|
+
return {}
|
|
111
|
+
|
|
112
|
+
_, _variable_id_dir, _grid_label, _version, filename = parts
|
|
113
|
+
|
|
114
|
+
# Parse filename: {var}_mon_{source_id}_{inst_id}_{grid}_{time}_AC_{ver}_{res}.nc
|
|
115
|
+
# Handle source_ids with hyphens (e.g., "ERA-5", "GPCP-Monthly-3-2")
|
|
116
|
+
filename_pattern = re.compile(
|
|
117
|
+
r"^(?P<variable_id>[a-z]+)_mon_"
|
|
118
|
+
r"(?P<source_id>[A-Za-z0-9-]+)_"
|
|
119
|
+
r"(?P<institution_id>[A-Za-z0-9]+)_"
|
|
120
|
+
r"(?P<grid_label>[a-z]+)_"
|
|
121
|
+
r"(?P<time_range>\d+-\d+)_AC_"
|
|
122
|
+
r"(?P<version>v\d+)_"
|
|
123
|
+
r"(?P<resolution>.+)\.nc$"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
match = filename_pattern.match(filename)
|
|
127
|
+
if not match:
|
|
128
|
+
logger.debug(f"Filename doesn't match expected pattern: {filename}")
|
|
129
|
+
return {}
|
|
130
|
+
|
|
131
|
+
metadata = match.groupdict()
|
|
132
|
+
|
|
133
|
+
# Parse time range (format: YYYYMM-YYYYMM)
|
|
134
|
+
time_parts = metadata["time_range"].split("-")
|
|
135
|
+
if len(time_parts) == 2: # noqa: PLR2004
|
|
136
|
+
metadata["time_start"] = time_parts[0]
|
|
137
|
+
metadata["time_end"] = time_parts[1]
|
|
138
|
+
|
|
139
|
+
# Add the full key for reference
|
|
140
|
+
metadata["key"] = key
|
|
141
|
+
|
|
142
|
+
return metadata
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _matches_facets(
|
|
146
|
+
metadata: dict[str, Any],
|
|
147
|
+
facets: dict[str, str | tuple[str, ...]],
|
|
148
|
+
) -> bool:
|
|
149
|
+
"""
|
|
150
|
+
Check if metadata matches all provided facets.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
metadata
|
|
155
|
+
Parsed metadata dictionary
|
|
156
|
+
facets
|
|
157
|
+
Facets to match against. Values can be strings or tuples of strings.
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
True if all facets match
|
|
162
|
+
"""
|
|
163
|
+
for facet_name, facet_value in facets.items():
|
|
164
|
+
if facet_name not in metadata:
|
|
165
|
+
return False
|
|
166
|
+
|
|
167
|
+
# Normalize to tuple for comparison
|
|
168
|
+
allowed_values = (facet_value,) if isinstance(facet_value, str) else facet_value
|
|
169
|
+
|
|
170
|
+
if metadata[facet_name] not in allowed_values:
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class RegistryRequest:
|
|
177
|
+
"""
|
|
178
|
+
Request for data from a pooch registry (e.g., pmp-climatology).
|
|
179
|
+
|
|
180
|
+
These data are fetched from a pooch registry rather than ESGF.
|
|
181
|
+
This is useful for pre-processed datasets like PMP climatologies
|
|
182
|
+
that are hosted externally but not on ESGF.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
slug
|
|
187
|
+
Unique identifier for this request
|
|
188
|
+
registry_name
|
|
189
|
+
Name of the registry to fetch from (e.g., "pmp-climatology")
|
|
190
|
+
facets
|
|
191
|
+
Facets to filter datasets (e.g., {"variable_id": "psl", "source_id": "ERA-5"})
|
|
192
|
+
source_type
|
|
193
|
+
Type of dataset source (default: "PMPClimatology")
|
|
194
|
+
time_span
|
|
195
|
+
Optional time range filter (not used for registry filtering, but required for protocol)
|
|
196
|
+
|
|
197
|
+
Example
|
|
198
|
+
-------
|
|
199
|
+
```python
|
|
200
|
+
request = RegistryRequest(
|
|
201
|
+
slug="era5-psl",
|
|
202
|
+
registry_name="pmp-climatology",
|
|
203
|
+
facets={"variable_id": "psl", "source_id": "ERA-5"},
|
|
204
|
+
)
|
|
205
|
+
df = request.fetch_datasets()
|
|
206
|
+
```
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
slug: str,
|
|
212
|
+
registry_name: str,
|
|
213
|
+
facets: dict[str, str | tuple[str, ...]],
|
|
214
|
+
source_type: str = "PMPClimatology",
|
|
215
|
+
time_span: tuple[str, str] | None = None,
|
|
216
|
+
) -> None:
|
|
217
|
+
self.slug = slug
|
|
218
|
+
self.registry_name = registry_name
|
|
219
|
+
self.facets = facets
|
|
220
|
+
self.source_type = source_type
|
|
221
|
+
self.time_span = time_span
|
|
222
|
+
|
|
223
|
+
def __repr__(self) -> str:
|
|
224
|
+
return (
|
|
225
|
+
f"RegistryRequest(slug={self.slug!r}, registry_name={self.registry_name!r}, "
|
|
226
|
+
f"facets={self.facets!r}, source_type={self.source_type!r}, time_span={self.time_span!r})"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def _get_parser(self) -> Callable[[str], dict[str, Any]]:
|
|
230
|
+
"""Get the appropriate parser function based on registry name."""
|
|
231
|
+
if self.registry_name == "pmp-climatology":
|
|
232
|
+
return _parse_pmp_climatology_key
|
|
233
|
+
elif self.registry_name == "obs4ref":
|
|
234
|
+
return _parse_obs4ref_key
|
|
235
|
+
else:
|
|
236
|
+
# Default to obs4ref parser as fallback
|
|
237
|
+
logger.warning(f"Unknown registry '{self.registry_name}', using obs4ref parser")
|
|
238
|
+
return _parse_obs4ref_key
|
|
239
|
+
|
|
240
|
+
def fetch_datasets(self) -> pd.DataFrame:
|
|
241
|
+
"""
|
|
242
|
+
Fetch matching datasets from the registry.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
DataFrame containing dataset metadata and file paths.
|
|
247
|
+
Each row represents one file, with columns for metadata
|
|
248
|
+
and a 'files' column containing a list with the file path.
|
|
249
|
+
"""
|
|
250
|
+
logger.info(f"Fetching from registry '{self.registry_name}' for request: {self.slug}")
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
registry = dataset_registry_manager[self.registry_name]
|
|
254
|
+
except KeyError:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Registry '{self.registry_name}' not found. "
|
|
257
|
+
f"Available registries: {list(dataset_registry_manager.keys())}"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
parser = self._get_parser()
|
|
261
|
+
matching_rows: list[dict[str, Any]] = []
|
|
262
|
+
|
|
263
|
+
for key in registry.registry.keys():
|
|
264
|
+
# Parse metadata from the registry key
|
|
265
|
+
metadata = parser(key)
|
|
266
|
+
if not metadata:
|
|
267
|
+
continue
|
|
268
|
+
|
|
269
|
+
# Check if it matches the requested facets
|
|
270
|
+
if not _matches_facets(metadata, self.facets):
|
|
271
|
+
continue
|
|
272
|
+
|
|
273
|
+
# Fetch the file (downloads if not cached)
|
|
274
|
+
try:
|
|
275
|
+
file_path = registry.fetch(key)
|
|
276
|
+
logger.debug(f"Fetched: {key} -> {file_path}")
|
|
277
|
+
except Exception as e:
|
|
278
|
+
logger.warning(f"Failed to fetch {key}: {e}")
|
|
279
|
+
continue
|
|
280
|
+
|
|
281
|
+
# Build row compatible with ESGFFetcher expectations
|
|
282
|
+
row = {
|
|
283
|
+
**metadata,
|
|
284
|
+
"files": [file_path],
|
|
285
|
+
"path": file_path,
|
|
286
|
+
}
|
|
287
|
+
matching_rows.append(row)
|
|
288
|
+
|
|
289
|
+
if not matching_rows:
|
|
290
|
+
logger.warning(f"No datasets found matching facets: {self.facets}")
|
|
291
|
+
return pd.DataFrame()
|
|
292
|
+
|
|
293
|
+
result = pd.DataFrame(matching_rows)
|
|
294
|
+
|
|
295
|
+
# Filter to only the latest version for each unique dataset
|
|
296
|
+
# Datasets are identified by source_id, variable_id, and grid_label
|
|
297
|
+
if "version" in result.columns:
|
|
298
|
+
group_by_cols = ["source_id", "variable_id", "grid_label"]
|
|
299
|
+
# Only group by columns that exist in the DataFrame
|
|
300
|
+
group_by_cols = [col for col in group_by_cols if col in result.columns]
|
|
301
|
+
if group_by_cols:
|
|
302
|
+
max_version = result.groupby(group_by_cols, sort=False)["version"].transform("max")
|
|
303
|
+
result = result[result["version"] == max_version]
|
|
304
|
+
|
|
305
|
+
logger.info(f"Found {len(result)} datasets matching request: {self.slug}")
|
|
306
|
+
|
|
307
|
+
return result
|
climate_ref_core/exceptions.py
CHANGED
|
@@ -67,3 +67,27 @@ class DiagnosticError(RefException):
|
|
|
67
67
|
def __reduce__(self) -> tuple[type["DiagnosticError"], tuple[str, Any]]:
|
|
68
68
|
# Return a tuple: (callable, args_tuple_for_reconstruction)
|
|
69
69
|
return (self.__class__, (self.message, self.result))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class TestCaseError(RefException):
|
|
73
|
+
"""Raised when there is an error with a test case."""
|
|
74
|
+
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class TestCaseNotFoundError(TestCaseError):
|
|
79
|
+
"""Raised when a test case is not found."""
|
|
80
|
+
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class NoTestDataSpecError(TestCaseError):
|
|
85
|
+
"""Raised when a diagnostic has no test_data_spec."""
|
|
86
|
+
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class DatasetResolutionError(TestCaseError):
|
|
91
|
+
"""Raised when datasets cannot be resolved for a test case."""
|
|
92
|
+
|
|
93
|
+
pass
|
climate_ref_core/providers.py
CHANGED
|
@@ -175,6 +175,105 @@ class DiagnosticProvider:
|
|
|
175
175
|
"""
|
|
176
176
|
return self._diagnostics[slug.lower()]
|
|
177
177
|
|
|
178
|
+
def setup(
|
|
179
|
+
self,
|
|
180
|
+
config: Config,
|
|
181
|
+
*,
|
|
182
|
+
skip_env: bool = False,
|
|
183
|
+
skip_data: bool = False,
|
|
184
|
+
) -> None:
|
|
185
|
+
"""
|
|
186
|
+
Perform all setup required before offline execution.
|
|
187
|
+
|
|
188
|
+
This calls setup_environment and fetch_data in the correct order.
|
|
189
|
+
Override individual hooks for fine-grained control.
|
|
190
|
+
|
|
191
|
+
This method MUST be idempotent - safe to call multiple times.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
config
|
|
196
|
+
The application configuration
|
|
197
|
+
skip_env
|
|
198
|
+
If True, skip environment setup (e.g., conda)
|
|
199
|
+
skip_data
|
|
200
|
+
If True, skip data fetching
|
|
201
|
+
"""
|
|
202
|
+
if not skip_env:
|
|
203
|
+
self.setup_environment(config)
|
|
204
|
+
if not skip_data:
|
|
205
|
+
self.fetch_data(config)
|
|
206
|
+
|
|
207
|
+
def setup_environment(self, config: Config) -> None:
|
|
208
|
+
"""
|
|
209
|
+
Set up the execution environment (e.g., conda environment).
|
|
210
|
+
|
|
211
|
+
Default implementation does nothing. Override in subclasses
|
|
212
|
+
that require environment setup.
|
|
213
|
+
|
|
214
|
+
This method MUST be idempotent.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
config
|
|
219
|
+
The application configuration
|
|
220
|
+
"""
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
def fetch_data(self, config: Config) -> None:
|
|
224
|
+
"""
|
|
225
|
+
Fetch all data required for offline execution.
|
|
226
|
+
|
|
227
|
+
This includes reference datasets, climatology files, map files,
|
|
228
|
+
recipes, or any other data the provider needs.
|
|
229
|
+
|
|
230
|
+
Default implementation does nothing. Override in subclasses
|
|
231
|
+
that require data fetching. Providers are responsible for
|
|
232
|
+
determining what data they need and how to fetch it.
|
|
233
|
+
|
|
234
|
+
Data should be downloaded to the pooch cache (via `fetch_all_files`
|
|
235
|
+
with `output_dir=None`). Diagnostics can then access data via
|
|
236
|
+
`registry.abspath`.
|
|
237
|
+
|
|
238
|
+
This method MUST be idempotent.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
config
|
|
243
|
+
The application configuration
|
|
244
|
+
"""
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
def validate_setup(self, config: Config) -> bool:
|
|
248
|
+
"""
|
|
249
|
+
Validate that the provider is ready for offline execution.
|
|
250
|
+
|
|
251
|
+
Returns True if setup is complete and valid, False otherwise.
|
|
252
|
+
Default implementation returns True.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
config
|
|
257
|
+
The application configuration
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
bool
|
|
262
|
+
True if setup is valid and complete
|
|
263
|
+
"""
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
def get_data_path(self) -> Path | None:
|
|
267
|
+
"""
|
|
268
|
+
Get the path where this provider's data is cached.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
Path | None
|
|
273
|
+
The data cache path, or None if the provider doesn't use cached data.
|
|
274
|
+
"""
|
|
275
|
+
return None
|
|
276
|
+
|
|
178
277
|
|
|
179
278
|
def import_provider(fqn: str) -> DiagnosticProvider:
|
|
180
279
|
"""
|
|
@@ -316,7 +415,7 @@ class CondaDiagnosticProvider(CommandLineDiagnosticProvider):
|
|
|
316
415
|
self._conda_exe: Path | None = None
|
|
317
416
|
self._prefix: Path | None = None
|
|
318
417
|
self.url = f"git+{repo}@{tag_or_commit}" if repo and tag_or_commit else None
|
|
319
|
-
self.env_vars: dict[str, str] =
|
|
418
|
+
self.env_vars: dict[str, str] = os.environ.copy()
|
|
320
419
|
|
|
321
420
|
@property
|
|
322
421
|
def prefix(self) -> Path:
|
|
@@ -338,9 +437,26 @@ class CondaDiagnosticProvider(CommandLineDiagnosticProvider):
|
|
|
338
437
|
"""Configure the provider."""
|
|
339
438
|
super().configure(config)
|
|
340
439
|
self.prefix = config.paths.software / "conda"
|
|
440
|
+
self.env_vars.setdefault("HOME", str(self.prefix))
|
|
441
|
+
|
|
442
|
+
def _is_stale(self, path: Path) -> bool:
|
|
443
|
+
"""Check if a file is older than `MICROMAMBA_MAX_AGE`.
|
|
444
|
+
|
|
445
|
+
Parameters
|
|
446
|
+
----------
|
|
447
|
+
path
|
|
448
|
+
The path to the file to check.
|
|
449
|
+
|
|
450
|
+
Returns
|
|
451
|
+
-------
|
|
452
|
+
True if the file is older than `MICROMAMBA_MAX_AGE`, False otherwise.
|
|
453
|
+
"""
|
|
454
|
+
creation_time = datetime.datetime.fromtimestamp(path.stat().st_ctime)
|
|
455
|
+
age = datetime.datetime.now() - creation_time
|
|
456
|
+
return age > MICROMAMBA_MAX_AGE
|
|
341
457
|
|
|
342
458
|
def _install_conda(self, update: bool) -> Path:
|
|
343
|
-
"""Install micromamba in a
|
|
459
|
+
"""Install micromamba in a specific location.
|
|
344
460
|
|
|
345
461
|
Parameters
|
|
346
462
|
----------
|
|
@@ -354,20 +470,15 @@ class CondaDiagnosticProvider(CommandLineDiagnosticProvider):
|
|
|
354
470
|
"""
|
|
355
471
|
conda_exe = self.prefix / "micromamba"
|
|
356
472
|
|
|
357
|
-
if conda_exe.exists()
|
|
358
|
-
# Only update if the executable is older than `MICROMAMBA_MAX_AGE`.
|
|
359
|
-
creation_time = datetime.datetime.fromtimestamp(conda_exe.stat().st_ctime)
|
|
360
|
-
age = datetime.datetime.now() - creation_time
|
|
361
|
-
if age < MICROMAMBA_MAX_AGE:
|
|
362
|
-
update = False
|
|
363
|
-
|
|
364
|
-
if not conda_exe.exists() or update:
|
|
473
|
+
if not conda_exe.exists() or update or self._is_stale(conda_exe):
|
|
365
474
|
logger.info("Installing conda")
|
|
366
475
|
self.prefix.mkdir(parents=True, exist_ok=True)
|
|
367
|
-
response = requests.get(_get_micromamba_url(), timeout=120)
|
|
476
|
+
response = requests.get(_get_micromamba_url(), timeout=120, stream=True)
|
|
368
477
|
response.raise_for_status()
|
|
369
478
|
with conda_exe.open(mode="wb") as file:
|
|
370
|
-
|
|
479
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
480
|
+
if chunk: # Filter out keep-alive new chunks
|
|
481
|
+
file.write(chunk)
|
|
371
482
|
conda_exe.chmod(stat.S_IRWXU)
|
|
372
483
|
logger.info("Successfully installed conda.")
|
|
373
484
|
|
|
@@ -428,7 +539,7 @@ class CondaDiagnosticProvider(CommandLineDiagnosticProvider):
|
|
|
428
539
|
f"{self.env_path}",
|
|
429
540
|
]
|
|
430
541
|
logger.debug(f"Running {' '.join(cmd)}")
|
|
431
|
-
subprocess.run(cmd, check=True) # noqa: S603
|
|
542
|
+
subprocess.run(cmd, check=True, env=self.env_vars) # noqa: S603
|
|
432
543
|
|
|
433
544
|
if self.url is not None:
|
|
434
545
|
logger.info(f"Installing development version of {self.slug} from {self.url}")
|
|
@@ -443,7 +554,7 @@ class CondaDiagnosticProvider(CommandLineDiagnosticProvider):
|
|
|
443
554
|
self.url,
|
|
444
555
|
]
|
|
445
556
|
logger.debug(f"Running {' '.join(cmd)}")
|
|
446
|
-
subprocess.run(cmd, check=True) # noqa: S603
|
|
557
|
+
subprocess.run(cmd, check=True, env=self.env_vars) # noqa: S603
|
|
447
558
|
|
|
448
559
|
def run(self, cmd: Iterable[str]) -> None:
|
|
449
560
|
"""
|
|
@@ -476,8 +587,6 @@ class CondaDiagnosticProvider(CommandLineDiagnosticProvider):
|
|
|
476
587
|
*cmd,
|
|
477
588
|
]
|
|
478
589
|
logger.info(f"Running '{' '.join(cmd)}'")
|
|
479
|
-
env_vars = os.environ.copy()
|
|
480
|
-
env_vars.update(self.env_vars)
|
|
481
590
|
try:
|
|
482
591
|
# This captures the log output until the execution is complete
|
|
483
592
|
# We could poll using `subprocess.Popen` if we want something more responsive
|
|
@@ -487,7 +596,7 @@ class CondaDiagnosticProvider(CommandLineDiagnosticProvider):
|
|
|
487
596
|
stdout=subprocess.PIPE,
|
|
488
597
|
stderr=subprocess.STDOUT,
|
|
489
598
|
text=True,
|
|
490
|
-
env=env_vars,
|
|
599
|
+
env=self.env_vars,
|
|
491
600
|
)
|
|
492
601
|
logger.info("Command output: \n" + res.stdout)
|
|
493
602
|
logger.info("Command execution successful")
|
|
@@ -495,3 +604,20 @@ class CondaDiagnosticProvider(CommandLineDiagnosticProvider):
|
|
|
495
604
|
logger.error(f"Failed to run {cmd}")
|
|
496
605
|
logger.error(e.stdout)
|
|
497
606
|
raise e
|
|
607
|
+
|
|
608
|
+
def setup_environment(self, config: Config) -> None:
|
|
609
|
+
"""Set up the conda environment."""
|
|
610
|
+
self.create_env()
|
|
611
|
+
|
|
612
|
+
def validate_setup(self, config: Config) -> bool:
|
|
613
|
+
"""Validate conda environment exists."""
|
|
614
|
+
env_exists = self.env_path.exists()
|
|
615
|
+
if not env_exists:
|
|
616
|
+
logger.error(
|
|
617
|
+
f"Conda environment for {self.slug} is not available at {self.env_path}. "
|
|
618
|
+
f"Please run `ref providers setup --provider {self.slug}` to install it."
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# TODO: Could add more validation here (e.g., check packages installed)
|
|
622
|
+
|
|
623
|
+
return env_exists
|