pastastore 1.10.2__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pastastore/base.py CHANGED
@@ -2,25 +2,236 @@
2
2
  """Base classes for PastaStore Connectors."""
3
3
 
4
4
  import functools
5
+ import logging
5
6
  import warnings
6
7
 
7
8
  # import weakref
8
9
  from abc import ABC, abstractmethod
10
+ from collections.abc import Iterable
9
11
  from itertools import chain
10
- from typing import Callable, Dict, List, Optional, Tuple, Union
12
+ from random import choice
13
+
14
+ # import weakref
15
+ from typing import Callable, Dict, List, Optional, Union
11
16
 
12
17
  import pandas as pd
13
18
  import pastas as ps
19
+ from packaging.version import parse as parse_version
14
20
  from tqdm.auto import tqdm
15
21
 
16
- from pastastore.util import ItemInLibraryException, _custom_warning, validate_names
17
- from pastastore.version import PASTAS_GEQ_150, PASTAS_LEQ_022
22
+ from pastastore.typing import AllLibs, FrameOrSeriesUnion, TimeSeriesLibs
23
+ from pastastore.util import (
24
+ ItemInLibraryException,
25
+ SeriesUsedByModel,
26
+ _custom_warning,
27
+ validate_names,
28
+ )
29
+ from pastastore.validator import Validator
30
+ from pastastore.version import PASTAS_GEQ_150
18
31
 
19
- FrameorSeriesUnion = Union[pd.DataFrame, pd.Series]
20
32
  warnings.showwarning = _custom_warning
21
33
 
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class ConnectorUtil:
38
+ """Mix-in class for utility methods used by BaseConnector subclasses.
39
+
40
+ This class contains internal methods for parsing names, handling metadata,
41
+ and parsing model dictionaries. It is designed to be mixed into BaseConnector
42
+ subclasses and assumes the presence of certain attributes and methods from
43
+ BaseConnector (e.g., oseries_names, stresses_names, get_oseries, get_stresses).
44
+
45
+ Note
46
+ ----
47
+ This class should not be instantiated directly. It is intended to be used
48
+ as a mixin with BaseConnector subclasses only.
49
+ """
50
+
51
+ def _parse_names(
52
+ self,
53
+ names: list[str] | str | None = None,
54
+ libname: AllLibs = "oseries",
55
+ ) -> list:
56
+ """Parse names kwarg, returns iterable with name(s) (internal method).
57
+
58
+ Parameters
59
+ ----------
60
+ names : Union[list, str], optional
61
+ str or list of str or None or 'all' (last two options
62
+ retrieves all names)
63
+ libname : str, optional
64
+ name of library, default is 'oseries'
65
+
66
+ Returns
67
+ -------
68
+ list
69
+ list of names
70
+ """
71
+ if not isinstance(names, str) and isinstance(names, Iterable):
72
+ return names
73
+ elif isinstance(names, str) and names != "all":
74
+ return [names]
75
+ elif names is None or names == "all":
76
+ if libname == "oseries":
77
+ return self.oseries_names
78
+ elif libname == "stresses":
79
+ return self.stresses_names
80
+ elif libname == "models":
81
+ return self.model_names
82
+ elif libname == "oseries_models":
83
+ return self.oseries_with_models
84
+ elif libname == "stresses_models":
85
+ return self.stresses_with_models
86
+ else:
87
+ raise ValueError(f"No library '{libname}'!")
88
+ else:
89
+ raise NotImplementedError(f"Cannot parse 'names': {names}")
90
+
91
+ @staticmethod
92
+ def _meta_list_to_frame(metalist: list, names: list):
93
+ """Convert list of metadata dictionaries to DataFrame.
94
+
95
+ Parameters
96
+ ----------
97
+ metalist : list
98
+ list of metadata dictionaries
99
+ names : list
100
+ list of names corresponding to data in metalist
101
+
102
+ Returns
103
+ -------
104
+ pandas.DataFrame
105
+ DataFrame containing overview of metadata
106
+ """
107
+ # convert to dataframe
108
+ if len(metalist) > 1:
109
+ meta = pd.DataFrame(metalist)
110
+ if len({"x", "y"}.difference(meta.columns)) == 0:
111
+ meta["x"] = meta["x"].astype(float)
112
+ meta["y"] = meta["y"].astype(float)
113
+ elif len(metalist) == 1:
114
+ meta = pd.DataFrame(metalist)
115
+ elif len(metalist) == 0:
116
+ meta = pd.DataFrame()
117
+
118
+ meta.index = names
119
+ meta.index.name = "name"
120
+ return meta
121
+
122
+ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
123
+ """Parse dictionary describing pastas models (internal method).
124
+
125
+ Parameters
126
+ ----------
127
+ mdict : dict
128
+ dictionary describing pastas.Model
129
+ update_ts_settings : bool, optional
130
+ update stored tmin and tmax in time series settings
131
+ based on time series loaded from store.
22
132
 
23
- class BaseConnector(ABC):
133
+ Returns
134
+ -------
135
+ ml : pastas.Model
136
+ time series analysis model
137
+ """
138
+ PASFILE_LEQ_022 = parse_version(
139
+ mdict["file_info"]["pastas_version"]
140
+ ) <= parse_version("0.22.0")
141
+
142
+ # oseries
143
+ if "series" not in mdict["oseries"]:
144
+ name = str(mdict["oseries"]["name"])
145
+ if name not in self.oseries.index:
146
+ msg = f"oseries '{name}' not present in library"
147
+ raise LookupError(msg)
148
+ mdict["oseries"]["series"] = self.get_oseries(name).squeeze()
149
+ # update tmin/tmax from time series
150
+ if update_ts_settings:
151
+ mdict["oseries"]["settings"]["tmin"] = mdict["oseries"]["series"].index[
152
+ 0
153
+ ]
154
+ mdict["oseries"]["settings"]["tmax"] = mdict["oseries"]["series"].index[
155
+ -1
156
+ ]
157
+
158
+ # StressModel, WellModel
159
+ for ts in mdict["stressmodels"].values():
160
+ if "stress" in ts.keys():
161
+ # WellModel
162
+ classkey = "stressmodel" if PASFILE_LEQ_022 else "class"
163
+ if ts[classkey] == "WellModel":
164
+ for stress in ts["stress"]:
165
+ if "series" not in stress:
166
+ name = str(stress["name"])
167
+ if name in self.stresses.index:
168
+ stress["series"] = self.get_stresses(name).squeeze()
169
+ # update tmin/tmax from time series
170
+ if update_ts_settings:
171
+ stress["settings"]["tmin"] = stress["series"].index[
172
+ 0
173
+ ]
174
+ stress["settings"]["tmax"] = stress["series"].index[
175
+ -1
176
+ ]
177
+ # StressModel
178
+ else:
179
+ for stress in ts["stress"] if PASFILE_LEQ_022 else [ts["stress"]]:
180
+ if "series" not in stress:
181
+ name = str(stress["name"])
182
+ if name in self.stresses.index:
183
+ stress["series"] = self.get_stresses(name).squeeze()
184
+ # update tmin/tmax from time series
185
+ if update_ts_settings:
186
+ stress["settings"]["tmin"] = stress["series"].index[
187
+ 0
188
+ ]
189
+ stress["settings"]["tmax"] = stress["series"].index[
190
+ -1
191
+ ]
192
+
193
+ # RechargeModel, TarsoModel
194
+ if ("prec" in ts.keys()) and ("evap" in ts.keys()):
195
+ for stress in [ts["prec"], ts["evap"]]:
196
+ if "series" not in stress:
197
+ name = str(stress["name"])
198
+ if name in self.stresses.index:
199
+ stress["series"] = self.get_stresses(name).squeeze()
200
+ # update tmin/tmax from time series
201
+ if update_ts_settings:
202
+ stress["settings"]["tmin"] = stress["series"].index[0]
203
+ stress["settings"]["tmax"] = stress["series"].index[-1]
204
+ else:
205
+ msg = "stress '{name}' not present in library"
206
+ raise KeyError(msg)
207
+
208
+ # hack for pcov w dtype object (when filled with NaNs on store?)
209
+ if "fit" in mdict:
210
+ if "pcov" in mdict["fit"]:
211
+ pcov = mdict["fit"]["pcov"]
212
+ if pcov.dtypes.apply(lambda dtyp: isinstance(dtyp, object)).any():
213
+ mdict["fit"]["pcov"] = pcov.astype(float)
214
+
215
+ # check pastas version vs pas-file version
216
+ file_version = mdict["file_info"]["pastas_version"]
217
+
218
+ # check file version and pastas version
219
+ # if file<0.23 and pastas>=1.0 --> error
220
+ PASTAS_GT_023 = parse_version(ps.__version__) > parse_version("0.23.1")
221
+ if PASFILE_LEQ_022 and PASTAS_GT_023:
222
+ raise UserWarning(
223
+ f"This file was created with Pastas v{file_version} "
224
+ f"and cannot be loaded with Pastas v{ps.__version__} Please load and "
225
+ "save the file with Pastas 0.23 first to update the file "
226
+ "format."
227
+ )
228
+
229
+ # Use pastas' internal _load_model - required for model reconstruction
230
+ ml = ps.io.base._load_model(mdict) # noqa: SLF001
231
+ return ml
232
+
233
+
234
+ class BaseConnector(ABC, ConnectorUtil):
24
235
  """Base Connector class.
25
236
 
26
237
  Class holds base logic for dealing with time series and Pastas Models. Create your
@@ -33,18 +244,12 @@ class BaseConnector(ABC):
33
244
  "stresses",
34
245
  "models",
35
246
  "oseries_models",
247
+ "stresses_models",
36
248
  ]
37
249
 
38
- # whether to check model time series contents against stored copies
39
- CHECK_MODEL_SERIES_VALUES = True
40
-
41
- # whether to validate time series according to pastas rules
42
- # True for pastas>=0.23.0 and False for pastas<=0.22.0
43
- USE_PASTAS_VALIDATE_SERIES = False if PASTAS_LEQ_022 else True
44
-
45
- # set series equality comparison settings (using assert_series_equal)
46
- SERIES_EQUALITY_ABSOLUTE_TOLERANCE = 1e-10
47
- SERIES_EQUALITY_RELATIVE_TOLERANCE = 0.0
250
+ _conn_type: Optional[str] = None
251
+ _validator: Optional[Validator] = None
252
+ name = None
48
253
 
49
254
  def __repr__(self):
50
255
  """Representation string of the object."""
@@ -55,13 +260,34 @@ class BaseConnector(ABC):
55
260
  f"{self.n_models} models"
56
261
  )
57
262
 
263
+ @property
264
+ def validation_settings(self):
265
+ """Return current connector settings as dictionary."""
266
+ return self.validator.settings
267
+
58
268
  @property
59
269
  def empty(self):
60
270
  """Check if the database is empty."""
61
271
  return not any([self.n_oseries > 0, self.n_stresses > 0, self.n_models > 0])
62
272
 
273
+ @property
274
+ def validator(self) -> Validator:
275
+ """Get the Validator instance for this connector."""
276
+ if self._validator is None:
277
+ raise AttributeError("Validator not set for this connector.")
278
+ return self._validator
279
+
280
+ @property
281
+ def conn_type(self) -> str:
282
+ """Get the connector type."""
283
+ if self._conn_type is None:
284
+ raise AttributeError(
285
+ "Connector class must set a connector type in `conn_type` attribute."
286
+ )
287
+ return self._conn_type
288
+
63
289
  @abstractmethod
64
- def _get_library(self, libname: str):
290
+ def _get_library(self, libname: AllLibs):
65
291
  """Get library handle.
66
292
 
67
293
  Must be overridden by subclass.
@@ -80,11 +306,10 @@ class BaseConnector(ABC):
80
306
  @abstractmethod
81
307
  def _add_item(
82
308
  self,
83
- libname: str,
84
- item: Union[FrameorSeriesUnion, Dict],
309
+ libname: AllLibs,
310
+ item: Union[FrameOrSeriesUnion, Dict],
85
311
  name: str,
86
312
  metadata: Optional[Dict] = None,
87
- overwrite: bool = False,
88
313
  ) -> None:
89
314
  """Add item for both time series and pastas.Models (internal method).
90
315
 
@@ -100,10 +325,17 @@ class BaseConnector(ABC):
100
325
  name of the item
101
326
  metadata : dict, optional
102
327
  dictionary containing metadata, by default None
328
+
329
+ Note
330
+ ----
331
+ Metadata storage can vary by connector:
332
+ - ArcticDB: Native metadata support via write()
333
+ - DictConnector: Stored as tuple (metadata, item)
334
+ - PasConnector: Separate {name}_meta.pas JSON file
103
335
  """
104
336
 
105
337
  @abstractmethod
106
- def _get_item(self, libname: str, name: str) -> Union[FrameorSeriesUnion, Dict]:
338
+ def _get_item(self, libname: AllLibs, name: str) -> Union[FrameOrSeriesUnion, Dict]:
107
339
  """Get item (series or pastas.Models) (internal method).
108
340
 
109
341
  Must be overridden by subclass.
@@ -122,7 +354,7 @@ class BaseConnector(ABC):
122
354
  """
123
355
 
124
356
  @abstractmethod
125
- def _del_item(self, libname: str, name: str) -> None:
357
+ def _del_item(self, libname: AllLibs, name: str, force: bool = False) -> None:
126
358
  """Delete items (series or models) (internal method).
127
359
 
128
360
  Must be overridden by subclass.
@@ -136,7 +368,7 @@ class BaseConnector(ABC):
136
368
  """
137
369
 
138
370
  @abstractmethod
139
- def _get_metadata(self, libname: str, name: str) -> Dict:
371
+ def _get_metadata(self, libname: TimeSeriesLibs, name: str) -> Dict:
140
372
  """Get metadata (internal method).
141
373
 
142
374
  Must be overridden by subclass.
@@ -154,35 +386,56 @@ class BaseConnector(ABC):
154
386
  dictionary containing metadata
155
387
  """
156
388
 
157
- @property
158
389
  @abstractmethod
390
+ def _list_symbols(self, libname: AllLibs) -> List[str]:
391
+ """Return list of symbol names in library."""
392
+
393
+ @property
159
394
  def oseries_names(self):
160
395
  """List of oseries names.
161
396
 
162
397
  Property must be overridden by subclass.
163
398
  """
399
+ return self._list_symbols("oseries")
164
400
 
165
401
  @property
166
- @abstractmethod
167
402
  def stresses_names(self):
168
403
  """List of stresses names.
169
404
 
170
405
  Property must be overridden by subclass.
171
406
  """
407
+ return self._list_symbols("stresses")
172
408
 
173
409
  @property
174
- @abstractmethod
175
410
  def model_names(self):
176
411
  """List of model names.
177
412
 
178
413
  Property must be overridden by subclass.
179
414
  """
415
+ return self._modelnames_cache
416
+
417
+ @property
418
+ def oseries_with_models(self):
419
+ """List of oseries used in models.
420
+
421
+ Property must be overridden by subclass.
422
+ """
423
+ return self._list_symbols("oseries_models")
424
+
425
+ @property
426
+ def stresses_with_models(self):
427
+ """List of stresses used in models.
428
+
429
+ Property must be overridden by subclass.
430
+ """
431
+ return self._list_symbols("stresses_models")
180
432
 
181
433
  @abstractmethod
182
434
  def _parallel(
183
435
  self,
184
436
  func: Callable,
185
437
  names: List[str],
438
+ kwargs: Optional[Dict] = None,
186
439
  progressbar: Optional[bool] = True,
187
440
  max_workers: Optional[int] = None,
188
441
  chunksize: Optional[int] = None,
@@ -198,6 +451,8 @@ class BaseConnector(ABC):
198
451
  function to apply in parallel
199
452
  names : list
200
453
  list of names to apply function to
454
+ kwargs : dict
455
+ additional keyword arguments to pass to function
201
456
  progressbar : bool, optional
202
457
  show progressbar, by default True
203
458
  max_workers : int, optional
@@ -208,72 +463,120 @@ class BaseConnector(ABC):
208
463
  description for progressbar, by default ""
209
464
  """
210
465
 
211
- def set_check_model_series_values(self, b: bool):
212
- """Turn CHECK_MODEL_SERIES_VALUES option on (True) or off (False).
466
+ def parse_names(
467
+ self,
468
+ names: list[str] | str | None = None,
469
+ libname: AllLibs = "oseries",
470
+ ) -> list:
471
+ """Parse names argument and return list of names.
213
472
 
214
- The default option is on (it is highly recommended to keep it that
215
- way). When turned on, the model time series
216
- (ml.oseries._series_original, and stressmodel.stress._series_original)
217
- values are checked against the stored copies in the database. If these
218
- do not match, an error is raised, and the model is not added to the
219
- database. This guarantees the stored model will be identical after
220
- loading from the database. This check is somewhat computationally
221
- expensive, which is why it can be turned on or off.
473
+ Public method that exposes name parsing functionality.
222
474
 
223
475
  Parameters
224
476
  ----------
225
- b : bool
226
- boolean indicating whether option should be turned on (True) or
227
- off (False). Option is on by default.
477
+ names : Union[list, str], optional
478
+ str or list of str or None or 'all' (last two options
479
+ retrieves all names)
480
+ libname : str, optional
481
+ name of library, default is 'oseries'
482
+
483
+ Returns
484
+ -------
485
+ list
486
+ list of names
228
487
  """
229
- self.CHECK_MODEL_SERIES_VALUES = b
230
- print(f"Model time series checking set to: {b}.")
488
+ return self._parse_names(names, libname)
231
489
 
232
- def set_use_pastas_validate_series(self, b: bool):
233
- """Turn USE_PASTAS_VALIDATE_SERIES option on (True) or off (False).
490
+ @property # type: ignore
491
+ @functools.lru_cache()
492
+ def oseries(self):
493
+ """Dataframe with overview of oseries."""
494
+ return self.get_metadata("oseries", self.oseries_names)
234
495
 
235
- This will use pastas.validate_oseries() or pastas.validate_stresses()
236
- to test the time series. If they do not meet the criteria, an error is
237
- raised. Turning this option off will allow the user to store any time
238
- series but this will mean that time series models cannot be made from
239
- stored time series directly and will have to be modified before
240
- building the models. This in turn will mean that storing the models
241
- will not work as the stored time series copy is checked against the
242
- time series in the model to check if they are equal.
496
+ @property # type: ignore
497
+ @functools.lru_cache()
498
+ def stresses(self):
499
+ """Dataframe with overview of stresses."""
500
+ return self.get_metadata("stresses", self.stresses_names)
243
501
 
244
- Note: this option requires pastas>=0.23.0, otherwise it is turned off.
502
+ @property # type: ignore
503
+ @functools.lru_cache()
504
+ def _modelnames_cache(self):
505
+ """List of model names."""
506
+ return self._list_symbols("models")
245
507
 
246
- Parameters
247
- ----------
248
- b : bool
249
- boolean indicating whether option should be turned on (True) or
250
- off (False). Option is on by default.
508
+ @property
509
+ def n_oseries(self):
251
510
  """
252
- self.USE_PASTAS_VALIDATE_SERIES = b
253
- print(f"Model time series checking set to: {b}.")
511
+ Returns the number of oseries.
254
512
 
255
- def _pastas_validate(self, validate):
256
- """Whether to validate time series.
513
+ Returns
514
+ -------
515
+ int
516
+ The number of oseries names.
517
+ """
518
+ return len(self.oseries_names)
257
519
 
258
- Parameters
259
- ----------
260
- validate : bool, NoneType
261
- value of validate keyword argument
520
+ @property
521
+ def n_stresses(self):
522
+ """
523
+ Returns the number of stresses.
262
524
 
263
525
  Returns
264
526
  -------
265
- b : bool
266
- return global or local setting (True or False)
527
+ int
528
+ The number of stresses.
267
529
  """
268
- if validate is None:
269
- return self.USE_PASTAS_VALIDATE_SERIES
270
- else:
271
- return validate
530
+ return len(self.stresses_names)
531
+
532
+ @property
533
+ def n_models(self):
534
+ """
535
+ Returns the number of models in the store.
536
+
537
+ Returns
538
+ -------
539
+ int
540
+ The number of models in the store.
541
+ """
542
+ return len(self.model_names)
543
+
544
+ @property # type: ignore
545
+ @functools.lru_cache()
546
+ def oseries_models(self):
547
+ """List of model names per oseries.
548
+
549
+ Returns
550
+ -------
551
+ d : dict
552
+ dictionary with oseries names as keys and list of model names as
553
+ values
554
+ """
555
+ d = {}
556
+ for onam in self.oseries_with_models:
557
+ d[onam] = self._get_item("oseries_models", onam)
558
+ return d
559
+
560
+ @property # type: ignore
561
+ @functools.lru_cache()
562
+ def stresses_models(self):
563
+ """List of model names per stress.
564
+
565
+ Returns
566
+ -------
567
+ d : dict
568
+ dictionary with stress names as keys and list of model names as
569
+ values
570
+ """
571
+ d = {}
572
+ for stress_name in self.stresses_with_models:
573
+ d[stress_name] = self._get_item("stresses_models", stress_name)
574
+ return d
272
575
 
273
576
  def _add_series(
274
577
  self,
275
- libname: str,
276
- series: FrameorSeriesUnion,
578
+ libname: TimeSeriesLibs,
579
+ series: FrameOrSeriesUnion,
277
580
  name: str,
278
581
  metadata: Optional[dict] = None,
279
582
  validate: Optional[bool] = None,
@@ -305,9 +608,9 @@ class BaseConnector(ABC):
305
608
  """
306
609
  if not isinstance(name, str):
307
610
  name = str(name)
308
- self._validate_input_series(series)
309
- series = self._set_series_name(series, name)
310
- if self._pastas_validate(validate):
611
+ self.validator.validate_input_series(series)
612
+ series = self.validator.set_series_name(series, name)
613
+ if self.validator.pastas_validation_status(validate):
311
614
  if libname == "oseries":
312
615
  if PASTAS_GEQ_150 and not ps.validate_oseries(series):
313
616
  raise ValueError(
@@ -326,10 +629,16 @@ class BaseConnector(ABC):
326
629
  ps.validate_stress(series)
327
630
  in_store = getattr(self, f"{libname}_names")
328
631
  if name not in in_store or overwrite:
329
- self._add_item(
330
- libname, series, name, metadata=metadata, overwrite=overwrite
331
- )
632
+ self._add_item(libname, series, name, metadata=metadata)
332
633
  self._clear_cache(libname)
634
+ elif (libname == "oseries" and name in self.oseries_models) or (
635
+ libname == "stresses" and name in self.stresses_models
636
+ ):
637
+ raise SeriesUsedByModel(
638
+ f"Time series with name '{name}' is used by a model! "
639
+ "Use overwrite=True to replace existing time series. "
640
+ "Note that this may modify the model!"
641
+ )
333
642
  else:
334
643
  raise ItemInLibraryException(
335
644
  f"Time series with name '{name}' already in '{libname}' library! "
@@ -338,11 +647,12 @@ class BaseConnector(ABC):
338
647
 
339
648
  def _update_series(
340
649
  self,
341
- libname: str,
342
- series: FrameorSeriesUnion,
650
+ libname: TimeSeriesLibs,
651
+ series: FrameOrSeriesUnion,
343
652
  name: str,
344
653
  metadata: Optional[dict] = None,
345
654
  validate: Optional[bool] = None,
655
+ force: bool = False,
346
656
  ) -> None:
347
657
  """Update time series (internal method).
348
658
 
@@ -360,11 +670,16 @@ class BaseConnector(ABC):
360
670
  validate: bool, optional
361
671
  use pastas to validate series, default is None, which will use the
362
672
  USE_PASTAS_VALIDATE_SERIES value (default is True).
673
+ force : bool, optional
674
+ force update even if time series is used in a model, by default False
675
+
363
676
  """
364
677
  if libname not in ["oseries", "stresses"]:
365
678
  raise ValueError("Library must be 'oseries' or 'stresses'!")
366
- self._validate_input_series(series)
367
- series = self._set_series_name(series, name)
679
+ if not force:
680
+ self.validator.check_series_in_models(libname, name)
681
+ self.validator.validate_input_series(series)
682
+ series = self.validator.set_series_name(series, name)
368
683
  stored = self._get_series(libname, name, progressbar=False)
369
684
  if self.conn_type == "pas" and not isinstance(series, type(stored)):
370
685
  if isinstance(series, pd.DataFrame):
@@ -389,11 +704,12 @@ class BaseConnector(ABC):
389
704
 
390
705
  def _upsert_series(
391
706
  self,
392
- libname: str,
393
- series: FrameorSeriesUnion,
707
+ libname: TimeSeriesLibs,
708
+ series: FrameOrSeriesUnion,
394
709
  name: str,
395
710
  metadata: Optional[dict] = None,
396
711
  validate: Optional[bool] = None,
712
+ force: bool = False,
397
713
  ) -> None:
398
714
  """Update or insert series depending on whether it exists in store.
399
715
 
@@ -410,19 +726,23 @@ class BaseConnector(ABC):
410
726
  validate : bool, optional
411
727
  use pastas to validate series, default is None, which will use the
412
728
  USE_PASTAS_VALIDATE_SERIES value (default is True).
729
+ force : bool, optional
730
+ force update even if time series is used in a model, by default False
413
731
  """
414
732
  if libname not in ["oseries", "stresses"]:
415
733
  raise ValueError("Library must be 'oseries' or 'stresses'!")
416
734
  if name in getattr(self, f"{libname}_names"):
417
735
  self._update_series(
418
- libname, series, name, metadata=metadata, validate=validate
736
+ libname, series, name, metadata=metadata, validate=validate, force=force
419
737
  )
420
738
  else:
421
739
  self._add_series(
422
740
  libname, series, name, metadata=metadata, validate=validate
423
741
  )
424
742
 
425
- def update_metadata(self, libname: str, name: str, metadata: dict) -> None:
743
+ def update_metadata(
744
+ self, libname: TimeSeriesLibs, name: str, metadata: dict
745
+ ) -> None:
426
746
  """Update metadata.
427
747
 
428
748
  Note: also retrieves and stores time series as updating only metadata
@@ -449,7 +769,7 @@ class BaseConnector(ABC):
449
769
 
450
770
  def add_oseries(
451
771
  self,
452
- series: FrameorSeriesUnion,
772
+ series: FrameOrSeriesUnion,
453
773
  name: str,
454
774
  metadata: Optional[dict] = None,
455
775
  validate: Optional[bool] = None,
@@ -472,7 +792,6 @@ class BaseConnector(ABC):
472
792
  overwrite existing dataset with the same name,
473
793
  by default False
474
794
  """
475
- series, metadata = self._parse_series_input(series, metadata)
476
795
  self._add_series(
477
796
  "oseries",
478
797
  series,
@@ -484,7 +803,7 @@ class BaseConnector(ABC):
484
803
 
485
804
  def add_stress(
486
805
  self,
487
- series: FrameorSeriesUnion,
806
+ series: FrameOrSeriesUnion,
488
807
  name: str,
489
808
  kind: str,
490
809
  metadata: Optional[dict] = None,
@@ -511,7 +830,6 @@ class BaseConnector(ABC):
511
830
  overwrite existing dataset with the same name,
512
831
  by default False
513
832
  """
514
- series, metadata = self._parse_series_input(series, metadata)
515
833
  if metadata is None:
516
834
  metadata = {}
517
835
  metadata["kind"] = kind
@@ -565,56 +883,84 @@ class BaseConnector(ABC):
565
883
  name = str(name)
566
884
  if name not in self.model_names or overwrite:
567
885
  # check if stressmodels supported
568
- self._check_stressmodels_supported(ml)
569
- # check if oseries and stresses exist in store
570
- self._check_model_series_names_for_store(ml)
571
- self._check_oseries_in_store(ml)
572
- self._check_stresses_in_store(ml)
886
+ self.validator.check_stressmodels_supported(ml)
887
+ # check oseries and stresses names and if they exist in store
888
+ self.validator.check_model_series_names_duplicates(ml)
889
+ self.validator.check_oseries_in_store(ml)
890
+ self.validator.check_stresses_in_store(ml)
573
891
  # write model to store
574
- self._add_item(
575
- "models", mldict, name, metadata=metadata, overwrite=overwrite
576
- )
892
+ self._add_item("models", mldict, name, metadata=metadata)
893
+ self._clear_cache("_modelnames_cache")
894
+ self._add_oseries_model_links(str(mldict["oseries"]["name"]), name)
895
+ self._add_stresses_model_links(self._get_model_stress_names(mldict), name)
577
896
  else:
578
897
  raise ItemInLibraryException(
579
898
  f"Model with name '{name}' already in 'models' library! "
580
899
  "Use overwrite=True to replace existing model."
581
900
  )
582
- self._clear_cache("_modelnames_cache")
583
- self._add_oseries_model_links(str(mldict["oseries"]["name"]), name)
584
901
 
585
- @staticmethod
586
- def _parse_series_input(
587
- series: FrameorSeriesUnion,
588
- metadata: Optional[Dict] = None,
589
- ) -> Tuple[FrameorSeriesUnion, Optional[Dict]]:
590
- """Parse series input (internal method).
902
+ def _update_series(
903
+ self,
904
+ libname: str,
905
+ series: FrameOrSeriesUnion,
906
+ name: str,
907
+ metadata: Optional[dict] = None,
908
+ validate: Optional[bool] = None,
909
+ force: bool = False,
910
+ ) -> None:
911
+ """Update time series (internal method).
591
912
 
592
913
  Parameters
593
914
  ----------
594
- series : FrameorSeriesUnion,
595
- series object to parse
596
- metadata : dict, optional
597
- metadata dictionary or None, by default None
598
-
599
- Returns
600
- -------
601
- series, metadata : FrameorSeriesUnion, Optional[Dict]
602
- time series as pandas.Series or DataFrame and optionally
603
- metadata dictionary
915
+ libname : str
916
+ name of library
917
+ series : FrameorSeriesUnion
918
+ time series containing update values
919
+ name : str
920
+ name of the time series to update
921
+ metadata : Optional[dict], optional
922
+ optionally provide metadata dictionary which will also update
923
+ the current stored metadata dictionary, by default None
924
+ validate: bool, optional
925
+ use pastas to validate series, default is None, which will use the
926
+ USE_PASTAS_VALIDATE_SERIES value (default is True).
927
+ force : bool, optional
928
+ force update even if time series is used in a model, by default False
604
929
  """
605
- if isinstance(series, ps.timeseries.TimeSeries):
606
- raise DeprecationWarning(
607
- "Pastas TimeSeries objects are no longer supported!"
608
- )
609
- s = series
610
- m = metadata
611
- return s, m
930
+ if libname not in ["oseries", "stresses"]:
931
+ raise ValueError("Library must be 'oseries' or 'stresses'!")
932
+ if not force:
933
+ self.validator.check_series_in_models(libname, name)
934
+ self.validator.validate_input_series(series)
935
+ series = self.validator.set_series_name(series, name)
936
+ stored = self._get_series(libname, name, progressbar=False)
937
+ if self.conn_type == "pas" and not isinstance(series, type(stored)):
938
+ if isinstance(series, pd.DataFrame):
939
+ stored = stored.to_frame()
940
+ # get union of index
941
+ idx_union = stored.index.union(series.index)
942
+ # update series with new values
943
+ update = stored.reindex(idx_union)
944
+ update.update(series)
945
+ # metadata
946
+ update_meta = self._get_metadata(libname, name)
947
+ if metadata is not None:
948
+ update_meta.update(metadata)
949
+ self._add_series(
950
+ libname,
951
+ update,
952
+ name,
953
+ metadata=update_meta,
954
+ validate=validate,
955
+ overwrite=True,
956
+ )
612
957
 
613
958
  def update_oseries(
614
959
  self,
615
- series: FrameorSeriesUnion,
960
+ series: FrameOrSeriesUnion,
616
961
  name: str,
617
962
  metadata: Optional[dict] = None,
963
+ force: bool = False,
618
964
  ) -> None:
619
965
  """Update oseries values.
620
966
 
@@ -627,61 +973,67 @@ class BaseConnector(ABC):
627
973
  metadata : Optional[dict], optional
628
974
  optionally provide metadata, which will update
629
975
  the stored metadata dictionary, by default None
976
+ force : bool, optional
977
+ force update even if time series is used in a model, by default False
630
978
  """
631
- series, metadata = self._parse_series_input(series, metadata)
632
- self._update_series("oseries", series, name, metadata=metadata)
979
+ self._update_series("oseries", series, name, metadata=metadata, force=force)
633
980
 
634
- def upsert_oseries(
981
+ def update_stress(
635
982
  self,
636
- series: FrameorSeriesUnion,
983
+ series: FrameOrSeriesUnion,
637
984
  name: str,
638
985
  metadata: Optional[dict] = None,
986
+ force: bool = False,
639
987
  ) -> None:
640
- """Update or insert oseries values depending on whether it exists.
988
+ """Update stresses values.
989
+
990
+ Note: the 'kind' attribute of a stress cannot be updated! To update
991
+ the 'kind' delete and add the stress again.
641
992
 
642
993
  Parameters
643
994
  ----------
644
995
  series : FrameorSeriesUnion
645
- time series to update/insert
996
+ time series to update stored stress with
646
997
  name : str
647
- name of the oseries
998
+ name of the stress to update
648
999
  metadata : Optional[dict], optional
649
1000
  optionally provide metadata, which will update
650
- the stored metadata dictionary if it exists, by default None
1001
+ the stored metadata dictionary, by default None
1002
+ force : bool, optional
1003
+ force update even if time series is used in a model, by default False
651
1004
  """
652
- series, metadata = self._parse_series_input(series, metadata)
653
- self._upsert_series("oseries", series, name, metadata=metadata)
1005
+ self._update_series("stresses", series, name, metadata=metadata, force=force)
654
1006
 
655
- def update_stress(
1007
+ def upsert_oseries(
656
1008
  self,
657
- series: FrameorSeriesUnion,
1009
+ series: FrameOrSeriesUnion,
658
1010
  name: str,
659
1011
  metadata: Optional[dict] = None,
1012
+ force: bool = False,
660
1013
  ) -> None:
661
- """Update stresses values.
662
-
663
- Note: the 'kind' attribute of a stress cannot be updated! To update
664
- the 'kind' delete and add the stress again.
1014
+ """Update or insert oseries values depending on whether it exists.
665
1015
 
666
1016
  Parameters
667
1017
  ----------
668
1018
  series : FrameorSeriesUnion
669
- time series to update stored stress with
1019
+ time series to update/insert
670
1020
  name : str
671
- name of the stress to update
1021
+ name of the oseries
672
1022
  metadata : Optional[dict], optional
673
1023
  optionally provide metadata, which will update
674
- the stored metadata dictionary, by default None
1024
+ the stored metadata dictionary if it exists, by default None
1025
+ force : bool, optional
1026
+ force update even if time series is used in a model, by default False
675
1027
  """
676
- series, metadata = self._parse_series_input(series, metadata)
677
- self._update_series("stresses", series, name, metadata=metadata)
1028
+ self._upsert_series("oseries", series, name, metadata=metadata, force=force)
678
1029
 
679
1030
  def upsert_stress(
680
1031
  self,
681
- series: FrameorSeriesUnion,
1032
+ series: FrameOrSeriesUnion,
682
1033
  name: str,
683
1034
  kind: str,
684
1035
  metadata: Optional[dict] = None,
1036
+ force: bool = False,
685
1037
  ) -> None:
686
1038
  """Update or insert stress values depending on whether it exists.
687
1039
 
@@ -694,12 +1046,16 @@ class BaseConnector(ABC):
694
1046
  metadata : Optional[dict], optional
695
1047
  optionally provide metadata, which will update
696
1048
  the stored metadata dictionary if it exists, by default None
1049
+ kind : str
1050
+ category to identify type of stress, this label is added to the
1051
+ metadata dictionary.
1052
+ force : bool, optional
1053
+ force update even if time series is used in a model, by default False
697
1054
  """
698
- series, metadata = self._parse_series_input(series, metadata)
699
1055
  if metadata is None:
700
1056
  metadata = {}
701
1057
  metadata["kind"] = kind
702
- self._upsert_series("stresses", series, name, metadata=metadata)
1058
+ self._upsert_series("stresses", series, name, metadata=metadata, force=force)
703
1059
 
704
1060
  def del_models(self, names: Union[list, str], verbose: bool = True) -> None:
705
1061
  """Delete model(s) from the database.
@@ -717,9 +1073,10 @@ class BaseConnector(ABC):
717
1073
  oname = mldict["oseries"]["name"]
718
1074
  self._del_item("models", n)
719
1075
  self._del_oseries_model_link(oname, n)
1076
+ self._del_stress_model_link(self._get_model_stress_names(mldict), n)
720
1077
  self._clear_cache("_modelnames_cache")
721
1078
  if verbose:
722
- print(f"Deleted {len(names)} model(s) from database.")
1079
+ logger.info("Deleted %d model(s) from database.", len(names))
723
1080
 
724
1081
  def del_model(self, names: Union[list, str], verbose: bool = True) -> None:
725
1082
  """Delete model(s) from the database.
@@ -736,7 +1093,11 @@ class BaseConnector(ABC):
736
1093
  self.del_models(names=names, verbose=verbose)
737
1094
 
738
1095
  def del_oseries(
739
- self, names: Union[list, str], remove_models: bool = False, verbose: bool = True
1096
+ self,
1097
+ names: Union[list, str],
1098
+ remove_models: bool = False,
1099
+ force: bool = False,
1100
+ verbose: bool = True,
740
1101
  ):
741
1102
  """Delete oseries from the database.
742
1103
 
@@ -746,38 +1107,60 @@ class BaseConnector(ABC):
746
1107
  name(s) of the oseries to delete
747
1108
  remove_models : bool, optional
748
1109
  also delete models for deleted oseries, default is False
1110
+ force : bool, optional
1111
+ force deletion of oseries that are used in models, by default False
749
1112
  verbose : bool, optional
750
1113
  print information about deleted oseries, by default True
751
1114
  """
752
1115
  names = self._parse_names(names, libname="oseries")
753
1116
  for n in names:
754
- self._del_item("oseries", n)
1117
+ self._del_item("oseries", n, force=force)
755
1118
  self._clear_cache("oseries")
756
1119
  if verbose:
757
- print(f"Deleted {len(names)} oseries from database.")
1120
+ logger.info("Deleted %d oseries from database.", len(names))
758
1121
  # remove associated models from database
759
1122
  if remove_models:
760
1123
  modelnames = list(
761
1124
  chain.from_iterable([self.oseries_models.get(n, []) for n in names])
762
1125
  )
763
1126
  self.del_models(modelnames, verbose=verbose)
1127
+ if verbose:
1128
+ logger.info("Deleted %d models(s) from database.", len(modelnames))
764
1129
 
765
- def del_stress(self, names: Union[list, str], verbose: bool = True):
1130
+ def del_stress(
1131
+ self,
1132
+ names: Union[list, str],
1133
+ remove_models: bool = False,
1134
+ force: bool = False,
1135
+ verbose: bool = True,
1136
+ ):
766
1137
  """Delete stress from the database.
767
1138
 
768
1139
  Parameters
769
1140
  ----------
770
1141
  names : str or list of str
771
1142
  name(s) of the stress to delete
1143
+ remove_models : bool, optional
1144
+ also delete models for deleted stresses, default is False
1145
+ force : bool, optional
1146
+ force deletion of stresses that are used in models, by default False
772
1147
  verbose : bool, optional
773
1148
  print information about deleted stresses, by default True
774
1149
  """
775
1150
  names = self._parse_names(names, libname="stresses")
776
1151
  for n in names:
777
- self._del_item("stresses", n)
1152
+ self._del_item("stresses", n, force=force)
778
1153
  self._clear_cache("stresses")
779
1154
  if verbose:
780
- print(f"Deleted {len(names)} stress(es) from database.")
1155
+ logger.info("Deleted %d stress(es) from database.", len(names))
1156
+ # remove associated models from database
1157
+ if remove_models:
1158
+ modelnames = list(
1159
+ chain.from_iterable([self.stresses_models.get(n, []) for n in names])
1160
+ )
1161
+ self.del_models(modelnames, verbose=verbose)
1162
+ if verbose:
1163
+ logger.info("Deleted %d models(s) from database.", len(modelnames))
781
1164
 
782
1165
  def _get_series(
783
1166
  self,
@@ -785,7 +1168,7 @@ class BaseConnector(ABC):
785
1168
  names: Union[list, str],
786
1169
  progressbar: bool = True,
787
1170
  squeeze: bool = True,
788
- ) -> FrameorSeriesUnion:
1171
+ ) -> FrameOrSeriesUnion:
789
1172
  """Get time series (internal method).
790
1173
 
791
1174
  Parameters
@@ -809,6 +1192,7 @@ class BaseConnector(ABC):
809
1192
  ts = {}
810
1193
  names = self._parse_names(names, libname=libname)
811
1194
  desc = f"Get {libname}"
1195
+ n = None
812
1196
  for n in tqdm(names, desc=desc) if progressbar else names:
813
1197
  ts[n] = self._get_item(libname, n)
814
1198
  # return frame if len == 1
@@ -865,7 +1249,7 @@ class BaseConnector(ABC):
865
1249
  return_metadata: bool = False,
866
1250
  progressbar: bool = False,
867
1251
  squeeze: bool = True,
868
- ) -> Union[Union[FrameorSeriesUnion, Dict], Optional[Union[Dict, List]]]:
1252
+ ) -> Union[Union[FrameOrSeriesUnion, Dict], Optional[Union[Dict, List]]]:
869
1253
  """Get oseries from database.
870
1254
 
871
1255
  Parameters
@@ -910,7 +1294,7 @@ class BaseConnector(ABC):
910
1294
  return_metadata: bool = False,
911
1295
  progressbar: bool = False,
912
1296
  squeeze: bool = True,
913
- ) -> Union[Union[FrameorSeriesUnion, Dict], Optional[Union[Dict, List]]]:
1297
+ ) -> Union[Union[FrameOrSeriesUnion, Dict], Optional[Union[Dict, List]]]:
914
1298
  """Get stresses from database.
915
1299
 
916
1300
  Parameters
@@ -955,7 +1339,7 @@ class BaseConnector(ABC):
955
1339
  return_metadata: bool = False,
956
1340
  progressbar: bool = False,
957
1341
  squeeze: bool = True,
958
- ) -> Union[Union[FrameorSeriesUnion, Dict], Optional[Union[Dict, List]]]:
1342
+ ) -> Union[Union[FrameOrSeriesUnion, Dict], Optional[Union[Dict, List]]]:
959
1343
  """Get stresses from database.
960
1344
 
961
1345
  Alias for `get_stresses()`
@@ -1078,7 +1462,7 @@ class BaseConnector(ABC):
1078
1462
  )
1079
1463
 
1080
1464
  def empty_library(
1081
- self, libname: str, prompt: bool = True, progressbar: bool = True
1465
+ self, libname: AllLibs, prompt: bool = True, progressbar: bool = True
1082
1466
  ):
1083
1467
  """Empty library of all its contents.
1084
1468
 
@@ -1114,11 +1498,13 @@ class BaseConnector(ABC):
1114
1498
  if progressbar
1115
1499
  else names
1116
1500
  ):
1117
- self._del_item(libname, name)
1501
+ self._del_item(libname, name, force=True)
1118
1502
  self._clear_cache(libname)
1119
- print(f"Emptied library {libname} in {self.name}: {self.__class__}")
1503
+ logger.info(
1504
+ "Emptied library %s in %s: %s", libname, self.name, self.__class__
1505
+ )
1120
1506
 
1121
- def _iter_series(self, libname: str, names: Optional[List[str]] = None):
1507
+ def _iter_series(self, libname: TimeSeriesLibs, names: Optional[List[str]] = None):
1122
1508
  """Iterate over time series in library (internal method).
1123
1509
 
1124
1510
  Parameters
@@ -1196,34 +1582,64 @@ class BaseConnector(ABC):
1196
1582
  for mlnam in modelnames:
1197
1583
  yield self.get_models(mlnam, return_dict=return_dict, progressbar=False)
1198
1584
 
1199
- def _add_oseries_model_links(self, onam: str, mlnames: Union[str, List[str]]):
1585
+ def _add_oseries_model_links(
1586
+ self, oseries_name: str, model_names: Union[str, List[str]]
1587
+ ):
1200
1588
  """Add model name to stored list of models per oseries.
1201
1589
 
1202
1590
  Parameters
1203
1591
  ----------
1204
- onam : str
1592
+ oseries_name : str
1205
1593
  name of oseries
1206
- mlnames : Union[str, List[str]]
1594
+ model_names : Union[str, List[str]]
1207
1595
  model name or list of model names for an oseries with name
1208
- onam.
1596
+ oseries_name.
1209
1597
  """
1210
1598
  # get stored list of model names
1211
- if str(onam) in self.oseries_with_models:
1212
- modellist = self._get_item("oseries_models", onam)
1599
+ if str(oseries_name) in self.oseries_with_models:
1600
+ modellist = self._get_item("oseries_models", oseries_name)
1213
1601
  else:
1214
1602
  # else empty list
1215
1603
  modellist = []
1216
1604
  # if one model name, make list for loop
1217
- if isinstance(mlnames, str):
1218
- mlnames = [mlnames]
1605
+ if isinstance(model_names, str):
1606
+ model_names = [model_names]
1219
1607
  # loop over model names
1220
- for iml in mlnames:
1608
+ for iml in model_names:
1221
1609
  # if not present, add to list
1222
1610
  if iml not in modellist:
1223
1611
  modellist.append(iml)
1224
- self._add_item("oseries_models", modellist, onam, overwrite=True)
1612
+ self._add_item("oseries_models", modellist, oseries_name)
1225
1613
  self._clear_cache("oseries_models")
1226
1614
 
1615
+ def _add_stresses_model_links(self, stress_names, model_names):
1616
+ """Add model name to stored list of models per stress.
1617
+
1618
+ Parameters
1619
+ ----------
1620
+ stress_names : list of str
1621
+ names of stresses
1622
+ model_names : Union[str, List[str]]
1623
+ model name or list of model names for a stress with name
1624
+ """
1625
+ # if one model name, make list for loop
1626
+ if isinstance(model_names, str):
1627
+ model_names = [model_names]
1628
+ for snam in stress_names:
1629
+ # get stored list of model names
1630
+ if str(snam) in self.stresses_with_models:
1631
+ modellist = self._get_item("stresses_models", snam)
1632
+ else:
1633
+ # else empty list
1634
+ modellist = []
1635
+ # loop over model names
1636
+ for iml in model_names:
1637
+ # if not present, add to list
1638
+ if iml not in modellist:
1639
+ modellist.append(iml)
1640
+ self._add_item("stresses_models", modellist, snam)
1641
+ self._clear_cache("stresses_models")
1642
+
1227
1643
  def _del_oseries_model_link(self, onam, mlnam):
1228
1644
  """Delete model name from stored list of models per oseries.
1229
1645
 
@@ -1239,128 +1655,183 @@ class BaseConnector(ABC):
1239
1655
  if len(modellist) == 0:
1240
1656
  self._del_item("oseries_models", onam)
1241
1657
  else:
1242
- self._add_item("oseries_models", modellist, onam, overwrite=True)
1658
+ self._add_item("oseries_models", modellist, onam)
1243
1659
  self._clear_cache("oseries_models")
1244
1660
 
1245
- def _update_all_oseries_model_links(self):
1246
- """Add all model names to oseries metadata dictionaries.
1661
+ def _del_stress_model_link(self, stress_names, model_name):
1662
+ """Delete model name from stored list of models per stress.
1663
+
1664
+ Parameters
1665
+ ----------
1666
+ stress_names : list of str
1667
+ List of stress names for which to remove the model link.
1668
+ model_name : str
1669
+ Name of the model to remove from the stress links.
1670
+ """
1671
+ for stress_name in stress_names:
1672
+ modellist = self._get_item("stresses_models", stress_name)
1673
+ modellist.remove(model_name)
1674
+ if len(modellist) == 0:
1675
+ self._del_item("stresses_models", stress_name)
1676
+ else:
1677
+ self._add_item("stresses_models", modellist, stress_name)
1678
+ self._clear_cache("stresses_models")
1679
+
1680
+ def _update_time_series_model_links(self):
1681
+ """Add all model names to reverse lookup time series dictionaries.
1247
1682
 
1248
- Used for old PastaStore versions, where relationship between oseries and models
1249
- was not stored. If there are any models in the database and if the
1250
- oseries_models library is empty, loops through all models to determine which
1251
- oseries each model belongs to.
1683
+ Used for old PastaStore versions, where relationship between time series and
1684
+ models was not stored. If there are any models in the database and if the
1685
+ oseries_models or stresses_models libraries are empty, loop through all models
1686
+ to determine which time series are used in each model.
1252
1687
  """
1253
- # get oseries_models library if there are any contents, if empty
1254
- # add all model links.
1688
+ # get oseries_models and stresses_models libraries,
1689
+ # if empty add all time series -> model links.
1255
1690
  if self.n_models > 0:
1256
- if len(self.oseries_models) == 0:
1257
- links = self._get_all_oseries_model_links()
1258
- for onam, mllinks in tqdm(
1259
- links.items(),
1260
- desc="Store models per oseries",
1261
- total=len(links),
1262
- ):
1263
- self._add_oseries_model_links(onam, mllinks)
1264
-
1265
- def _get_all_oseries_model_links(self):
1266
- """Get all model names per oseries in dictionary.
1691
+ if len(self.oseries_models) == 0 or len(self.stresses_models) == 0:
1692
+ links = self._get_time_series_model_links()
1693
+ for k in ["oseries", "stresses"]:
1694
+ for name, model_links in tqdm(
1695
+ links[k].items(),
1696
+ desc=f"Store models per {k}",
1697
+ total=len(links[k]),
1698
+ ):
1699
+ if k == "oseries":
1700
+ self._add_oseries_model_links(name, model_links)
1701
+ elif k == "stresses":
1702
+ self._add_stresses_model_links(name, model_links)
1703
+
1704
+ def _get_time_series_model_links(self):
1705
+ """Get model names per oseries and stresses time series in a dictionary.
1267
1706
 
1268
1707
  Returns
1269
1708
  -------
1270
1709
  links : dict
1271
- dictionary with oseries names as keys and lists of model names as
1272
- values
1710
+ dictionary with 'oseries' and 'stresses' as keys containing
1711
+ dictionaries with time series names as keys and lists of model
1712
+ names as values.
1273
1713
  """
1274
- links = {}
1714
+ oseries_links = {}
1715
+ stresses_links = {}
1275
1716
  for mldict in tqdm(
1276
1717
  self.iter_models(return_dict=True),
1277
1718
  total=self.n_models,
1278
- desc="Get models per oseries",
1719
+ desc="Get models per time series",
1279
1720
  ):
1280
- onam = mldict["oseries"]["name"]
1281
1721
  mlnam = mldict["name"]
1282
- if onam in links:
1283
- links[onam].append(mlnam)
1722
+ # oseries
1723
+ onam = mldict["oseries"]["name"]
1724
+ if onam in oseries_links:
1725
+ oseries_links[onam].append(mlnam)
1284
1726
  else:
1285
- links[onam] = [mlnam]
1286
- return links
1287
-
1288
- @staticmethod
1289
- def _clear_cache(libname: str) -> None:
1290
- """Clear cached property."""
1291
- if libname == "models":
1292
- libname = "_modelnames_cache"
1293
- getattr(BaseConnector, libname).fget.cache_clear()
1294
-
1295
- @property # type: ignore
1296
- @functools.lru_cache()
1297
- def oseries(self):
1298
- """Dataframe with overview of oseries."""
1299
- return self.get_metadata("oseries", self.oseries_names)
1300
-
1301
- @property # type: ignore
1302
- @functools.lru_cache()
1303
- def stresses(self):
1304
- """Dataframe with overview of stresses."""
1305
- return self.get_metadata("stresses", self.stresses_names)
1306
-
1307
- @property # type: ignore
1308
- @functools.lru_cache()
1309
- def _modelnames_cache(self):
1310
- """List of model names."""
1311
- return self.model_names
1312
-
1313
- @property
1314
- def n_oseries(self):
1315
- """
1316
- Returns the number of oseries.
1727
+ oseries_links[onam] = [mlnam]
1728
+ # stresses
1729
+ stress_names = self._get_model_stress_names(mldict)
1730
+ for snam in stress_names:
1731
+ if snam in stresses_links:
1732
+ stresses_links[snam].append(mlnam)
1733
+ else:
1734
+ stresses_links[snam] = [mlnam]
1735
+ return {"oseries": oseries_links, "stresses": stresses_links}
1317
1736
 
1318
- Returns
1319
- -------
1320
- int
1321
- The number of oseries names.
1322
- """
1323
- return len(self.oseries_names)
1737
+ def _get_model_stress_names(self, ml: ps.Model | dict) -> List[str]:
1738
+ """Get list of stress names used in model.
1324
1739
 
1325
- @property
1326
- def n_stresses(self):
1327
- """
1328
- Returns the number of stresses.
1740
+ Parameters
1741
+ ----------
1742
+ ml : pastas.Model or dict
1743
+ model to get stress names from
1329
1744
 
1330
1745
  Returns
1331
1746
  -------
1332
- int
1333
- The number of stresses.
1747
+ list of str
1748
+ list of stress names used in model
1334
1749
  """
1335
- return len(self.stresses_names)
1750
+ stresses = []
1751
+ if isinstance(ml, dict):
1752
+ for sm in ml["stressmodels"].values():
1753
+ class_key = "class"
1754
+ if sm[class_key] == "RechargeModel":
1755
+ stresses.append(sm["prec"]["name"])
1756
+ stresses.append(sm["evap"]["name"])
1757
+ if sm["temp"] is not None:
1758
+ stresses.append(sm["temp"]["name"])
1759
+ elif "stress" in sm:
1760
+ smstress = sm["stress"]
1761
+ if isinstance(smstress, dict):
1762
+ smstress = [smstress]
1763
+ for s in smstress:
1764
+ stresses.append(s["name"])
1765
+ else:
1766
+ for sm in ml.stressmodels.values():
1767
+ # Check class name using type instead of protected _name attribute
1768
+ if type(sm).__name__ == "RechargeModel":
1769
+ stresses.append(sm.prec.name)
1770
+ stresses.append(sm.evap.name)
1771
+ if sm.temp is not None:
1772
+ stresses.append(sm.temp.name)
1773
+ elif hasattr(sm, "stress"):
1774
+ smstress = sm.stress
1775
+ if not isinstance(smstress, list):
1776
+ smstress = [smstress]
1777
+ for s in smstress:
1778
+ stresses.append(s.name)
1779
+ return list(set(stresses))
1780
+
1781
+ def get_model_time_series_names(
1782
+ self,
1783
+ modelnames: Optional[Union[list, str]] = None,
1784
+ dropna: bool = True,
1785
+ progressbar: bool = True,
1786
+ ) -> FrameOrSeriesUnion:
1787
+ """Get time series names contained in model.
1336
1788
 
1337
- @property
1338
- def n_models(self):
1339
- """
1340
- Returns the number of models in the store.
1789
+ Parameters
1790
+ ----------
1791
+ modelnames : Optional[Union[list, str]], optional
1792
+ list or name of models to get time series names for,
1793
+ by default None which will use all modelnames
1794
+ dropna : bool, optional
1795
+ drop stresses from table if stress is not included in any
1796
+ model, by default True
1797
+ progressbar : bool, optional
1798
+ show progressbar, by default True
1341
1799
 
1342
1800
  Returns
1343
1801
  -------
1344
- int
1345
- The number of models in the store.
1802
+ structure : pandas.DataFrame
1803
+ returns DataFrame with oseries name per model, and a flag
1804
+ indicating whether a stress is contained within a time series
1805
+ model.
1346
1806
  """
1347
- return len(self.model_names)
1807
+ model_names = self._parse_names(modelnames, libname="models")
1808
+ structure = pd.DataFrame(
1809
+ index=model_names, columns=["oseries"] + self.stresses_names
1810
+ )
1811
+ structure.index.name = "model"
1348
1812
 
1349
- @property # type: ignore
1350
- @functools.lru_cache()
1351
- def oseries_models(self):
1352
- """List of model names per oseries.
1813
+ for mlnam in (
1814
+ tqdm(model_names, desc="Get model time series names")
1815
+ if progressbar
1816
+ else model_names
1817
+ ):
1818
+ mldict = self.get_models(mlnam, return_dict=True)
1819
+ stresses_names = self._get_model_stress_names(mldict)
1820
+ # oseries
1821
+ structure.loc[mlnam, "oseries"] = mldict["oseries"]["name"]
1822
+ # stresses
1823
+ structure.loc[mlnam, stresses_names] = 1
1824
+ if dropna:
1825
+ return structure.dropna(how="all", axis=1)
1826
+ else:
1827
+ return structure
1353
1828
 
1354
- Returns
1355
- -------
1356
- d : dict
1357
- dictionary with oseries names as keys and list of model names as
1358
- values
1359
- """
1360
- d = {}
1361
- for onam in self.oseries_with_models:
1362
- d[onam] = self._get_item("oseries_models", onam)
1363
- return d
1829
+ @staticmethod
1830
+ def _clear_cache(libname: AllLibs) -> None:
1831
+ """Clear cached property."""
1832
+ if libname == "models":
1833
+ libname = "_modelnames_cache"
1834
+ getattr(BaseConnector, libname).fget.cache_clear()
1364
1835
 
1365
1836
 
1366
1837
  class ModelAccessor:
@@ -1412,7 +1883,7 @@ class ModelAccessor:
1412
1883
  """Representation contains the number of models and the list of model names."""
1413
1884
  return (
1414
1885
  f"<{self.__class__.__name__}> {len(self)} model(s): \n"
1415
- + self.conn._modelnames_cache.__repr__()
1886
+ + self.conn.model_names.__repr__()
1416
1887
  )
1417
1888
 
1418
1889
  def __getitem__(self, name: str):
@@ -1463,9 +1934,7 @@ class ModelAccessor:
1463
1934
  pastas.Model
1464
1935
  A random model object from the connection.
1465
1936
  """
1466
- from random import choice
1467
-
1468
- return self.conn.get_models(choice(self.conn._modelnames_cache))
1937
+ return self.conn.get_models(choice(self.conn.model_names))
1469
1938
 
1470
1939
  @property
1471
1940
  def metadata(self):