pastastore 1.7.2__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pastastore/connectors.py CHANGED
@@ -1,27 +1,755 @@
1
1
  """Module containing classes for connecting to different data stores."""
2
2
 
3
3
  import json
4
+ import logging
4
5
  import os
5
6
  import warnings
7
+ from collections.abc import Iterable
8
+ from concurrent.futures import ProcessPoolExecutor
6
9
  from copy import deepcopy
7
- from typing import Dict, Optional, Union
10
+ from functools import partial
11
+
12
+ # import weakref
13
+ from typing import Callable, Dict, List, Optional, Tuple, Union
8
14
 
9
15
  import pandas as pd
16
+ import pastas as ps
17
+ from numpy import isin
18
+ from packaging.version import parse as parse_version
19
+ from pandas.testing import assert_series_equal
10
20
  from pastas.io.pas import PastasEncoder, pastas_hook
21
+ from tqdm.auto import tqdm
22
+ from tqdm.contrib.concurrent import process_map
11
23
 
12
- from pastastore.base import BaseConnector, ConnectorUtil, ModelAccessor
24
+ from pastastore.base import BaseConnector, ModelAccessor
13
25
  from pastastore.util import _custom_warning
26
+ from pastastore.version import PASTAS_LEQ_022
14
27
 
15
28
  FrameorSeriesUnion = Union[pd.DataFrame, pd.Series]
16
29
  warnings.showwarning = _custom_warning
17
30
 
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class ConnectorUtil:
35
+ """Mix-in class for general Connector helper functions.
36
+
37
+ Only for internal methods, and not methods that are related to CRUD operations on
38
+ database.
39
+ """
40
+
41
+ def _parse_names(
42
+ self,
43
+ names: Optional[Union[list, str]] = None,
44
+ libname: Optional[str] = "oseries",
45
+ ) -> list:
46
+ """Parse names kwarg, returns iterable with name(s) (internal method).
47
+
48
+ Parameters
49
+ ----------
50
+ names : Union[list, str], optional
51
+ str or list of str or None or 'all' (last two options
52
+ retrieves all names)
53
+ libname : str, optional
54
+ name of library, default is 'oseries'
55
+
56
+ Returns
57
+ -------
58
+ list
59
+ list of names
60
+ """
61
+ if not isinstance(names, str) and isinstance(names, Iterable):
62
+ return names
63
+ elif isinstance(names, str) and names != "all":
64
+ return [names]
65
+ elif names is None or names == "all":
66
+ if libname == "oseries":
67
+ return self.oseries_names
68
+ elif libname == "stresses":
69
+ return self.stresses_names
70
+ elif libname == "models":
71
+ return self.model_names
72
+ elif libname == "oseries_models":
73
+ return self.oseries_with_models
74
+ else:
75
+ raise ValueError(f"No library '{libname}'!")
76
+ else:
77
+ raise NotImplementedError(f"Cannot parse 'names': {names}")
78
+
79
+ @staticmethod
80
+ def _meta_list_to_frame(metalist: list, names: list):
81
+ """Convert list of metadata dictionaries to DataFrame.
82
+
83
+ Parameters
84
+ ----------
85
+ metalist : list
86
+ list of metadata dictionaries
87
+ names : list
88
+ list of names corresponding to data in metalist
89
+
90
+ Returns
91
+ -------
92
+ pandas.DataFrame
93
+ DataFrame containing overview of metadata
94
+ """
95
+ # convert to dataframe
96
+ if len(metalist) > 1:
97
+ meta = pd.DataFrame(metalist)
98
+ if len({"x", "y"}.difference(meta.columns)) == 0:
99
+ meta["x"] = meta["x"].astype(float)
100
+ meta["y"] = meta["y"].astype(float)
101
+ elif len(metalist) == 1:
102
+ meta = pd.DataFrame(metalist)
103
+ elif len(metalist) == 0:
104
+ meta = pd.DataFrame()
105
+
106
+ meta.index = names
107
+ meta.index.name = "name"
108
+ return meta
109
+
110
+ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
111
+ """Parse dictionary describing pastas models (internal method).
112
+
113
+ Parameters
114
+ ----------
115
+ mdict : dict
116
+ dictionary describing pastas.Model
117
+ update_ts_settings : bool, optional
118
+ update stored tmin and tmax in time series settings
119
+ based on time series loaded from store.
120
+
121
+ Returns
122
+ -------
123
+ ml : pastas.Model
124
+ time series analysis model
125
+ """
126
+ PASFILE_LEQ_022 = parse_version(
127
+ mdict["file_info"]["pastas_version"]
128
+ ) <= parse_version("0.22.0")
129
+
130
+ # oseries
131
+ if "series" not in mdict["oseries"]:
132
+ name = str(mdict["oseries"]["name"])
133
+ if name not in self.oseries.index:
134
+ msg = "oseries '{}' not present in library".format(name)
135
+ raise LookupError(msg)
136
+ mdict["oseries"]["series"] = self.get_oseries(name).squeeze()
137
+ # update tmin/tmax from time series
138
+ if update_ts_settings:
139
+ mdict["oseries"]["settings"]["tmin"] = mdict["oseries"]["series"].index[
140
+ 0
141
+ ]
142
+ mdict["oseries"]["settings"]["tmax"] = mdict["oseries"]["series"].index[
143
+ -1
144
+ ]
145
+
146
+ # StressModel, WellModel
147
+ for ts in mdict["stressmodels"].values():
148
+ if "stress" in ts.keys():
149
+ # WellModel
150
+ classkey = "stressmodel" if PASFILE_LEQ_022 else "class"
151
+ if ts[classkey] == "WellModel":
152
+ for stress in ts["stress"]:
153
+ if "series" not in stress:
154
+ name = str(stress["name"])
155
+ if name in self.stresses.index:
156
+ stress["series"] = self.get_stresses(name).squeeze()
157
+ # update tmin/tmax from time series
158
+ if update_ts_settings:
159
+ stress["settings"]["tmin"] = stress["series"].index[
160
+ 0
161
+ ]
162
+ stress["settings"]["tmax"] = stress["series"].index[
163
+ -1
164
+ ]
165
+ # StressModel
166
+ else:
167
+ for stress in ts["stress"] if PASFILE_LEQ_022 else [ts["stress"]]:
168
+ if "series" not in stress:
169
+ name = str(stress["name"])
170
+ if name in self.stresses.index:
171
+ stress["series"] = self.get_stresses(name).squeeze()
172
+ # update tmin/tmax from time series
173
+ if update_ts_settings:
174
+ stress["settings"]["tmin"] = stress["series"].index[
175
+ 0
176
+ ]
177
+ stress["settings"]["tmax"] = stress["series"].index[
178
+ -1
179
+ ]
180
+
181
+ # RechargeModel, TarsoModel
182
+ if ("prec" in ts.keys()) and ("evap" in ts.keys()):
183
+ for stress in [ts["prec"], ts["evap"]]:
184
+ if "series" not in stress:
185
+ name = str(stress["name"])
186
+ if name in self.stresses.index:
187
+ stress["series"] = self.get_stresses(name).squeeze()
188
+ # update tmin/tmax from time series
189
+ if update_ts_settings:
190
+ stress["settings"]["tmin"] = stress["series"].index[0]
191
+ stress["settings"]["tmax"] = stress["series"].index[-1]
192
+ else:
193
+ msg = "stress '{}' not present in library".format(name)
194
+ raise KeyError(msg)
195
+
196
+ # hack for pcov w dtype object (when filled with NaNs on store?)
197
+ if "fit" in mdict:
198
+ if "pcov" in mdict["fit"]:
199
+ pcov = mdict["fit"]["pcov"]
200
+ if pcov.dtypes.apply(lambda dtyp: isinstance(dtyp, object)).any():
201
+ mdict["fit"]["pcov"] = pcov.astype(float)
202
+
203
+ # check pastas version vs pas-file version
204
+ file_version = mdict["file_info"]["pastas_version"]
205
+
206
+ # check file version and pastas version
207
+ # if file<0.23 and pastas>=1.0 --> error
208
+ PASTAS_GT_023 = parse_version(ps.__version__) > parse_version("0.23.1")
209
+ if PASFILE_LEQ_022 and PASTAS_GT_023:
210
+ raise UserWarning(
211
+ f"This file was created with Pastas v{file_version} "
212
+ f"and cannot be loaded with Pastas v{ps.__version__} Please load and "
213
+ "save the file with Pastas 0.23 first to update the file "
214
+ "format."
215
+ )
216
+
217
+ try:
218
+ # pastas>=0.15.0
219
+ ml = ps.io.base._load_model(mdict)
220
+ except AttributeError:
221
+ # pastas<0.15.0
222
+ ml = ps.io.base.load_model(mdict)
223
+ return ml
224
+
225
+ @staticmethod
226
+ def _validate_input_series(series):
227
+ """Check if series is pandas.DataFrame or pandas.Series.
228
+
229
+ Parameters
230
+ ----------
231
+ series : object
232
+ object to validate
233
+
234
+ Raises
235
+ ------
236
+ TypeError
237
+ if object is not of type pandas.DataFrame or pandas.Series
238
+ """
239
+ if not (isinstance(series, pd.DataFrame) or isinstance(series, pd.Series)):
240
+ raise TypeError("Please provide pandas.DataFrame or pandas.Series!")
241
+ if isinstance(series, pd.DataFrame):
242
+ if series.columns.size > 1:
243
+ raise ValueError("Only DataFrames with one column are supported!")
244
+
245
+ @staticmethod
246
+ def _set_series_name(series, name):
247
+ """Set series name to match user defined name in store.
248
+
249
+ Parameters
250
+ ----------
251
+ series : pandas.Series or pandas.DataFrame
252
+ set name for this time series
253
+ name : str
254
+ name of the time series (used in the pastastore)
255
+ """
256
+ if isinstance(series, pd.Series):
257
+ series.name = name
258
+ # empty string on index name causes trouble when reading
259
+ # data from ArcticDB: TODO: check if still an issue?
260
+ if series.index.name == "":
261
+ series.index.name = None
262
+
263
+ if isinstance(series, pd.DataFrame):
264
+ series.columns = [name]
265
+ # check for hydropandas objects which are instances of DataFrame but
266
+ # do have a name attribute
267
+ if hasattr(series, "name"):
268
+ series.name = name
269
+ return series
270
+
271
+ @staticmethod
272
+ def _check_stressmodels_supported(ml):
273
+ supported_stressmodels = [
274
+ "StressModel",
275
+ "StressModel2",
276
+ "RechargeModel",
277
+ "WellModel",
278
+ "TarsoModel",
279
+ "Constant",
280
+ "LinearTrend",
281
+ "StepModel",
282
+ ]
283
+ if isinstance(ml, ps.Model):
284
+ smtyps = [sm._name for sm in ml.stressmodels.values()]
285
+ elif isinstance(ml, dict):
286
+ classkey = "stressmodel" if PASTAS_LEQ_022 else "class"
287
+ smtyps = [sm[classkey] for sm in ml["stressmodels"].values()]
288
+ check = isin(smtyps, supported_stressmodels)
289
+ if not all(check):
290
+ unsupported = set(smtyps) - set(supported_stressmodels)
291
+ raise NotImplementedError(
292
+ "PastaStore does not support storing models with the "
293
+ f"following stressmodels: {unsupported}"
294
+ )
295
+
296
+ @staticmethod
297
+ def _check_model_series_names_for_store(ml):
298
+ prec_evap_model = ["RechargeModel", "TarsoModel"]
299
+
300
+ if isinstance(ml, ps.Model):
301
+ series_names = [
302
+ istress.series.name
303
+ for sm in ml.stressmodels.values()
304
+ for istress in sm.stress
305
+ ]
306
+
307
+ elif isinstance(ml, dict):
308
+ # non RechargeModel, Tarsomodel, WellModel stressmodels
309
+ classkey = "stressmodel" if PASTAS_LEQ_022 else "class"
310
+ if PASTAS_LEQ_022:
311
+ series_names = [
312
+ istress["name"]
313
+ for sm in ml["stressmodels"].values()
314
+ if sm[classkey] not in (prec_evap_model + ["WellModel"])
315
+ for istress in sm["stress"]
316
+ ]
317
+ else:
318
+ series_names = [
319
+ sm["stress"]["name"]
320
+ for sm in ml["stressmodels"].values()
321
+ if sm[classkey] not in (prec_evap_model + ["WellModel"])
322
+ ]
323
+
324
+ # WellModel
325
+ if isin(
326
+ ["WellModel"],
327
+ [i[classkey] for i in ml["stressmodels"].values()],
328
+ ).any():
329
+ series_names += [
330
+ istress["name"]
331
+ for sm in ml["stressmodels"].values()
332
+ if sm[classkey] in ["WellModel"]
333
+ for istress in sm["stress"]
334
+ ]
335
+
336
+ # RechargeModel, TarsoModel
337
+ if isin(
338
+ prec_evap_model,
339
+ [i[classkey] for i in ml["stressmodels"].values()],
340
+ ).any():
341
+ series_names += [
342
+ istress["name"]
343
+ for sm in ml["stressmodels"].values()
344
+ if sm[classkey] in prec_evap_model
345
+ for istress in [sm["prec"], sm["evap"]]
346
+ ]
347
+
348
+ else:
349
+ raise TypeError("Expected pastas.Model or dict!")
350
+ if len(series_names) - len(set(series_names)) > 0:
351
+ msg = (
352
+ "There are multiple stresses series with the same name! "
353
+ "Each series name must be unique for the PastaStore!"
354
+ )
355
+ raise ValueError(msg)
356
+
357
+ def _check_oseries_in_store(self, ml: Union[ps.Model, dict]):
358
+ """Check if Model oseries are contained in PastaStore (internal method).
359
+
360
+ Parameters
361
+ ----------
362
+ ml : Union[ps.Model, dict]
363
+ pastas Model
364
+ """
365
+ if isinstance(ml, ps.Model):
366
+ name = ml.oseries.name
367
+ elif isinstance(ml, dict):
368
+ name = str(ml["oseries"]["name"])
369
+ else:
370
+ raise TypeError("Expected pastas.Model or dict!")
371
+ if name not in self.oseries.index:
372
+ msg = (
373
+ f"Cannot add model because oseries '{name}' is not contained in store."
374
+ )
375
+ raise LookupError(msg)
376
+ # expensive check
377
+ if self.CHECK_MODEL_SERIES_VALUES and isinstance(ml, ps.Model):
378
+ s_org = self.get_oseries(name).squeeze().dropna()
379
+ if PASTAS_LEQ_022:
380
+ so = ml.oseries.series_original
381
+ else:
382
+ so = ml.oseries._series_original
383
+ try:
384
+ assert_series_equal(
385
+ so.dropna(),
386
+ s_org,
387
+ atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE,
388
+ rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE,
389
+ )
390
+ except AssertionError as e:
391
+ raise ValueError(
392
+ f"Cannot add model because model oseries '{name}'"
393
+ " is different from stored oseries! See stacktrace for differences."
394
+ ) from e
395
+
396
+ def _check_stresses_in_store(self, ml: Union[ps.Model, dict]):
397
+ """Check if stresses time series are contained in PastaStore (internal method).
398
+
399
+ Parameters
400
+ ----------
401
+ ml : Union[ps.Model, dict]
402
+ pastas Model
403
+ """
404
+ prec_evap_model = ["RechargeModel", "TarsoModel"]
405
+ if isinstance(ml, ps.Model):
406
+ for sm in ml.stressmodels.values():
407
+ if sm._name in prec_evap_model:
408
+ stresses = [sm.prec, sm.evap]
409
+ else:
410
+ stresses = sm.stress
411
+ for s in stresses:
412
+ if str(s.name) not in self.stresses.index:
413
+ msg = (
414
+ f"Cannot add model because stress '{s.name}' "
415
+ "is not contained in store."
416
+ )
417
+ raise LookupError(msg)
418
+ if self.CHECK_MODEL_SERIES_VALUES:
419
+ s_org = self.get_stresses(s.name).squeeze()
420
+ if PASTAS_LEQ_022:
421
+ so = s.series_original
422
+ else:
423
+ so = s._series_original
424
+ try:
425
+ assert_series_equal(
426
+ so,
427
+ s_org,
428
+ atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE,
429
+ rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE,
430
+ )
431
+ except AssertionError as e:
432
+ raise ValueError(
433
+ f"Cannot add model because model stress "
434
+ f"'{s.name}' is different from stored stress! "
435
+ "See stacktrace for differences."
436
+ ) from e
437
+ elif isinstance(ml, dict):
438
+ for sm in ml["stressmodels"].values():
439
+ classkey = "stressmodel" if PASTAS_LEQ_022 else "class"
440
+ if sm[classkey] in prec_evap_model:
441
+ stresses = [sm["prec"], sm["evap"]]
442
+ elif sm[classkey] in ["WellModel"]:
443
+ stresses = sm["stress"]
444
+ else:
445
+ stresses = sm["stress"] if PASTAS_LEQ_022 else [sm["stress"]]
446
+ for s in stresses:
447
+ if str(s["name"]) not in self.stresses.index:
448
+ msg = (
449
+ f"Cannot add model because stress '{s['name']}' "
450
+ "is not contained in store."
451
+ )
452
+ raise LookupError(msg)
453
+ else:
454
+ raise TypeError("Expected pastas.Model or dict!")
455
+
456
+ def _stored_series_to_json(
457
+ self,
458
+ libname: str,
459
+ names: Optional[Union[list, str]] = None,
460
+ squeeze: bool = True,
461
+ progressbar: bool = False,
462
+ ):
463
+ """Write stored series to JSON.
464
+
465
+ Parameters
466
+ ----------
467
+ libname : str
468
+ library name
469
+ names : Optional[Union[list, str]], optional
470
+ names of series, by default None
471
+ squeeze : bool, optional
472
+ return single entry as json string instead
473
+ of list, by default True
474
+ progressbar : bool, optional
475
+ show progressbar, by default False
476
+
477
+ Returns
478
+ -------
479
+ files : list or str
480
+ list of series converted to JSON string or single string
481
+ if single entry is returned and squeeze is True
482
+ """
483
+ names = self._parse_names(names, libname=libname)
484
+ files = []
485
+ for n in tqdm(names, desc=libname) if progressbar else names:
486
+ s = self._get_series(libname, n, progressbar=False)
487
+ if isinstance(s, pd.Series):
488
+ s = s.to_frame()
489
+ try:
490
+ sjson = s.to_json(orient="columns")
491
+ except ValueError as e:
492
+ msg = (
493
+ f"DatetimeIndex of '{n}' probably contains NaT "
494
+ "or duplicate timestamps!"
495
+ )
496
+ raise ValueError(msg) from e
497
+ files.append(sjson)
498
+ if len(files) == 1 and squeeze:
499
+ return files[0]
500
+ else:
501
+ return files
502
+
503
+ def _stored_metadata_to_json(
504
+ self,
505
+ libname: str,
506
+ names: Optional[Union[list, str]] = None,
507
+ squeeze: bool = True,
508
+ progressbar: bool = False,
509
+ ):
510
+ """Write metadata from stored series to JSON.
511
+
512
+ Parameters
513
+ ----------
514
+ libname : str
515
+ library containing series
516
+ names : Optional[Union[list, str]], optional
517
+ names to parse, by default None
518
+ squeeze : bool, optional
519
+ return single entry as json string instead of list, by default True
520
+ progressbar : bool, optional
521
+ show progressbar, by default False
522
+
523
+ Returns
524
+ -------
525
+ files : list or str
526
+ list of json string
527
+ """
528
+ names = self._parse_names(names, libname=libname)
529
+ files = []
530
+ for n in tqdm(names, desc=libname) if progressbar else names:
531
+ meta = self.get_metadata(libname, n, as_frame=False)
532
+ meta_json = json.dumps(meta, cls=PastasEncoder, indent=4)
533
+ files.append(meta_json)
534
+ if len(files) == 1 and squeeze:
535
+ return files[0]
536
+ else:
537
+ return files
538
+
539
+ def _series_to_archive(
540
+ self,
541
+ archive,
542
+ libname: str,
543
+ names: Optional[Union[list, str]] = None,
544
+ progressbar: bool = True,
545
+ ):
546
+ """Write DataFrame or Series to zipfile (internal method).
547
+
548
+ Parameters
549
+ ----------
550
+ archive : zipfile.ZipFile
551
+ reference to an archive to write data to
552
+ libname : str
553
+ name of the library to write to zipfile
554
+ names : str or list of str, optional
555
+ names of the time series to write to archive, by default None,
556
+ which writes all time series to archive
557
+ progressbar : bool, optional
558
+ show progressbar, by default True
559
+ """
560
+ names = self._parse_names(names, libname=libname)
561
+ for n in tqdm(names, desc=libname) if progressbar else names:
562
+ sjson = self._stored_series_to_json(
563
+ libname, names=n, progressbar=False, squeeze=True
564
+ )
565
+ meta_json = self._stored_metadata_to_json(
566
+ libname, names=n, progressbar=False, squeeze=True
567
+ )
568
+ archive.writestr(f"{libname}/{n}.pas", sjson)
569
+ archive.writestr(f"{libname}/{n}_meta.pas", meta_json)
570
+
571
+ def _models_to_archive(self, archive, names=None, progressbar=True):
572
+ """Write pastas.Model to zipfile (internal method).
573
+
574
+ Parameters
575
+ ----------
576
+ archive : zipfile.ZipFile
577
+ reference to an archive to write data to
578
+ names : str or list of str, optional
579
+ names of the models to write to archive, by default None,
580
+ which writes all models to archive
581
+ progressbar : bool, optional
582
+ show progressbar, by default True
583
+ """
584
+ names = self._parse_names(names, libname="models")
585
+ for n in tqdm(names, desc="models") if progressbar else names:
586
+ m = self.get_models(n, return_dict=True)
587
+ jsondict = json.dumps(m, cls=PastasEncoder, indent=4)
588
+ archive.writestr(f"models/{n}.pas", jsondict)
589
+
590
+ @staticmethod
591
+ def _series_from_json(fjson: str, squeeze: bool = True):
592
+ """Load time series from JSON.
593
+
594
+ Parameters
595
+ ----------
596
+ fjson : str
597
+ path to file
598
+ squeeze : bool, optional
599
+ squeeze time series object to obtain pandas Series
600
+
601
+ Returns
602
+ -------
603
+ s : pd.DataFrame
604
+ DataFrame containing time series
605
+ """
606
+ s = pd.read_json(fjson, orient="columns", precise_float=True, dtype=False)
607
+ if not isinstance(s.index, pd.DatetimeIndex):
608
+ s.index = pd.to_datetime(s.index, unit="ms")
609
+ s = s.sort_index() # needed for some reason ...
610
+ if squeeze:
611
+ return s.squeeze()
612
+ return s
613
+
614
+ @staticmethod
615
+ def _metadata_from_json(fjson: str):
616
+ """Load metadata dictionary from JSON.
617
+
618
+ Parameters
619
+ ----------
620
+ fjson : str
621
+ path to file
622
+
623
+ Returns
624
+ -------
625
+ meta : dict
626
+ dictionary containing metadata
627
+ """
628
+ with open(fjson, "r") as f:
629
+ meta = json.load(f)
630
+ return meta
631
+
632
+ def _get_model_orphans(self):
633
+ """Get models whose oseries no longer exist in database.
634
+
635
+ Returns
636
+ -------
637
+ dict
638
+ dictionary with oseries names as keys and lists of model names
639
+ as values
640
+ """
641
+ d = {}
642
+ for mlnam in tqdm(self.model_names, desc="Identifying model orphans"):
643
+ mdict = self.get_models(mlnam, return_dict=True)
644
+ onam = mdict["oseries"]["name"]
645
+ if onam not in self.oseries_names:
646
+ if onam in d:
647
+ d[onam] = d[onam].append(mlnam)
648
+ else:
649
+ d[onam] = [mlnam]
650
+ return d
651
+
652
+ @staticmethod
653
+ def _solve_model(
654
+ ml_name: str,
655
+ connector: Optional[BaseConnector] = None,
656
+ report: bool = False,
657
+ ignore_solve_errors: bool = False,
658
+ **kwargs,
659
+ ) -> None:
660
+ """Solve a model in the store (internal method).
661
+
662
+ ml_name : list of str, optional
663
+ name of a model in the pastastore
664
+ connector : PasConnector, optional
665
+ Connector to use, by default None which gets the global ArcticDB
666
+ Connector. Otherwise parse a PasConnector.
667
+ report : boolean, optional
668
+ determines if a report is printed when the model is solved,
669
+ default is False
670
+ ignore_solve_errors : boolean, optional
671
+ if True, errors emerging from the solve method are ignored,
672
+ default is False which will raise an exception when a model
673
+ cannot be optimized
674
+ **kwargs : dictionary
675
+ arguments are passed to the solve method.
676
+ """
677
+ if connector is not None:
678
+ conn = connector
679
+ else:
680
+ conn = globals()["conn"]
681
+
682
+ ml = conn.get_models(ml_name)
683
+ m_kwargs = {}
684
+ for key, value in kwargs.items():
685
+ if isinstance(value, pd.Series):
686
+ m_kwargs[key] = value.loc[ml.name]
687
+ else:
688
+ m_kwargs[key] = value
689
+ # Convert timestamps
690
+ for tstamp in ["tmin", "tmax"]:
691
+ if tstamp in m_kwargs:
692
+ m_kwargs[tstamp] = pd.Timestamp(m_kwargs[tstamp])
693
+
694
+ try:
695
+ ml.solve(report=report, **m_kwargs)
696
+ except Exception as e:
697
+ if ignore_solve_errors:
698
+ warning = "Solve error ignored for '%s': %s " % (ml.name, e)
699
+ logger.warning(warning)
700
+ else:
701
+ raise e
702
+
703
+ conn.add_model(ml, overwrite=True)
704
+
705
+ @staticmethod
706
+ def _get_statistics(
707
+ name: str,
708
+ statistics: List[str],
709
+ connector: Union[None, BaseConnector] = None,
710
+ **kwargs,
711
+ ) -> pd.Series:
712
+ """Get statistics for a model in the store (internal method).
713
+
714
+ This function was made to be run in parallel mode. For the odd user
715
+ that wants to run this function directly in sequential model using
716
+ an ArcticDBDConnector the connector argument must be passed in the kwargs
717
+ of the apply method.
718
+ """
719
+ if connector is not None:
720
+ conn = connector
721
+ else:
722
+ conn = globals()["conn"]
723
+
724
+ ml = conn.get_model(name)
725
+ series = pd.Series(index=statistics, dtype=float)
726
+ for stat in statistics:
727
+ series.loc[stat] = getattr(ml.stats, stat)(**kwargs)
728
+ return series
729
+
730
+ @staticmethod
731
+ def _get_max_workers_and_chunksize(
732
+ max_workers: int, njobs: int, chunksize: int = None
733
+ ) -> Tuple[int, int]:
734
+ """Get the maximum workers and chunksize for parallel processing.
735
+
736
+ From: https://stackoverflow.com/a/42096963/10596229
737
+ """
738
+ max_workers = (
739
+ min(32, os.cpu_count() + 4) if max_workers is None else max_workers
740
+ )
741
+ if chunksize is None:
742
+ num_chunks = max_workers * 14
743
+ chunksize = max(njobs // num_chunks, 1)
744
+ return max_workers, chunksize
745
+
18
746
 
19
747
  class ArcticDBConnector(BaseConnector, ConnectorUtil):
20
748
  """ArcticDBConnector object using ArcticDB to store data."""
21
749
 
22
750
  conn_type = "arcticdb"
23
751
 
24
- def __init__(self, name: str, uri: str):
752
+ def __init__(self, name: str, uri: str, verbose: bool = True):
25
753
  """Create an ArcticDBConnector object using ArcticDB to store data.
26
754
 
27
755
  Parameters
@@ -30,9 +758,12 @@ class ArcticDBConnector(BaseConnector, ConnectorUtil):
30
758
  name of the database
31
759
  uri : str
32
760
  URI connection string (e.g. 'lmdb://<your path here>')
761
+ verbose : bool, optional
762
+ whether to print message when database is initialized, by default True
33
763
  """
34
764
  try:
35
765
  import arcticdb
766
+
36
767
  except ModuleNotFoundError as e:
37
768
  print("Please install arcticdb with `pip install arcticdb`!")
38
769
  raise e
@@ -41,25 +772,52 @@ class ArcticDBConnector(BaseConnector, ConnectorUtil):
41
772
 
42
773
  self.libs: dict = {}
43
774
  self.arc = arcticdb.Arctic(uri)
44
- self._initialize()
775
+ self._initialize(verbose=verbose)
45
776
  self.models = ModelAccessor(self)
46
777
  # for older versions of PastaStore, if oseries_models library is empty
47
778
  # populate oseries - models database
48
779
  self._update_all_oseries_model_links()
780
+ # write pstore file to store database info that can be used to load pstore
781
+ if "lmdb" in self.uri:
782
+ self.write_pstore_config_file()
49
783
 
50
- def _initialize(self) -> None:
784
+ def _initialize(self, verbose: bool = True) -> None:
51
785
  """Initialize the libraries (internal method)."""
52
786
  for libname in self._default_library_names:
53
787
  if self._library_name(libname) not in self.arc.list_libraries():
54
788
  self.arc.create_library(self._library_name(libname))
55
789
  else:
56
- print(
57
- f"ArcticDBConnector: library "
58
- f"'{self._library_name(libname)}'"
59
- " already exists. Linking to existing library."
60
- )
790
+ if verbose:
791
+ print(
792
+ f"ArcticDBConnector: library "
793
+ f"'{self._library_name(libname)}'"
794
+ " already exists. Linking to existing library."
795
+ )
61
796
  self.libs[libname] = self._get_library(libname)
62
797
 
798
+ def write_pstore_config_file(self, path: str = None) -> None:
799
+ """Write pstore configuration file to store database info."""
800
+ # NOTE: method is not private as theoretically an ArcticDB
801
+ # database could also be hosted in the cloud, in which case,
802
+ # writing this config in the folder holding the database
803
+ # is no longer possible. For those situations, the user can
804
+ # write this config file and specify the path it should be
805
+ # written to.
806
+ config = {
807
+ "connector_type": self.conn_type,
808
+ "name": self.name,
809
+ "uri": self.uri,
810
+ }
811
+ if path is None and "lmdb" in self.uri:
812
+ path = self.uri.split("://")[1]
813
+ elif path is None and "lmdb" not in self.uri:
814
+ raise ValueError("Please provide a path to write the pastastore file!")
815
+
816
+ with open(
817
+ os.path.join(path, f"{self.name}.pastastore"), "w", encoding="utf-8"
818
+ ) as f:
819
+ json.dump(config, f)
820
+
63
821
  def _library_name(self, libname: str) -> str:
64
822
  """Get full library name according to ArcticDB (internal method)."""
65
823
  return ".".join([self.name, libname])
@@ -159,6 +917,70 @@ class ArcticDBConnector(BaseConnector, ConnectorUtil):
159
917
  lib = self._get_library(libname)
160
918
  return lib.read_metadata(name).metadata
161
919
 
920
+ def _parallel(
921
+ self,
922
+ func: Callable,
923
+ names: List[str],
924
+ kwargs: Optional[Dict] = None,
925
+ progressbar: Optional[bool] = True,
926
+ max_workers: Optional[int] = None,
927
+ chunksize: Optional[int] = None,
928
+ desc: str = "",
929
+ ):
930
+ """Parallel processing of function.
931
+
932
+ Does not return results, so function must store results in database.
933
+
934
+ Parameters
935
+ ----------
936
+ func : function
937
+ function to apply in parallel
938
+ names : list
939
+ list of names to apply function to
940
+ kwargs : dict, optional
941
+ keyword arguments to pass to function
942
+ progressbar : bool, optional
943
+ show progressbar, by default True
944
+ max_workers : int, optional
945
+ maximum number of workers, by default None
946
+ chunksize : int, optional
947
+ chunksize for parallel processing, by default None
948
+ desc : str, optional
949
+ description for progressbar, by default ""
950
+ """
951
+ max_workers, chunksize = ConnectorUtil._get_max_workers_and_chunksize(
952
+ max_workers, len(names), chunksize
953
+ )
954
+
955
+ def initializer(*args):
956
+ global conn
957
+ conn = ArcticDBConnector(*args)
958
+
959
+ initargs = (self.name, self.uri, False)
960
+
961
+ if kwargs is None:
962
+ kwargs = {}
963
+
964
+ if progressbar:
965
+ result = []
966
+ with tqdm(total=len(names), desc=desc) as pbar:
967
+ with ProcessPoolExecutor(
968
+ max_workers=max_workers, initializer=initializer, initargs=initargs
969
+ ) as executor:
970
+ for item in executor.map(
971
+ partial(func, **kwargs), names, chunksize=chunksize
972
+ ):
973
+ result.append(item)
974
+ pbar.update()
975
+ else:
976
+ with ProcessPoolExecutor(
977
+ max_workers=max_workers, initializer=initializer, initargs=initargs
978
+ ) as executor:
979
+ result = executor.map(
980
+ partial(func, **kwargs), names, chunksize=chunksize
981
+ )
982
+ return result
983
+
162
984
  @property
163
985
  def oseries_names(self):
164
986
  """List of oseries names.
@@ -317,6 +1139,12 @@ class DictConnector(BaseConnector, ConnectorUtil):
317
1139
  imeta = deepcopy(lib[name][0])
318
1140
  return imeta
319
1141
 
1142
+ def _parallel(self, *args, **kwargs) -> None:
1143
+ raise NotImplementedError(
1144
+ "DictConnector does not support parallel processing,"
1145
+ " use PasConnector or ArcticDBConnector."
1146
+ )
1147
+
320
1148
  @property
321
1149
  def oseries_names(self):
322
1150
  """List of oseries names."""
@@ -347,7 +1175,7 @@ class PasConnector(BaseConnector, ConnectorUtil):
347
1175
 
348
1176
  conn_type = "pas"
349
1177
 
350
- def __init__(self, name: str, path: str):
1178
+ def __init__(self, name: str, path: str, verbose: bool = True):
351
1179
  """Create PasConnector object that stores data as JSON files on disk.
352
1180
 
353
1181
  Uses Pastas export format (pas-files) to store files.
@@ -359,30 +1187,49 @@ class PasConnector(BaseConnector, ConnectorUtil):
359
1187
  directory in which the data will be stored
360
1188
  path : str
361
1189
  path to directory for storing the data
1190
+ verbose : bool, optional
1191
+ whether to print message when database is initialized, by default True
362
1192
  """
363
1193
  self.name = name
1194
+ self.parentdir = path
364
1195
  self.path = os.path.abspath(os.path.join(path, self.name))
365
1196
  self.relpath = os.path.relpath(self.path)
366
- self._initialize()
1197
+ self._initialize(verbose=verbose)
367
1198
  self.models = ModelAccessor(self)
368
1199
  # for older versions of PastaStore, if oseries_models library is empty
369
1200
  # populate oseries_models library
370
1201
  self._update_all_oseries_model_links()
1202
+ # write pstore file to store database info that can be used to load pstore
1203
+ self._write_pstore_config_file()
371
1204
 
372
- def _initialize(self) -> None:
1205
+ def _initialize(self, verbose: bool = True) -> None:
373
1206
  """Initialize the libraries (internal method)."""
374
1207
  for val in self._default_library_names:
375
1208
  libdir = os.path.join(self.path, val)
376
1209
  if not os.path.exists(libdir):
377
- print(f"PasConnector: library '{val}' created in '{libdir}'")
1210
+ if verbose:
1211
+ print(f"PasConnector: library '{val}' created in '{libdir}'")
378
1212
  os.makedirs(libdir)
379
1213
  else:
380
- print(
381
- f"PasConnector: library '{val}' already exists. "
382
- f"Linking to existing directory: '{libdir}'"
383
- )
1214
+ if verbose:
1215
+ print(
1216
+ f"PasConnector: library '{val}' already exists. "
1217
+ f"Linking to existing directory: '{libdir}'"
1218
+ )
384
1219
  setattr(self, f"lib_{val}", os.path.join(self.path, val))
385
1220
 
1221
+ def _write_pstore_config_file(self):
1222
+ """Write pstore configuration file to store database info."""
1223
+ config = {
1224
+ "connector_type": self.conn_type,
1225
+ "name": self.name,
1226
+ "path": self.parentdir,
1227
+ }
1228
+ with open(
1229
+ os.path.join(self.path, f"{self.name}.pastastore"), "w", encoding="utf-8"
1230
+ ) as f:
1231
+ json.dump(config, f)
1232
+
386
1233
  def _get_library(self, libname: str):
387
1234
  """Get path to directory holding data.
388
1235
 
@@ -523,6 +1370,58 @@ class PasConnector(BaseConnector, ConnectorUtil):
523
1370
  imeta = {}
524
1371
  return imeta
525
1372
 
1373
+ def _parallel(
1374
+ self,
1375
+ func: Callable,
1376
+ names: List[str],
1377
+ kwargs: Optional[dict] = None,
1378
+ progressbar: Optional[bool] = True,
1379
+ max_workers: Optional[int] = None,
1380
+ chunksize: Optional[int] = None,
1381
+ desc: str = "",
1382
+ ):
1383
+ """Parallel processing of function.
1384
+
1385
+ Does not return results, so function must store results in database.
1386
+
1387
+ Parameters
1388
+ ----------
1389
+ func : function
1390
+ function to apply in parallel
1391
+ names : list
1392
+ list of names to apply function to
1393
+ progressbar : bool, optional
1394
+ show progressbar, by default True
1395
+ max_workers : int, optional
1396
+ maximum number of workers, by default None
1397
+ chunksize : int, optional
1398
+ chunksize for parallel processing, by default None
1399
+ desc : str, optional
1400
+ description for progressbar, by default ""
1401
+ """
1402
+ max_workers, chunksize = ConnectorUtil._get_max_workers_and_chunksize(
1403
+ max_workers, len(names), chunksize
1404
+ )
1405
+
1406
+ if kwargs is None:
1407
+ kwargs = {}
1408
+
1409
+ if progressbar:
1410
+ return process_map(
1411
+ partial(func, **kwargs),
1412
+ names,
1413
+ max_workers=max_workers,
1414
+ chunksize=chunksize,
1415
+ desc=desc,
1416
+ total=len(names),
1417
+ )
1418
+ else:
1419
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
1420
+ result = executor.map(
1421
+ partial(func, **kwargs), names, chunksize=chunksize
1422
+ )
1423
+ return result
1424
+
526
1425
  @property
527
1426
  def oseries_names(self):
528
1427
  """List of oseries names."""