esgpull 0.9.1__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
esgpull/config.py CHANGED
@@ -4,142 +4,56 @@ import logging
4
4
  from collections.abc import Iterator, Mapping
5
5
  from enum import Enum, auto
6
6
  from pathlib import Path
7
- from typing import Any, TypeVar, Union, cast, overload
7
+ from typing import Any, TypeVar
8
8
 
9
9
  import tomlkit
10
- from attrs import Factory, define, field
11
- from attrs import has as attrs_has
12
- from cattrs import Converter
13
- from cattrs.gen import make_dict_unstructure_fn, override
14
- from tomlkit import TOMLDocument
10
+ from pydantic import BaseModel, Field, PrivateAttr, field_validator
11
+ from pydantic_settings import (
12
+ BaseSettings,
13
+ PydanticBaseSettingsSource,
14
+ SettingsConfigDict,
15
+ TomlConfigSettingsSource,
16
+ )
15
17
 
16
18
  from esgpull.constants import CONFIG_FILENAME
17
19
  from esgpull.exceptions import BadConfigError, VirtualConfigError
18
20
  from esgpull.install_config import InstallConfig
19
- from esgpull.models.options import Options
21
+ from esgpull.models.options import Option, Options
20
22
 
21
23
  logger = logging.getLogger("esgpull")
22
-
23
24
  T = TypeVar("T")
24
25
 
25
26
 
26
- @overload
27
- def cast_value(
28
- target: str, value: Union[str, int, bool, float], key: str
29
- ) -> str: ...
30
-
31
-
32
- @overload
33
- def cast_value(
34
- target: bool, value: Union[str, int, bool, float], key: str
35
- ) -> bool: ...
36
-
37
-
38
- @overload
39
- def cast_value(
40
- target: int, value: Union[str, int, bool, float], key: str
41
- ) -> int: ...
42
-
43
-
44
- @overload
45
- def cast_value(
46
- target: float, value: Union[str, int, bool, float], key: str
47
- ) -> float: ...
48
-
49
-
50
- def cast_value(
51
- target: Any, value: Union[str, int, bool, float], key: str
52
- ) -> Any:
53
- if isinstance(value, type(target)):
54
- return value
55
- elif attrs_has(type(target)):
56
- raise KeyError(key)
57
- elif isinstance(target, str):
58
- return str(value)
59
- elif isinstance(target, float):
60
- try:
61
- return float(value)
62
- except Exception:
63
- raise ValueError(value)
64
- elif isinstance(target, bool):
65
- if isinstance(value, str):
66
- if value.lower() in ["on", "true"]:
67
- return True
68
- elif value.lower() in ["off", "false"]:
69
- return False
70
- else:
71
- raise ValueError(value)
72
- else:
73
- raise TypeError(value)
74
- elif isinstance(target, int):
75
- # int must be after bool, because isinstance(True, int) == True
76
- try:
77
- return int(value)
78
- except Exception:
79
- raise ValueError(value)
27
+ def _get_root() -> Path:
28
+ if InstallConfig.current is not None:
29
+ return InstallConfig.current.path
80
30
  else:
81
- raise TypeError(value)
82
-
83
-
84
- @define
85
- class Paths:
86
- auth: Path = field(converter=Path)
87
- data: Path = field(converter=Path)
88
- db: Path = field(converter=Path)
89
- log: Path = field(converter=Path)
90
- tmp: Path = field(converter=Path)
91
- plugins: Path = field(converter=Path)
92
-
93
- @auth.default
94
- def _auth_factory(self) -> Path:
95
- if InstallConfig.current is not None:
96
- root = InstallConfig.current.path
97
- else:
98
- root = InstallConfig.default
99
- return root / "auth"
100
-
101
- @data.default
102
- def _data_factory(self) -> Path:
103
- if InstallConfig.current is not None:
104
- root = InstallConfig.current.path
105
- else:
106
- root = InstallConfig.default
107
- return root / "data"
108
-
109
- @db.default
110
- def _db_factory(self) -> Path:
111
- if InstallConfig.current is not None:
112
- root = InstallConfig.current.path
113
- else:
114
- root = InstallConfig.default
115
- return root / "db"
116
-
117
- @log.default
118
- def _log_factory(self) -> Path:
119
- if InstallConfig.current is not None:
120
- root = InstallConfig.current.path
121
- else:
122
- root = InstallConfig.default
123
- return root / "log"
124
-
125
- @tmp.default
126
- def _tmp_factory(self) -> Path:
127
- if InstallConfig.current is not None:
128
- root = InstallConfig.current.path
129
- else:
130
- root = InstallConfig.default
131
- return root / "tmp"
132
-
133
- @plugins.default
134
- def _plugins_factory(self) -> Path:
135
- if InstallConfig.current is not None:
136
- root = InstallConfig.current.path
137
- else:
138
- root = InstallConfig.default
139
- return root / "plugins"
31
+ return InstallConfig.default
32
+
33
+
34
+ class Paths(BaseModel, validate_assignment=True, validate_default=True):
35
+ data: Path = Path("data")
36
+ db: Path = Path("db")
37
+ log: Path = Path("log")
38
+ tmp: Path = Path("tmp")
39
+ plugins: Path = Path("plugins")
40
+
41
+ @field_validator(
42
+ "data",
43
+ "db",
44
+ "log",
45
+ "tmp",
46
+ "plugins",
47
+ mode="after",
48
+ )
49
+ @classmethod
50
+ def _set_path_from_root(cls, value: Path) -> Path:
51
+ root = _get_root()
52
+ if not value.is_absolute():
53
+ value = root / value
54
+ return value
140
55
 
141
56
  def values(self) -> Iterator[Path]:
142
- yield self.auth
143
57
  yield self.data
144
58
  yield self.db
145
59
  yield self.log
@@ -147,24 +61,20 @@ class Paths:
147
61
  yield self.plugins
148
62
 
149
63
 
150
- @define
151
- class Credentials:
64
+ class Credentials(BaseModel, validate_assignment=True):
152
65
  filename: str = "credentials.toml"
153
66
 
154
67
 
155
- @define
156
- class Cli:
68
+ class Cli(BaseModel, validate_assignment=True):
157
69
  page_size: int = 20
158
70
 
159
71
 
160
- @define
161
- class Db:
72
+ class Db(BaseModel, validate_assignment=True):
162
73
  filename: str = "esgpull.db"
163
74
 
164
75
 
165
- @define
166
- class Download:
167
- chunk_size: int = 1 << 26 # 64 MiB
76
+ class Download(BaseModel, validate_assignment=True):
77
+ chunk_size: int = 1 << 26
168
78
  http_timeout: int = 20
169
79
  max_concurrent: int = 5
170
80
  disable_ssl: bool = False
@@ -172,13 +82,22 @@ class Download:
172
82
  show_filename: bool = False
173
83
 
174
84
 
175
- @define
176
- class DefaultOptions:
85
+ class DefaultOptions(BaseModel, validate_assignment=True):
177
86
  distrib: str = Options._distrib_.name
178
87
  latest: str = Options._latest_.name
179
88
  replica: str = Options._replica_.name
180
89
  retracted: str = Options._retracted_.name
181
90
 
91
+ @field_validator(
92
+ "distrib", "latest", "replica", "retracted", mode="before"
93
+ )
94
+ @classmethod
95
+ def _is_valid_option(cls, value: str | Option) -> str:
96
+ if isinstance(value, str):
97
+ return Option(value.lower()).name
98
+ else:
99
+ return value.name
100
+
182
101
  def asdict(self) -> dict[str, str]:
183
102
  return dict(
184
103
  distrib=self.distrib,
@@ -188,33 +107,18 @@ class DefaultOptions:
188
107
  )
189
108
 
190
109
 
191
- @define
192
- class API:
110
+ class API(BaseModel, validate_assignment=True):
193
111
  index_node: str = "esgf-node.ipsl.upmc.fr"
194
112
  http_timeout: int = 20
195
113
  max_concurrent: int = 5
196
114
  page_limit: int = 50
197
- default_options: DefaultOptions = Factory(DefaultOptions)
115
+ default_options: DefaultOptions = Field(default_factory=DefaultOptions)
198
116
  default_query_id: str = ""
199
117
  use_custom_distribution_algorithm: bool = False
200
118
 
201
119
 
202
- def fix_rename_search_api(doc: TOMLDocument) -> TOMLDocument:
203
- if "api" in doc and "search" in doc:
204
- raise KeyError(
205
- "Both 'api' and 'search' (deprecated) are used in your "
206
- "config, please use 'api' only."
207
- )
208
- elif "search" in doc:
209
- logger.warn(
210
- "Deprecated key 'search' is used in your config, "
211
- "please use 'api' instead."
212
- )
213
- doc["api"] = doc.pop("search")
214
- return doc
215
-
216
-
217
- config_fixers = [fix_rename_search_api]
120
+ class Plugins(BaseModel, validate_assignment=True):
121
+ enabled: bool = False
218
122
 
219
123
 
220
124
  class ConfigKind(Enum):
@@ -227,9 +131,13 @@ class ConfigKind(Enum):
227
131
  class ConfigKey:
228
132
  path: tuple[str, ...]
229
133
 
230
- def __init__(self, first: str | tuple[str, ...], *rest: str) -> None:
231
- if isinstance(first, tuple):
232
- self.path = first + rest
134
+ def __init__(
135
+ self,
136
+ first: str | tuple[str, ...] | list[str],
137
+ *rest: str,
138
+ ) -> None:
139
+ if isinstance(first, (tuple, list)):
140
+ self.path = tuple(first) + rest
233
141
  elif "." in first:
234
142
  self.path = tuple(first.split(".")) + rest
235
143
  else:
@@ -239,7 +147,14 @@ class ConfigKey:
239
147
  yield from self.path
240
148
 
241
149
  def __hash__(self) -> int:
242
- return hash(self.path)
150
+ return hash(str(self.path))
151
+
152
+ def __eq__(self, other: object) -> bool:
153
+ match other:
154
+ case ConfigKey():
155
+ return self.path == other.path
156
+ case _:
157
+ raise TypeError(type(other))
243
158
 
244
159
  def __repr__(self) -> str:
245
160
  return ".".join(self)
@@ -261,13 +176,41 @@ class ConfigKey:
261
176
  return False
262
177
  return True
263
178
 
264
- def value_of(self, source: Mapping) -> Any:
179
+ def value_of(self, source: Any) -> Any:
265
180
  doc = source
266
181
  for key in self:
267
- doc = doc[key]
182
+ try:
183
+ doc = doc[key]
184
+ except TypeError:
185
+ doc = getattr(doc, key)
268
186
  return doc
269
187
 
270
188
 
189
+ def fix_rename_search_api(doc: dict) -> dict:
190
+ if "api" in doc and "search" in doc:
191
+ raise KeyError(
192
+ "Both 'api' and 'search' (deprecated) are used in your "
193
+ "config, please use 'api' only."
194
+ )
195
+ elif "search" in doc:
196
+ logger.warn(
197
+ "Deprecated key 'search' is used in your config, "
198
+ "please use 'api' instead."
199
+ )
200
+ doc["api"] = doc.pop("search")
201
+ return doc
202
+
203
+
204
+ def fix_remove_auth(doc: dict) -> dict:
205
+ if "paths" in doc and "auth" in doc["paths"]:
206
+ logger.warn(
207
+ "Deprecated 'paths.auth' is present in your config, "
208
+ "you can remove it safely."
209
+ )
210
+ doc["paths"].pop("auth")
211
+ return doc
212
+
213
+
271
214
  def iter_keys(
272
215
  source: Mapping,
273
216
  path: ConfigKey | None = None,
@@ -283,52 +226,88 @@ def iter_keys(
283
226
  yield local_path
284
227
 
285
228
 
286
- @define
287
- class Plugins:
288
- """Configuration for the plugin system"""
229
+ def pop_and_clear_empty_parents(source: Mapping, ckey: ConfigKey):
230
+ *parent_path, last_key = ckey.path
231
+ parent_ckey = ConfigKey(parent_path)
232
+ parent_ckey.value_of(source).pop(last_key)
289
233
 
290
- enabled: bool = False
234
+ for i in range(len(parent_path), 0, -1):
235
+ parent_ckey = ConfigKey(parent_path[: i - 1])
236
+ container_ckey = ConfigKey(parent_path[:i])
237
+ parent = parent_ckey.value_of(source)
238
+ container = container_ckey.value_of(source)
239
+ if isinstance(container, dict) and len(container) == 0:
240
+ parent.pop(container_ckey.path[-1])
241
+ else:
242
+ break # Stop if we hit a non-empty container
243
+
244
+
245
+ config_fixers = [fix_rename_search_api, fix_remove_auth]
246
+
247
+
248
+ class TomlKitConfigSettingsSource(TomlConfigSettingsSource):
249
+ def _read_file(self, file_path: Path) -> dict[str, Any]:
250
+ with open(file_path, mode="rb") as toml_file:
251
+ doc = tomlkit.load(toml_file)
252
+ for fixer in config_fixers:
253
+ try:
254
+ doc = fixer(doc)
255
+ except Exception:
256
+ raise BadConfigError(file_path)
257
+ return dict(doc)
291
258
 
292
259
 
293
- @define
294
- class Config:
295
- paths: Paths = Factory(Paths)
296
- credentials: Credentials = Factory(Credentials)
297
- cli: Cli = Factory(Cli)
298
- db: Db = Factory(Db)
299
- download: Download = Factory(Download)
300
- api: API = Factory(API)
301
- plugins: Plugins = Factory(Plugins)
302
- _raw: TOMLDocument | None = field(init=False, default=None)
303
- _config_file: Path | None = field(init=False, default=None)
260
+ class Config(BaseSettings):
261
+ # TODO: set in a load method instead
262
+ # model_config = SettingsConfigDict(toml_file=_get_root() / "config.toml")
263
+ model_config = SettingsConfigDict(toml_file=None)
264
+
265
+ paths: Paths = Field(default_factory=Paths)
266
+ credentials: Credentials = Field(default_factory=Credentials)
267
+ cli: Cli = Field(default_factory=Cli)
268
+ db: Db = Field(default_factory=Db)
269
+ download: Download = Field(default_factory=Download)
270
+ api: API = Field(default_factory=API)
271
+ plugins: Plugins = Field(default_factory=Plugins)
272
+ _raw: dict[str, Any] | None = PrivateAttr(default=None)
273
+ _config_file: Path | None = PrivateAttr(default=None)
274
+
275
+ @classmethod
276
+ def settings_customise_sources(
277
+ cls,
278
+ settings_cls: type[BaseSettings],
279
+ init_settings: PydanticBaseSettingsSource,
280
+ env_settings: PydanticBaseSettingsSource,
281
+ dotenv_settings: PydanticBaseSettingsSource,
282
+ file_secret_settings: PydanticBaseSettingsSource,
283
+ ) -> tuple[PydanticBaseSettingsSource, ...]:
284
+ return (
285
+ init_settings,
286
+ TomlKitConfigSettingsSource(settings_cls),
287
+ )
304
288
 
305
289
  @classmethod
306
290
  def load(cls, path: Path) -> Config:
307
- config_file = path / CONFIG_FILENAME
308
- if config_file.is_file():
309
- with config_file.open() as fh:
310
- doc = tomlkit.load(fh)
311
- for fixer in config_fixers:
312
- try:
313
- doc = fixer(doc)
314
- except Exception:
315
- raise BadConfigError(config_file)
316
- raw = doc
291
+ try:
292
+ file_path = path / CONFIG_FILENAME
293
+ cls.model_config["toml_file"] = file_path
294
+ instance = cls()
295
+ finally:
296
+ cls.model_config["toml_file"] = None
297
+ instance._config_file = file_path
298
+ if file_path.is_file():
299
+ with file_path.open("rb") as toml_file:
300
+ instance._raw = tomlkit.load(toml_file)
317
301
  else:
318
- doc = TOMLDocument()
319
- raw = None
320
- config = _converter_defaults.structure(doc, cls)
321
- config._raw = raw
322
- config._config_file = config_file
323
- return config
302
+ instance._raw = None
303
+ return instance
324
304
 
325
305
  @classmethod
326
306
  def default(cls) -> Config:
327
- if InstallConfig.current is not None:
328
- root = InstallConfig.current.path
329
- else:
330
- root = InstallConfig.default
331
- return cls.load(root)
307
+ ## TODO: rename+deprecate
308
+ ## very bad name since this is loading from the **default** config
309
+ ## file path, as set in InstallConfig
310
+ return cls.load(path=_get_root())
332
311
 
333
312
  @property
334
313
  def kind(self) -> ConfigKind:
@@ -341,145 +320,82 @@ class Config:
341
320
  else:
342
321
  return ConfigKind.Complete
343
322
 
344
- def dumps(self, defaults: bool = True, comments: bool = False) -> str:
345
- return self.dump(defaults, comments).as_string()
323
+ def dump(self, with_defaults: bool = True) -> dict:
324
+ result = self.model_dump(mode="json")
325
+ if not with_defaults:
326
+ unset = set(self.unset_options())
327
+ for ckey in iter_keys(self.model_dump()):
328
+ if ckey in unset:
329
+ pop_and_clear_empty_parents(result, ckey)
330
+ return result
346
331
 
347
- def dump(
348
- self,
349
- defaults: bool = True,
350
- comments: bool = False,
351
- ) -> TOMLDocument:
352
- if defaults:
353
- converter = _converter_defaults
354
- else:
355
- converter = _converter_no_defaults
356
- dump = converter.unstructure(self)
357
- if not defaults:
358
- pop_empty(dump)
359
- doc = TOMLDocument()
360
- doc.update(dump)
361
- # if comments and self._raw is not None:
362
- # original = tomlkit.loads(self._raw)
363
- return doc
332
+ def unset_options(self) -> list[ConfigKey]:
333
+ result: list[ConfigKey] = []
334
+ dump = self.model_dump()
335
+ for ckey in iter_keys(dump):
336
+ if not ckey.exists_in(self._raw):
337
+ result.append(ckey)
338
+ return result
364
339
 
365
340
  def update_item(
366
341
  self,
367
342
  key: str,
368
- value: str | int | bool,
343
+ value: T,
369
344
  empty_ok: bool = False,
370
- ) -> int | str | None:
345
+ ) -> T:
371
346
  if self._raw is None and empty_ok:
372
- self._raw = TOMLDocument()
373
- if self._raw is None:
347
+ self._raw = {}
348
+ elif self._raw is None:
374
349
  raise VirtualConfigError
375
- else:
376
- doc: dict = self._raw
350
+ doc = self._raw
377
351
  obj = self
378
- *parts, last = ConfigKey(key)
352
+ ckey = ConfigKey(key)
353
+ *parts, last = ckey
379
354
  for part in parts:
380
355
  doc.setdefault(part, {})
381
356
  doc = doc[part]
382
357
  obj = getattr(obj, part)
383
358
  old_value = getattr(obj, last)
384
- value = cast_value(old_value, value, key)
385
359
  setattr(obj, last, value)
386
360
  doc[last] = value
387
361
  return old_value
388
362
 
389
- def set_default(self, key: str) -> int | str | None:
363
+ def set_default(self, key: str) -> Any:
390
364
  ckey = ConfigKey(key)
391
365
  if self._raw is None:
392
366
  raise VirtualConfigError()
393
367
  elif not ckey.exists_in(self._raw):
394
368
  return None
395
- default_config = self.__class__()
396
- default_value = ckey.value_of(default_config.dump())
397
- old_value: Any = ckey.value_of(self.dump())
398
- first_pass = True
399
- obj = self
400
- for idx in range(len(ckey), 0, -1):
401
- *parts, last = ckey.path[:idx]
402
- doc: tomlkit.container.Container = self._raw
403
- for part in parts:
404
- if first_pass:
405
- obj = getattr(obj, part)
406
- doc = cast(tomlkit.container.Container, doc[part])
407
- if first_pass:
408
- doc.remove(last)
409
- setattr(obj, last, default_value)
410
- first_pass = False
411
- elif (
412
- (value := doc[last])
413
- and isinstance(value, tomlkit.container.Container)
414
- and len(value) == 0
415
- ):
416
- doc.remove(last)
369
+ default_config = Config()
370
+ default_value: Any = ckey.value_of(default_config)
371
+ old_value: Any = ckey.value_of(self)
372
+
373
+ *parent_path, last_key = ckey.path
374
+ parent_ckey = ConfigKey(parent_path)
375
+ obj = parent_ckey.value_of(self)
376
+ setattr(obj, last_key, default_value)
377
+ pop_and_clear_empty_parents(self._raw, ckey)
417
378
  return old_value
418
379
 
419
- def unset_options(self) -> list[ConfigKey]:
420
- result: list[ConfigKey] = []
421
- raw: dict
422
- dump = self.dump()
423
- if self._raw is None:
424
- raw = {}
425
- else:
426
- raw = self._raw
427
- for ckey in iter_keys(dump):
428
- if not ckey.exists_in(raw):
429
- result.append(ckey)
430
- return result
431
-
432
- def generate(
433
- self,
434
- overwrite: bool = False,
435
- ) -> None:
380
+ def generate(self, overwrite: bool = False) -> None:
436
381
  match (self.kind, overwrite):
437
382
  case (ConfigKind.Virtual, _):
438
383
  raise VirtualConfigError
439
384
  case (ConfigKind.Partial, overwrite):
440
- defaults = self.dump()
385
+ defaults = self.model_dump()
441
386
  for ckey in self.unset_options():
442
387
  self.update_item(str(ckey), ckey.value_of(defaults))
443
388
  case (ConfigKind.Partial | ConfigKind.Complete, _):
444
389
  raise FileExistsError(self._config_file)
445
390
  case (ConfigKind.NoFile, _):
446
- self._raw = self.dump()
391
+ self._raw = self.model_dump(mode="json")
447
392
  case _:
448
393
  raise ValueError(self.kind)
449
394
  self.write()
450
395
 
451
396
  def write(self) -> None:
452
- if self.kind == ConfigKind.Virtual or self._raw is None:
397
+ if self._config_file is None or self._raw is None:
453
398
  raise VirtualConfigError
454
- config_file = cast(Path, self._config_file)
455
- with config_file.open("w") as f:
399
+ self._config_file.parent.mkdir(parents=True, exist_ok=True)
400
+ with self._config_file.open("w") as f:
456
401
  tomlkit.dump(self._raw, f)
457
-
458
-
459
- def _make_converter(omit_default: bool) -> Converter:
460
- conv = Converter(omit_if_default=omit_default, forbid_extra_keys=True)
461
- conv.register_unstructure_hook(Path, str)
462
- conv.register_unstructure_hook(
463
- Config,
464
- make_dict_unstructure_fn(
465
- Config,
466
- conv,
467
- _raw=override(omit=True),
468
- _config_file=override(omit=True),
469
- ),
470
- )
471
- return conv
472
-
473
-
474
- _converter_defaults = _make_converter(omit_default=False)
475
- _converter_no_defaults = _make_converter(omit_default=True)
476
-
477
-
478
- def pop_empty(d: dict[str, Any]) -> None:
479
- keys = list(d.keys())
480
- for key in keys:
481
- value = d[key]
482
- if isinstance(value, dict):
483
- pop_empty(value)
484
- if not value:
485
- d.pop(key)
esgpull/constants.py CHANGED
@@ -4,6 +4,7 @@ CONFIG_FILENAME = "config.toml"
4
4
  INSTALLS_PATH_ENV = "ESGPULL_INSTALLS_PATH"
5
5
  ROOT_ENV = "ESGPULL_CURRENT"
6
6
  ESGPULL_DEBUG = os.environ.get("ESGPULL_DEBUG", "0") == "1"
7
+ ESGPULL_DEBUG_LOCALS = os.environ.get("ESGPULL_DEBUG_LOCALS", "0") == "1"
7
8
 
8
9
  IDP = "/esgf-idp/openid/"
9
10
  CEDA_IDP = "/OpenID/Provider/server/"