pastastore 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pastastore/connectors.py CHANGED
@@ -1,27 +1,756 @@
1
1
  """Module containing classes for connecting to different data stores."""
2
2
 
3
3
  import json
4
+ import logging
4
5
  import os
5
6
  import warnings
7
+ from collections.abc import Iterable
8
+ from concurrent.futures import ProcessPoolExecutor
6
9
  from copy import deepcopy
7
- from typing import Dict, Optional, Union
10
+ from functools import partial
11
+
12
+ # import weakref
13
+ from typing import Callable, Dict, List, Optional, Tuple, Union
8
14
 
9
15
  import pandas as pd
16
+ import pastas as ps
17
+ from numpy import isin
18
+ from packaging.version import parse as parse_version
19
+ from pandas.testing import assert_series_equal
10
20
  from pastas.io.pas import PastasEncoder, pastas_hook
21
+ from tqdm.auto import tqdm
22
+ from tqdm.contrib.concurrent import process_map
11
23
 
12
- from pastastore.base import BaseConnector, ConnectorUtil, ModelAccessor
24
+ from pastastore.base import BaseConnector, ModelAccessor
13
25
  from pastastore.util import _custom_warning
26
+ from pastastore.version import PASTAS_LEQ_022
14
27
 
15
28
  FrameorSeriesUnion = Union[pd.DataFrame, pd.Series]
16
29
  warnings.showwarning = _custom_warning
17
30
 
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class ConnectorUtil:
35
+ """Mix-in class for general Connector helper functions.
36
+
37
+ Only for internal methods, and not methods that are related to CRUD operations on
38
+ database.
39
+ """
40
+
41
+ def _parse_names(
42
+ self,
43
+ names: Optional[Union[list, str]] = None,
44
+ libname: Optional[str] = "oseries",
45
+ ) -> list:
46
+ """Parse names kwarg, returns iterable with name(s) (internal method).
47
+
48
+ Parameters
49
+ ----------
50
+ names : Union[list, str], optional
51
+ str or list of str or None or 'all' (last two options
52
+ retrieves all names)
53
+ libname : str, optional
54
+ name of library, default is 'oseries'
55
+
56
+ Returns
57
+ -------
58
+ list
59
+ list of names
60
+ """
61
+ if not isinstance(names, str) and isinstance(names, Iterable):
62
+ return names
63
+ elif isinstance(names, str) and names != "all":
64
+ return [names]
65
+ elif names is None or names == "all":
66
+ if libname == "oseries":
67
+ return self.oseries_names
68
+ elif libname == "stresses":
69
+ return self.stresses_names
70
+ elif libname == "models":
71
+ return self.model_names
72
+ elif libname == "oseries_models":
73
+ return self.oseries_with_models
74
+ else:
75
+ raise ValueError(f"No library '{libname}'!")
76
+ else:
77
+ raise NotImplementedError(f"Cannot parse 'names': {names}")
78
+
79
+ @staticmethod
80
+ def _meta_list_to_frame(metalist: list, names: list):
81
+ """Convert list of metadata dictionaries to DataFrame.
82
+
83
+ Parameters
84
+ ----------
85
+ metalist : list
86
+ list of metadata dictionaries
87
+ names : list
88
+ list of names corresponding to data in metalist
89
+
90
+ Returns
91
+ -------
92
+ pandas.DataFrame
93
+ DataFrame containing overview of metadata
94
+ """
95
+ # convert to dataframe
96
+ if len(metalist) > 1:
97
+ meta = pd.DataFrame(metalist)
98
+ if len({"x", "y"}.difference(meta.columns)) == 0:
99
+ meta["x"] = meta["x"].astype(float)
100
+ meta["y"] = meta["y"].astype(float)
101
+ elif len(metalist) == 1:
102
+ meta = pd.DataFrame(metalist)
103
+ elif len(metalist) == 0:
104
+ meta = pd.DataFrame()
105
+
106
+ meta.index = names
107
+ meta.index.name = "name"
108
+ return meta
109
+
110
+ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
111
+ """Parse dictionary describing pastas models (internal method).
112
+
113
+ Parameters
114
+ ----------
115
+ mdict : dict
116
+ dictionary describing pastas.Model
117
+ update_ts_settings : bool, optional
118
+ update stored tmin and tmax in time series settings
119
+ based on time series loaded from store.
120
+
121
+ Returns
122
+ -------
123
+ ml : pastas.Model
124
+ time series analysis model
125
+ """
126
+ PASFILE_LEQ_022 = parse_version(
127
+ mdict["file_info"]["pastas_version"]
128
+ ) <= parse_version("0.22.0")
129
+
130
+ # oseries
131
+ if "series" not in mdict["oseries"]:
132
+ name = str(mdict["oseries"]["name"])
133
+ if name not in self.oseries.index:
134
+ msg = "oseries '{}' not present in library".format(name)
135
+ raise LookupError(msg)
136
+ mdict["oseries"]["series"] = self.get_oseries(name).squeeze()
137
+ # update tmin/tmax from time series
138
+ if update_ts_settings:
139
+ mdict["oseries"]["settings"]["tmin"] = mdict["oseries"]["series"].index[
140
+ 0
141
+ ]
142
+ mdict["oseries"]["settings"]["tmax"] = mdict["oseries"]["series"].index[
143
+ -1
144
+ ]
145
+
146
+ # StressModel, WellModel
147
+ for ts in mdict["stressmodels"].values():
148
+ if "stress" in ts.keys():
149
+ # WellModel
150
+ classkey = "stressmodel" if PASFILE_LEQ_022 else "class"
151
+ if ts[classkey] == "WellModel":
152
+ for stress in ts["stress"]:
153
+ if "series" not in stress:
154
+ name = str(stress["name"])
155
+ if name in self.stresses.index:
156
+ stress["series"] = self.get_stresses(name).squeeze()
157
+ # update tmin/tmax from time series
158
+ if update_ts_settings:
159
+ stress["settings"]["tmin"] = stress["series"].index[
160
+ 0
161
+ ]
162
+ stress["settings"]["tmax"] = stress["series"].index[
163
+ -1
164
+ ]
165
+ # StressModel
166
+ else:
167
+ for stress in ts["stress"] if PASFILE_LEQ_022 else [ts["stress"]]:
168
+ if "series" not in stress:
169
+ name = str(stress["name"])
170
+ if name in self.stresses.index:
171
+ stress["series"] = self.get_stresses(name).squeeze()
172
+ # update tmin/tmax from time series
173
+ if update_ts_settings:
174
+ stress["settings"]["tmin"] = stress["series"].index[
175
+ 0
176
+ ]
177
+ stress["settings"]["tmax"] = stress["series"].index[
178
+ -1
179
+ ]
180
+
181
+ # RechargeModel, TarsoModel
182
+ if ("prec" in ts.keys()) and ("evap" in ts.keys()):
183
+ for stress in [ts["prec"], ts["evap"]]:
184
+ if "series" not in stress:
185
+ name = str(stress["name"])
186
+ if name in self.stresses.index:
187
+ stress["series"] = self.get_stresses(name).squeeze()
188
+ # update tmin/tmax from time series
189
+ if update_ts_settings:
190
+ stress["settings"]["tmin"] = stress["series"].index[0]
191
+ stress["settings"]["tmax"] = stress["series"].index[-1]
192
+ else:
193
+ msg = "stress '{}' not present in library".format(name)
194
+ raise KeyError(msg)
195
+
196
+ # hack for pcov w dtype object (when filled with NaNs on store?)
197
+ if "fit" in mdict:
198
+ if "pcov" in mdict["fit"]:
199
+ pcov = mdict["fit"]["pcov"]
200
+ if pcov.dtypes.apply(lambda dtyp: isinstance(dtyp, object)).any():
201
+ mdict["fit"]["pcov"] = pcov.astype(float)
202
+
203
+ # check pastas version vs pas-file version
204
+ file_version = mdict["file_info"]["pastas_version"]
205
+
206
+ # check file version and pastas version
207
+ # if file<0.23 and pastas>=1.0 --> error
208
+ PASTAS_GT_023 = parse_version(ps.__version__) > parse_version("0.23.1")
209
+ if PASFILE_LEQ_022 and PASTAS_GT_023:
210
+ raise UserWarning(
211
+ f"This file was created with Pastas v{file_version} "
212
+ f"and cannot be loaded with Pastas v{ps.__version__} Please load and "
213
+ "save the file with Pastas 0.23 first to update the file "
214
+ "format."
215
+ )
216
+
217
+ try:
218
+ # pastas>=0.15.0
219
+ ml = ps.io.base._load_model(mdict)
220
+ except AttributeError:
221
+ # pastas<0.15.0
222
+ ml = ps.io.base.load_model(mdict)
223
+ return ml
224
+
225
+ @staticmethod
226
+ def _validate_input_series(series):
227
+ """Check if series is pandas.DataFrame or pandas.Series.
228
+
229
+ Parameters
230
+ ----------
231
+ series : object
232
+ object to validate
233
+
234
+ Raises
235
+ ------
236
+ TypeError
237
+ if object is not of type pandas.DataFrame or pandas.Series
238
+ """
239
+ if not (isinstance(series, pd.DataFrame) or isinstance(series, pd.Series)):
240
+ raise TypeError("Please provide pandas.DataFrame or pandas.Series!")
241
+ if isinstance(series, pd.DataFrame):
242
+ if series.columns.size > 1:
243
+ raise ValueError("Only DataFrames with one column are supported!")
244
+
245
+ @staticmethod
246
+ def _set_series_name(series, name):
247
+ """Set series name to match user defined name in store.
248
+
249
+ Parameters
250
+ ----------
251
+ series : pandas.Series or pandas.DataFrame
252
+ set name for this time series
253
+ name : str
254
+ name of the time series (used in the pastastore)
255
+ """
256
+ if isinstance(series, pd.Series):
257
+ series.name = name
258
+ # empty string on index name causes trouble when reading
259
+ # data from ArcticDB: TODO: check if still an issue?
260
+ if series.index.name == "":
261
+ series.index.name = None
262
+
263
+ if isinstance(series, pd.DataFrame):
264
+ series.columns = [name]
265
+ # check for hydropandas objects which are instances of DataFrame but
266
+ # do have a name attribute
267
+ if hasattr(series, "name"):
268
+ series.name = name
269
+ return series
270
+
271
+ @staticmethod
272
+ def _check_stressmodels_supported(ml):
273
+ supported_stressmodels = [
274
+ "StressModel",
275
+ "StressModel2",
276
+ "RechargeModel",
277
+ "WellModel",
278
+ "TarsoModel",
279
+ "Constant",
280
+ "LinearTrend",
281
+ "StepModel",
282
+ ]
283
+ if isinstance(ml, ps.Model):
284
+ smtyps = [sm._name for sm in ml.stressmodels.values()]
285
+ elif isinstance(ml, dict):
286
+ classkey = "stressmodel" if PASTAS_LEQ_022 else "class"
287
+ smtyps = [sm[classkey] for sm in ml["stressmodels"].values()]
288
+ check = isin(smtyps, supported_stressmodels)
289
+ if not all(check):
290
+ unsupported = set(smtyps) - set(supported_stressmodels)
291
+ raise NotImplementedError(
292
+ "PastaStore does not support storing models with the "
293
+ f"following stressmodels: {unsupported}"
294
+ )
295
+
296
+ @staticmethod
297
+ def _check_model_series_names_for_store(ml):
298
+ prec_evap_model = ["RechargeModel", "TarsoModel"]
299
+
300
+ if isinstance(ml, ps.Model):
301
+ series_names = [
302
+ istress.series.name
303
+ for sm in ml.stressmodels.values()
304
+ for istress in sm.stress
305
+ ]
306
+
307
+ elif isinstance(ml, dict):
308
+ # non RechargeModel, Tarsomodel, WellModel stressmodels
309
+ classkey = "stressmodel" if PASTAS_LEQ_022 else "class"
310
+ if PASTAS_LEQ_022:
311
+ series_names = [
312
+ istress["name"]
313
+ for sm in ml["stressmodels"].values()
314
+ if sm[classkey] not in (prec_evap_model + ["WellModel"])
315
+ for istress in sm["stress"]
316
+ ]
317
+ else:
318
+ series_names = [
319
+ sm["stress"]["name"]
320
+ for sm in ml["stressmodels"].values()
321
+ if sm[classkey] not in (prec_evap_model + ["WellModel"])
322
+ ]
323
+
324
+ # WellModel
325
+ if isin(
326
+ ["WellModel"],
327
+ [i[classkey] for i in ml["stressmodels"].values()],
328
+ ).any():
329
+ series_names += [
330
+ istress["name"]
331
+ for sm in ml["stressmodels"].values()
332
+ if sm[classkey] in ["WellModel"]
333
+ for istress in sm["stress"]
334
+ ]
335
+
336
+ # RechargeModel, TarsoModel
337
+ if isin(
338
+ prec_evap_model,
339
+ [i[classkey] for i in ml["stressmodels"].values()],
340
+ ).any():
341
+ series_names += [
342
+ istress["name"]
343
+ for sm in ml["stressmodels"].values()
344
+ if sm[classkey] in prec_evap_model
345
+ for istress in [sm["prec"], sm["evap"]]
346
+ ]
347
+
348
+ else:
349
+ raise TypeError("Expected pastas.Model or dict!")
350
+ if len(series_names) - len(set(series_names)) > 0:
351
+ msg = (
352
+ "There are multiple stresses series with the same name! "
353
+ "Each series name must be unique for the PastaStore!"
354
+ )
355
+ raise ValueError(msg)
356
+
357
+ def _check_oseries_in_store(self, ml: Union[ps.Model, dict]):
358
+ """Check if Model oseries are contained in PastaStore (internal method).
359
+
360
+ Parameters
361
+ ----------
362
+ ml : Union[ps.Model, dict]
363
+ pastas Model
364
+ """
365
+ if isinstance(ml, ps.Model):
366
+ name = ml.oseries.name
367
+ elif isinstance(ml, dict):
368
+ name = str(ml["oseries"]["name"])
369
+ else:
370
+ raise TypeError("Expected pastas.Model or dict!")
371
+ if name not in self.oseries.index:
372
+ msg = (
373
+ f"Cannot add model because oseries '{name}' "
374
+ "is not contained in store."
375
+ )
376
+ raise LookupError(msg)
377
+ # expensive check
378
+ if self.CHECK_MODEL_SERIES_VALUES and isinstance(ml, ps.Model):
379
+ s_org = self.get_oseries(name).squeeze().dropna()
380
+ if PASTAS_LEQ_022:
381
+ so = ml.oseries.series_original
382
+ else:
383
+ so = ml.oseries._series_original
384
+ try:
385
+ assert_series_equal(
386
+ so.dropna(),
387
+ s_org,
388
+ atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE,
389
+ rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE,
390
+ )
391
+ except AssertionError as e:
392
+ raise ValueError(
393
+ f"Cannot add model because model oseries '{name}'"
394
+ " is different from stored oseries! See stacktrace for differences."
395
+ ) from e
396
+
397
+ def _check_stresses_in_store(self, ml: Union[ps.Model, dict]):
398
+ """Check if stresses time series are contained in PastaStore (internal method).
399
+
400
+ Parameters
401
+ ----------
402
+ ml : Union[ps.Model, dict]
403
+ pastas Model
404
+ """
405
+ prec_evap_model = ["RechargeModel", "TarsoModel"]
406
+ if isinstance(ml, ps.Model):
407
+ for sm in ml.stressmodels.values():
408
+ if sm._name in prec_evap_model:
409
+ stresses = [sm.prec, sm.evap]
410
+ else:
411
+ stresses = sm.stress
412
+ for s in stresses:
413
+ if str(s.name) not in self.stresses.index:
414
+ msg = (
415
+ f"Cannot add model because stress '{s.name}' "
416
+ "is not contained in store."
417
+ )
418
+ raise LookupError(msg)
419
+ if self.CHECK_MODEL_SERIES_VALUES:
420
+ s_org = self.get_stresses(s.name).squeeze()
421
+ if PASTAS_LEQ_022:
422
+ so = s.series_original
423
+ else:
424
+ so = s._series_original
425
+ try:
426
+ assert_series_equal(
427
+ so,
428
+ s_org,
429
+ atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE,
430
+ rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE,
431
+ )
432
+ except AssertionError as e:
433
+ raise ValueError(
434
+ f"Cannot add model because model stress "
435
+ f"'{s.name}' is different from stored stress! "
436
+ "See stacktrace for differences."
437
+ ) from e
438
+ elif isinstance(ml, dict):
439
+ for sm in ml["stressmodels"].values():
440
+ classkey = "stressmodel" if PASTAS_LEQ_022 else "class"
441
+ if sm[classkey] in prec_evap_model:
442
+ stresses = [sm["prec"], sm["evap"]]
443
+ elif sm[classkey] in ["WellModel"]:
444
+ stresses = sm["stress"]
445
+ else:
446
+ stresses = sm["stress"] if PASTAS_LEQ_022 else [sm["stress"]]
447
+ for s in stresses:
448
+ if str(s["name"]) not in self.stresses.index:
449
+ msg = (
450
+ f"Cannot add model because stress '{s['name']}' "
451
+ "is not contained in store."
452
+ )
453
+ raise LookupError(msg)
454
+ else:
455
+ raise TypeError("Expected pastas.Model or dict!")
456
+
457
+ def _stored_series_to_json(
458
+ self,
459
+ libname: str,
460
+ names: Optional[Union[list, str]] = None,
461
+ squeeze: bool = True,
462
+ progressbar: bool = False,
463
+ ):
464
+ """Write stored series to JSON.
465
+
466
+ Parameters
467
+ ----------
468
+ libname : str
469
+ library name
470
+ names : Optional[Union[list, str]], optional
471
+ names of series, by default None
472
+ squeeze : bool, optional
473
+ return single entry as json string instead
474
+ of list, by default True
475
+ progressbar : bool, optional
476
+ show progressbar, by default False
477
+
478
+ Returns
479
+ -------
480
+ files : list or str
481
+ list of series converted to JSON string or single string
482
+ if single entry is returned and squeeze is True
483
+ """
484
+ names = self._parse_names(names, libname=libname)
485
+ files = []
486
+ for n in tqdm(names, desc=libname) if progressbar else names:
487
+ s = self._get_series(libname, n, progressbar=False)
488
+ if isinstance(s, pd.Series):
489
+ s = s.to_frame()
490
+ try:
491
+ sjson = s.to_json(orient="columns")
492
+ except ValueError as e:
493
+ msg = (
494
+ f"DatetimeIndex of '{n}' probably contains NaT "
495
+ "or duplicate timestamps!"
496
+ )
497
+ raise ValueError(msg) from e
498
+ files.append(sjson)
499
+ if len(files) == 1 and squeeze:
500
+ return files[0]
501
+ else:
502
+ return files
503
+
504
+ def _stored_metadata_to_json(
505
+ self,
506
+ libname: str,
507
+ names: Optional[Union[list, str]] = None,
508
+ squeeze: bool = True,
509
+ progressbar: bool = False,
510
+ ):
511
+ """Write metadata from stored series to JSON.
512
+
513
+ Parameters
514
+ ----------
515
+ libname : str
516
+ library containing series
517
+ names : Optional[Union[list, str]], optional
518
+ names to parse, by default None
519
+ squeeze : bool, optional
520
+ return single entry as json string instead of list, by default True
521
+ progressbar : bool, optional
522
+ show progressbar, by default False
523
+
524
+ Returns
525
+ -------
526
+ files : list or str
527
+ list of json string
528
+ """
529
+ names = self._parse_names(names, libname=libname)
530
+ files = []
531
+ for n in tqdm(names, desc=libname) if progressbar else names:
532
+ meta = self.get_metadata(libname, n, as_frame=False)
533
+ meta_json = json.dumps(meta, cls=PastasEncoder, indent=4)
534
+ files.append(meta_json)
535
+ if len(files) == 1 and squeeze:
536
+ return files[0]
537
+ else:
538
+ return files
539
+
540
+ def _series_to_archive(
541
+ self,
542
+ archive,
543
+ libname: str,
544
+ names: Optional[Union[list, str]] = None,
545
+ progressbar: bool = True,
546
+ ):
547
+ """Write DataFrame or Series to zipfile (internal method).
548
+
549
+ Parameters
550
+ ----------
551
+ archive : zipfile.ZipFile
552
+ reference to an archive to write data to
553
+ libname : str
554
+ name of the library to write to zipfile
555
+ names : str or list of str, optional
556
+ names of the time series to write to archive, by default None,
557
+ which writes all time series to archive
558
+ progressbar : bool, optional
559
+ show progressbar, by default True
560
+ """
561
+ names = self._parse_names(names, libname=libname)
562
+ for n in tqdm(names, desc=libname) if progressbar else names:
563
+ sjson = self._stored_series_to_json(
564
+ libname, names=n, progressbar=False, squeeze=True
565
+ )
566
+ meta_json = self._stored_metadata_to_json(
567
+ libname, names=n, progressbar=False, squeeze=True
568
+ )
569
+ archive.writestr(f"{libname}/{n}.pas", sjson)
570
+ archive.writestr(f"{libname}/{n}_meta.pas", meta_json)
571
+
572
+ def _models_to_archive(self, archive, names=None, progressbar=True):
573
+ """Write pastas.Model to zipfile (internal method).
574
+
575
+ Parameters
576
+ ----------
577
+ archive : zipfile.ZipFile
578
+ reference to an archive to write data to
579
+ names : str or list of str, optional
580
+ names of the models to write to archive, by default None,
581
+ which writes all models to archive
582
+ progressbar : bool, optional
583
+ show progressbar, by default True
584
+ """
585
+ names = self._parse_names(names, libname="models")
586
+ for n in tqdm(names, desc="models") if progressbar else names:
587
+ m = self.get_models(n, return_dict=True)
588
+ jsondict = json.dumps(m, cls=PastasEncoder, indent=4)
589
+ archive.writestr(f"models/{n}.pas", jsondict)
590
+
591
+ @staticmethod
592
+ def _series_from_json(fjson: str, squeeze: bool = True):
593
+ """Load time series from JSON.
594
+
595
+ Parameters
596
+ ----------
597
+ fjson : str
598
+ path to file
599
+ squeeze : bool, optional
600
+ squeeze time series object to obtain pandas Series
601
+
602
+ Returns
603
+ -------
604
+ s : pd.DataFrame
605
+ DataFrame containing time series
606
+ """
607
+ s = pd.read_json(fjson, orient="columns", precise_float=True, dtype=False)
608
+ if not isinstance(s.index, pd.DatetimeIndex):
609
+ s.index = pd.to_datetime(s.index, unit="ms")
610
+ s = s.sort_index() # needed for some reason ...
611
+ if squeeze:
612
+ return s.squeeze()
613
+ return s
614
+
615
+ @staticmethod
616
+ def _metadata_from_json(fjson: str):
617
+ """Load metadata dictionary from JSON.
618
+
619
+ Parameters
620
+ ----------
621
+ fjson : str
622
+ path to file
623
+
624
+ Returns
625
+ -------
626
+ meta : dict
627
+ dictionary containing metadata
628
+ """
629
+ with open(fjson, "r") as f:
630
+ meta = json.load(f)
631
+ return meta
632
+
633
+ def _get_model_orphans(self):
634
+ """Get models whose oseries no longer exist in database.
635
+
636
+ Returns
637
+ -------
638
+ dict
639
+ dictionary with oseries names as keys and lists of model names
640
+ as values
641
+ """
642
+ d = {}
643
+ for mlnam in tqdm(self.model_names, desc="Identifying model orphans"):
644
+ mdict = self.get_models(mlnam, return_dict=True)
645
+ onam = mdict["oseries"]["name"]
646
+ if onam not in self.oseries_names:
647
+ if onam in d:
648
+ d[onam] = d[onam].append(mlnam)
649
+ else:
650
+ d[onam] = [mlnam]
651
+ return d
652
+
653
+ @staticmethod
654
+ def _solve_model(
655
+ ml_name: str,
656
+ connector: Optional[BaseConnector] = None,
657
+ report: bool = False,
658
+ ignore_solve_errors: bool = False,
659
+ **kwargs,
660
+ ) -> None:
661
+ """Solve a model in the store (internal method).
662
+
663
+ ml_name : list of str, optional
664
+ name of a model in the pastastore
665
+ connector : PasConnector, optional
666
+ Connector to use, by default None which gets the global ArcticDB
667
+ Connector. Otherwise parse a PasConnector.
668
+ report : boolean, optional
669
+ determines if a report is printed when the model is solved,
670
+ default is False
671
+ ignore_solve_errors : boolean, optional
672
+ if True, errors emerging from the solve method are ignored,
673
+ default is False which will raise an exception when a model
674
+ cannot be optimized
675
+ **kwargs : dictionary
676
+ arguments are passed to the solve method.
677
+ """
678
+ if connector is not None:
679
+ conn = connector
680
+ else:
681
+ conn = globals()["conn"]
682
+
683
+ ml = conn.get_models(ml_name)
684
+ m_kwargs = {}
685
+ for key, value in kwargs.items():
686
+ if isinstance(value, pd.Series):
687
+ m_kwargs[key] = value.loc[ml.name]
688
+ else:
689
+ m_kwargs[key] = value
690
+ # Convert timestamps
691
+ for tstamp in ["tmin", "tmax"]:
692
+ if tstamp in m_kwargs:
693
+ m_kwargs[tstamp] = pd.Timestamp(m_kwargs[tstamp])
694
+
695
+ try:
696
+ ml.solve(report=report, **m_kwargs)
697
+ except Exception as e:
698
+ if ignore_solve_errors:
699
+ warning = "Solve error ignored for '%s': %s " % (ml.name, e)
700
+ logger.warning(warning)
701
+ else:
702
+ raise e
703
+
704
+ conn.add_model(ml, overwrite=True)
705
+
706
+ @staticmethod
707
+ def _get_statistics(
708
+ name: str,
709
+ statistics: List[str],
710
+ connector: Union[None, BaseConnector] = None,
711
+ **kwargs,
712
+ ) -> pd.Series:
713
+ """Get statistics for a model in the store (internal method).
714
+
715
+ This function was made to be run in parallel mode. For the odd user
716
+ that wants to run this function directly in sequential model using
717
+ an ArcticDBDConnector the connector argument must be passed in the kwargs
718
+ of the apply method.
719
+ """
720
+ if connector is not None:
721
+ conn = connector
722
+ else:
723
+ conn = globals()["conn"]
724
+
725
+ ml = conn.get_model(name)
726
+ series = pd.Series(index=statistics, dtype=float)
727
+ for stat in statistics:
728
+ series.loc[stat] = getattr(ml.stats, stat)(**kwargs)
729
+ return series
730
+
731
+ @staticmethod
732
+ def _get_max_workers_and_chunksize(
733
+ max_workers: int, njobs: int, chunksize: int = None
734
+ ) -> Tuple[int, int]:
735
+ """Get the maximum workers and chunksize for parallel processing.
736
+
737
+ From: https://stackoverflow.com/a/42096963/10596229
738
+ """
739
+ max_workers = (
740
+ min(32, os.cpu_count() + 4) if max_workers is None else max_workers
741
+ )
742
+ if chunksize is None:
743
+ num_chunks = max_workers * 14
744
+ chunksize = max(njobs // num_chunks, 1)
745
+ return max_workers, chunksize
746
+
18
747
 
19
748
  class ArcticDBConnector(BaseConnector, ConnectorUtil):
20
749
  """ArcticDBConnector object using ArcticDB to store data."""
21
750
 
22
751
  conn_type = "arcticdb"
23
752
 
24
- def __init__(self, name: str, uri: str):
753
+ def __init__(self, name: str, uri: str, verbose: bool = True):
25
754
  """Create an ArcticDBConnector object using ArcticDB to store data.
26
755
 
27
756
  Parameters
@@ -30,6 +759,8 @@ class ArcticDBConnector(BaseConnector, ConnectorUtil):
30
759
  name of the database
31
760
  uri : str
32
761
  URI connection string (e.g. 'lmdb://<your path here>')
762
+ verbose : bool, optional
763
+ whether to print message when database is initialized, by default True
33
764
  """
34
765
  try:
35
766
  import arcticdb
@@ -41,23 +772,24 @@ class ArcticDBConnector(BaseConnector, ConnectorUtil):
41
772
 
42
773
  self.libs: dict = {}
43
774
  self.arc = arcticdb.Arctic(uri)
44
- self._initialize()
775
+ self._initialize(verbose=verbose)
45
776
  self.models = ModelAccessor(self)
46
777
  # for older versions of PastaStore, if oseries_models library is empty
47
778
  # populate oseries - models database
48
779
  self._update_all_oseries_model_links()
49
780
 
50
- def _initialize(self) -> None:
781
+ def _initialize(self, verbose: bool = True) -> None:
51
782
  """Initialize the libraries (internal method)."""
52
783
  for libname in self._default_library_names:
53
784
  if self._library_name(libname) not in self.arc.list_libraries():
54
785
  self.arc.create_library(self._library_name(libname))
55
786
  else:
56
- print(
57
- f"ArcticDBConnector: library "
58
- f"'{self._library_name(libname)}'"
59
- " already exists. Linking to existing library."
60
- )
787
+ if verbose:
788
+ print(
789
+ f"ArcticDBConnector: library "
790
+ f"'{self._library_name(libname)}'"
791
+ " already exists. Linking to existing library."
792
+ )
61
793
  self.libs[libname] = self._get_library(libname)
62
794
 
63
795
  def _library_name(self, libname: str) -> str:
@@ -159,6 +891,70 @@ class ArcticDBConnector(BaseConnector, ConnectorUtil):
159
891
  lib = self._get_library(libname)
160
892
  return lib.read_metadata(name).metadata
161
893
 
894
+ def _parallel(
895
+ self,
896
+ func: Callable,
897
+ names: List[str],
898
+ kwargs: Optional[Dict] = None,
899
+ progressbar: Optional[bool] = True,
900
+ max_workers: Optional[int] = None,
901
+ chunksize: Optional[int] = None,
902
+ desc: str = "",
903
+ ):
904
+ """Parallel processing of function.
905
+
906
+ Does not return results, so function must store results in database.
907
+
908
+ Parameters
909
+ ----------
910
+ func : function
911
+ function to apply in parallel
912
+ names : list
913
+ list of names to apply function to
914
+ kwargs : dict, optional
915
+ keyword arguments to pass to function
916
+ progressbar : bool, optional
917
+ show progressbar, by default True
918
+ max_workers : int, optional
919
+ maximum number of workers, by default None
920
+ chunksize : int, optional
921
+ chunksize for parallel processing, by default None
922
+ desc : str, optional
923
+ description for progressbar, by default ""
924
+ """
925
+ max_workers, chunksize = ConnectorUtil._get_max_workers_and_chunksize(
926
+ max_workers, len(names), chunksize
927
+ )
928
+
929
+ def initializer(*args):
930
+ global conn
931
+ conn = ArcticDBConnector(*args)
932
+
933
+ initargs = (self.name, self.uri, False)
934
+
935
+ if kwargs is None:
936
+ kwargs = {}
937
+
938
+ if progressbar:
939
+ result = []
940
+ with tqdm(total=len(names), desc=desc) as pbar:
941
+ with ProcessPoolExecutor(
942
+ max_workers=max_workers, initializer=initializer, initargs=initargs
943
+ ) as executor:
944
+ for item in executor.map(
945
+ partial(func, **kwargs), names, chunksize=chunksize
946
+ ):
947
+ result.append(item)
948
+ pbar.update()
949
+ else:
950
+ with ProcessPoolExecutor(
951
+ max_workers=max_workers, initializer=initializer, initargs=initargs
952
+ ) as executor:
953
+ result = executor.map(
954
+ partial(func, **kwargs), names, chunksize=chunksize
955
+ )
956
+ return result
957
+
162
958
  @property
163
959
  def oseries_names(self):
164
960
  """List of oseries names.
@@ -317,6 +1113,12 @@ class DictConnector(BaseConnector, ConnectorUtil):
317
1113
  imeta = deepcopy(lib[name][0])
318
1114
  return imeta
319
1115
 
1116
+ def _parallel(self, *args, **kwargs) -> None:
1117
+ raise NotImplementedError(
1118
+ "DictConnector does not support parallel processing,"
1119
+ " use PasConnector or ArcticDBConnector."
1120
+ )
1121
+
320
1122
  @property
321
1123
  def oseries_names(self):
322
1124
  """List of oseries names."""
@@ -347,7 +1149,7 @@ class PasConnector(BaseConnector, ConnectorUtil):
347
1149
 
348
1150
  conn_type = "pas"
349
1151
 
350
- def __init__(self, name: str, path: str):
1152
+ def __init__(self, name: str, path: str, verbose: bool = True):
351
1153
  """Create PasConnector object that stores data as JSON files on disk.
352
1154
 
353
1155
  Uses Pastas export format (pas-files) to store files.
@@ -359,28 +1161,32 @@ class PasConnector(BaseConnector, ConnectorUtil):
359
1161
  directory in which the data will be stored
360
1162
  path : str
361
1163
  path to directory for storing the data
1164
+ verbose : bool, optional
1165
+ whether to print message when database is initialized, by default True
362
1166
  """
363
1167
  self.name = name
364
1168
  self.path = os.path.abspath(os.path.join(path, self.name))
365
1169
  self.relpath = os.path.relpath(self.path)
366
- self._initialize()
1170
+ self._initialize(verbose=verbose)
367
1171
  self.models = ModelAccessor(self)
368
1172
  # for older versions of PastaStore, if oseries_models library is empty
369
1173
  # populate oseries_models library
370
1174
  self._update_all_oseries_model_links()
371
1175
 
372
- def _initialize(self) -> None:
1176
+ def _initialize(self, verbose: bool = True) -> None:
373
1177
  """Initialize the libraries (internal method)."""
374
1178
  for val in self._default_library_names:
375
1179
  libdir = os.path.join(self.path, val)
376
1180
  if not os.path.exists(libdir):
377
- print(f"PasConnector: library '{val}' created in '{libdir}'")
1181
+ if verbose:
1182
+ print(f"PasConnector: library '{val}' created in '{libdir}'")
378
1183
  os.makedirs(libdir)
379
1184
  else:
380
- print(
381
- f"PasConnector: library '{val}' already exists. "
382
- f"Linking to existing directory: '{libdir}'"
383
- )
1185
+ if verbose:
1186
+ print(
1187
+ f"PasConnector: library '{val}' already exists. "
1188
+ f"Linking to existing directory: '{libdir}'"
1189
+ )
384
1190
  setattr(self, f"lib_{val}", os.path.join(self.path, val))
385
1191
 
386
1192
  def _get_library(self, libname: str):
@@ -523,6 +1329,58 @@ class PasConnector(BaseConnector, ConnectorUtil):
523
1329
  imeta = {}
524
1330
  return imeta
525
1331
 
1332
+ def _parallel(
1333
+ self,
1334
+ func: Callable,
1335
+ names: List[str],
1336
+ kwargs: Optional[dict] = None,
1337
+ progressbar: Optional[bool] = True,
1338
+ max_workers: Optional[int] = None,
1339
+ chunksize: Optional[int] = None,
1340
+ desc: str = "",
1341
+ ):
1342
+ """Parallel processing of function.
1343
+
1344
+ Does not return results, so function must store results in database.
1345
+
1346
+ Parameters
1347
+ ----------
1348
+ func : function
1349
+ function to apply in parallel
1350
+ names : list
1351
+ list of names to apply function to
1352
+ progressbar : bool, optional
1353
+ show progressbar, by default True
1354
+ max_workers : int, optional
1355
+ maximum number of workers, by default None
1356
+ chunksize : int, optional
1357
+ chunksize for parallel processing, by default None
1358
+ desc : str, optional
1359
+ description for progressbar, by default ""
1360
+ """
1361
+ max_workers, chunksize = ConnectorUtil._get_max_workers_and_chunksize(
1362
+ max_workers, len(names), chunksize
1363
+ )
1364
+
1365
+ if kwargs is None:
1366
+ kwargs = {}
1367
+
1368
+ if progressbar:
1369
+ return process_map(
1370
+ partial(func, **kwargs),
1371
+ names,
1372
+ max_workers=max_workers,
1373
+ chunksize=chunksize,
1374
+ desc=desc,
1375
+ total=len(names),
1376
+ )
1377
+ else:
1378
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
1379
+ result = executor.map(
1380
+ partial(func, **kwargs), names, chunksize=chunksize
1381
+ )
1382
+ return result
1383
+
526
1384
  @property
527
1385
  def oseries_names(self):
528
1386
  """List of oseries names."""