pydartdiags 0.0.42__tar.gz → 0.0.43__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydartdiags might be problematic. Click here for more details.
- pydartdiags-0.0.43/PKG-INFO +45 -0
- pydartdiags-0.0.43/README.md +24 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/pyproject.toml +1 -1
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/src/pydartdiags/obs_sequence/obs_sequence.py +127 -63
- pydartdiags-0.0.43/src/pydartdiags/plots/plots.py +339 -0
- pydartdiags-0.0.43/src/pydartdiags.egg-info/PKG-INFO +45 -0
- pydartdiags-0.0.43/tests/test_obs_sequence.py +225 -0
- pydartdiags-0.0.42/PKG-INFO +0 -404
- pydartdiags-0.0.42/README.md +0 -383
- pydartdiags-0.0.42/src/pydartdiags/plots/plots.py +0 -161
- pydartdiags-0.0.42/src/pydartdiags.egg-info/PKG-INFO +0 -404
- pydartdiags-0.0.42/tests/test_obs_sequence.py +0 -87
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/LICENSE +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/setup.cfg +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/setup.py +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/src/pydartdiags/__init__.py +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/src/pydartdiags/obs_sequence/__init__.py +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/src/pydartdiags/plots/__init__.py +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/src/pydartdiags.egg-info/SOURCES.txt +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/src/pydartdiags.egg-info/dependency_links.txt +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/src/pydartdiags.egg-info/requires.txt +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/src/pydartdiags.egg-info/top_level.txt +0 -0
- {pydartdiags-0.0.42 → pydartdiags-0.0.43}/tests/test_plots.py +0 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: pydartdiags
|
|
3
|
+
Version: 0.0.43
|
|
4
|
+
Summary: Observation Sequence Diagnostics for DART
|
|
5
|
+
Home-page: https://github.com/NCAR/pyDARTdiags.git
|
|
6
|
+
Author: Helen Kershaw
|
|
7
|
+
Author-email: Helen Kershaw <hkershaw@ucar.edu>
|
|
8
|
+
Project-URL: Homepage, https://github.com/NCAR/pyDARTdiags.git
|
|
9
|
+
Project-URL: Issues, https://github.com/NCAR/pyDARTdiags/issues
|
|
10
|
+
Project-URL: Documentation, https://ncar.github.io/pyDARTdiags
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Requires-Python: >=3.8
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: pandas>=2.2.0
|
|
18
|
+
Requires-Dist: numpy>=1.26
|
|
19
|
+
Requires-Dist: plotly>=5.22.0
|
|
20
|
+
Requires-Dist: pyyaml>=6.0.2
|
|
21
|
+
|
|
22
|
+
[](https://opensource.org/licenses/Apache-2.0)
|
|
23
|
+
[](https://codecov.io/gh/NCAR/pyDARTdiags)
|
|
24
|
+
[](https://pypi.org/project/pydartdiags/)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# pyDARTdiags
|
|
28
|
+
|
|
29
|
+
pyDARTdiags is a Python library for obsevation space diagnostics for the Data Assimilation Research Testbed ([DART](https://github.com/NCAR/DART)).
|
|
30
|
+
|
|
31
|
+
pyDARTdiags is under initial development, so please use caution.
|
|
32
|
+
The MATLAB [observation space diagnostics](https://docs.dart.ucar.edu/en/latest/guide/matlab-observation-space.html) are available through [DART](https://github.com/NCAR/DART).
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
pyDARTdiags can be installed through pip: https://pypi.org/project/pydartdiags/
|
|
36
|
+
Documenation : https://ncar.github.io/pyDARTdiags/
|
|
37
|
+
|
|
38
|
+
## Contributing
|
|
39
|
+
Contributions are welcome! If you have a feature request, bug report, or a suggestion, please open an issue on our GitHub repository.
|
|
40
|
+
Please read our [Contributors Guide](https://github.com/NCAR/pyDARTdiags/blob/main/CONTRIBUTING.md) if you would like to contribute to
|
|
41
|
+
pyDARTdiags.
|
|
42
|
+
|
|
43
|
+
## License
|
|
44
|
+
|
|
45
|
+
pyDARTdiags is released under the Apache License 2.0. For more details, see the LICENSE file in the root directory of this source tree or visit [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
[](https://opensource.org/licenses/Apache-2.0)
|
|
2
|
+
[](https://codecov.io/gh/NCAR/pyDARTdiags)
|
|
3
|
+
[](https://pypi.org/project/pydartdiags/)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# pyDARTdiags
|
|
7
|
+
|
|
8
|
+
pyDARTdiags is a Python library for obsevation space diagnostics for the Data Assimilation Research Testbed ([DART](https://github.com/NCAR/DART)).
|
|
9
|
+
|
|
10
|
+
pyDARTdiags is under initial development, so please use caution.
|
|
11
|
+
The MATLAB [observation space diagnostics](https://docs.dart.ucar.edu/en/latest/guide/matlab-observation-space.html) are available through [DART](https://github.com/NCAR/DART).
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
pyDARTdiags can be installed through pip: https://pypi.org/project/pydartdiags/
|
|
15
|
+
Documenation : https://ncar.github.io/pyDARTdiags/
|
|
16
|
+
|
|
17
|
+
## Contributing
|
|
18
|
+
Contributions are welcome! If you have a feature request, bug report, or a suggestion, please open an issue on our GitHub repository.
|
|
19
|
+
Please read our [Contributors Guide](https://github.com/NCAR/pyDARTdiags/blob/main/CONTRIBUTING.md) if you would like to contribute to
|
|
20
|
+
pyDARTdiags.
|
|
21
|
+
|
|
22
|
+
## License
|
|
23
|
+
|
|
24
|
+
pyDARTdiags is released under the Apache License 2.0. For more details, see the LICENSE file in the root directory of this source tree or visit [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
|
@@ -5,6 +5,23 @@ import os
|
|
|
5
5
|
import yaml
|
|
6
6
|
import struct
|
|
7
7
|
|
|
8
|
+
def requires_assimilation_info(func):
|
|
9
|
+
def wrapper(self, *args, **kwargs):
|
|
10
|
+
if self.has_assimilation_info:
|
|
11
|
+
return func(self, *args, **kwargs)
|
|
12
|
+
else:
|
|
13
|
+
raise ValueError("Assimilation information is required to call this function.")
|
|
14
|
+
return wrapper
|
|
15
|
+
|
|
16
|
+
def requires_posterior_info(func):
|
|
17
|
+
def wrapper(self, *args, **kwargs):
|
|
18
|
+
if self.has_posterior_info:
|
|
19
|
+
return func(self, *args, **kwargs)
|
|
20
|
+
else:
|
|
21
|
+
raise ValueError("Posterior information is required to call this function.")
|
|
22
|
+
return wrapper
|
|
23
|
+
|
|
24
|
+
|
|
8
25
|
class obs_sequence:
|
|
9
26
|
"""Create an obs_sequence object from an ascii observation sequence file.
|
|
10
27
|
|
|
@@ -59,6 +76,8 @@ class obs_sequence:
|
|
|
59
76
|
|
|
60
77
|
def __init__(self, file, synonyms=None):
|
|
61
78
|
self.loc_mod = 'None'
|
|
79
|
+
self.has_assimilation_info = False
|
|
80
|
+
self.has_posterior = False
|
|
62
81
|
self.file = file
|
|
63
82
|
self.synonyms_for_obs = ['NCEP BUFR observation',
|
|
64
83
|
'AIRS observation',
|
|
@@ -72,6 +91,17 @@ class obs_sequence:
|
|
|
72
91
|
else:
|
|
73
92
|
self.synonyms_for_obs.append(synonyms)
|
|
74
93
|
|
|
94
|
+
if file is None:
|
|
95
|
+
# Early exit for testing purposes
|
|
96
|
+
self.df = pd.DataFrame()
|
|
97
|
+
self.types = {}
|
|
98
|
+
self.reverse_types = {}
|
|
99
|
+
self.copie_names = []
|
|
100
|
+
self.n_copies = 0
|
|
101
|
+
self.seq = []
|
|
102
|
+
self.all_obs = []
|
|
103
|
+
return
|
|
104
|
+
|
|
75
105
|
module_dir = os.path.dirname(__file__)
|
|
76
106
|
self.default_composite_types = os.path.join(module_dir,"composite_types.yaml")
|
|
77
107
|
|
|
@@ -103,11 +133,16 @@ class obs_sequence:
|
|
|
103
133
|
self.synonyms_for_obs = [synonym.replace(' ', '_') for synonym in self.synonyms_for_obs]
|
|
104
134
|
rename_dict = {old: 'observation' for old in self.synonyms_for_obs if old in self.df.columns}
|
|
105
135
|
self.df = self.df.rename(columns=rename_dict)
|
|
136
|
+
|
|
106
137
|
# calculate bias and sq_err is the obs_seq is an obs_seq.final
|
|
107
138
|
if 'prior_ensemble_mean'.casefold() in map(str.casefold, self.columns):
|
|
108
|
-
self.
|
|
109
|
-
self.df['
|
|
110
|
-
|
|
139
|
+
self.has_assimilation_info = True
|
|
140
|
+
self.df['prior_bias'] = (self.df['prior_ensemble_mean'] - self.df['observation'])
|
|
141
|
+
self.df['prior_sq_err'] = self.df['prior_bias']**2 # squared error
|
|
142
|
+
if 'posterior_ensemble_mean'.casefold() in map(str.casefold, self.columns):
|
|
143
|
+
self.has_posterior_info = True
|
|
144
|
+
self.df['posterior_bias'] = (self.df['posterior_ensemble_mean'] - self.df['observation'])
|
|
145
|
+
self.df['posterior_sq_err'] = self.df['posterior_bias']**2
|
|
111
146
|
|
|
112
147
|
def create_all_obs(self):
|
|
113
148
|
""" steps through the generator to create a
|
|
@@ -152,14 +187,38 @@ class obs_sequence:
|
|
|
152
187
|
data.append(self.types[type_value]) # observation type
|
|
153
188
|
|
|
154
189
|
# any observation specific obs def info is between here and the end of the list
|
|
190
|
+
# can be obs_def & external forward operator
|
|
191
|
+
metadata = obs[typeI+2:-2]
|
|
192
|
+
obs_def_metadata, external_metadata = self.split_metadata(metadata)
|
|
193
|
+
data.append(obs_def_metadata)
|
|
194
|
+
data.append(external_metadata)
|
|
195
|
+
|
|
155
196
|
time = obs[-2].split()
|
|
156
197
|
data.append(int(time[0])) # seconds
|
|
157
198
|
data.append(int(time[1])) # days
|
|
158
199
|
data.append(convert_dart_time(int(time[0]), int(time[1]))) # datetime # HK todo what is approprate for 1d models?
|
|
159
200
|
data.append(float(obs[-1])) # obs error variance ?convert to sd?
|
|
160
|
-
|
|
201
|
+
|
|
161
202
|
return data
|
|
162
203
|
|
|
204
|
+
@staticmethod
|
|
205
|
+
def split_metadata(metadata):
|
|
206
|
+
"""
|
|
207
|
+
Split the metadata list at the first occurrence of an element starting with 'externalF0'.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
metadata (list of str): The metadata list to be split.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
tuple: Two sublists, the first containing elements before 'externalF0', and the second
|
|
214
|
+
containing 'externalF0' and all elements after it. If 'externalF0' is not found,
|
|
215
|
+
the first sublist contains the entire metadata list, and the second is empty.
|
|
216
|
+
"""
|
|
217
|
+
for i, item in enumerate(metadata):
|
|
218
|
+
if item.startswith('external_FO'):
|
|
219
|
+
return metadata[:i], metadata[i:]
|
|
220
|
+
return metadata, []
|
|
221
|
+
|
|
163
222
|
def list_to_obs(self, data):
|
|
164
223
|
obs = []
|
|
165
224
|
obs.append('OBS ' + str(data[0])) # obs_num lots of space
|
|
@@ -171,10 +230,16 @@ class obs_sequence:
|
|
|
171
230
|
obs.append(' '.join(map(str, data[self.n_copies+2:self.n_copies+5])) + ' ' + str(self.reversed_vert[data[self.n_copies+5]]) ) # location x, y, z, vert
|
|
172
231
|
obs.append('kind') # this is type of observation
|
|
173
232
|
obs.append(self.reverse_types[data[self.n_copies + 6]]) # observation type
|
|
233
|
+
# Convert metadata to a string and append
|
|
234
|
+
obs.extend(data[self.n_copies + 7]) # metadata
|
|
174
235
|
elif self.loc_mod == 'loc1d':
|
|
175
236
|
obs.append(data[self.n_copies+2]) # 1d location
|
|
176
237
|
obs.append('kind') # this is type of observation
|
|
177
238
|
obs.append(self.reverse_types[data[self.n_copies + 3]]) # observation type
|
|
239
|
+
# Convert metadata to a string and append
|
|
240
|
+
metadata = ' '.join(map(str, data[self.n_copies + 4:-4]))
|
|
241
|
+
if metadata:
|
|
242
|
+
obs.append(metadata) # metadata
|
|
178
243
|
obs.append(' '.join(map(str, data[-4:-2]))) # seconds, days
|
|
179
244
|
obs.append(data[-1]) # obs error variance
|
|
180
245
|
|
|
@@ -273,12 +338,70 @@ class obs_sequence:
|
|
|
273
338
|
elif self.loc_mod == 'loc1d':
|
|
274
339
|
heading.append('location')
|
|
275
340
|
heading.append('type')
|
|
341
|
+
heading.append('metadata')
|
|
342
|
+
heading.append('external_FO')
|
|
276
343
|
heading.append('seconds')
|
|
277
344
|
heading.append('days')
|
|
278
345
|
heading.append('time')
|
|
279
346
|
heading.append('obs_err_var')
|
|
280
347
|
return heading
|
|
281
348
|
|
|
349
|
+
@requires_assimilation_info
|
|
350
|
+
def select_by_dart_qc(self, dart_qc):
|
|
351
|
+
"""
|
|
352
|
+
Selects rows from a DataFrame based on the DART quality control flag.
|
|
353
|
+
|
|
354
|
+
Parameters:
|
|
355
|
+
df (DataFrame): A pandas DataFrame.
|
|
356
|
+
dart_qc (int): The DART quality control flag to select.
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
DataFrame: A DataFrame containing only the rows with the specified DART quality control flag.
|
|
360
|
+
|
|
361
|
+
Raises:
|
|
362
|
+
ValueError: If the DART quality control flag is not present in the DataFrame.
|
|
363
|
+
"""
|
|
364
|
+
if dart_qc not in self.df['DART_quality_control'].unique():
|
|
365
|
+
raise ValueError(f"DART quality control flag '{dart_qc}' not found in DataFrame.")
|
|
366
|
+
else:
|
|
367
|
+
return self.df[self.df['DART_quality_control'] == dart_qc]
|
|
368
|
+
|
|
369
|
+
@requires_assimilation_info
|
|
370
|
+
def select_failed_qcs(self):
|
|
371
|
+
"""
|
|
372
|
+
Select rows from the DataFrame where the DART quality control flag is greater than 0.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
pandas.DataFrame: A DataFrame containing only the rows with a DART quality control flag greater than 0.
|
|
376
|
+
"""
|
|
377
|
+
return self.df[self.df['DART_quality_control'] > 0]
|
|
378
|
+
|
|
379
|
+
@requires_assimilation_info
|
|
380
|
+
def possible_vs_used(self):
|
|
381
|
+
"""
|
|
382
|
+
Calculates the count of possible vs. used observations by type.
|
|
383
|
+
|
|
384
|
+
This function takes a DataFrame containing observation data, including a 'type' column for the observation
|
|
385
|
+
type and an 'observation' column. The number of used observations ('used'), is the total number
|
|
386
|
+
minus the observations that failed quality control checks (as determined by the `select_failed_qcs` function).
|
|
387
|
+
The result is a DataFrame with each observation type, the count of possible observations, and the count of
|
|
388
|
+
used observations.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
pd.DataFrame: A DataFrame with three columns: 'type', 'possible', and 'used'. 'type' is the observation type,
|
|
392
|
+
'possible' is the count of all observations of that type, and 'used' is the count of observations of that type
|
|
393
|
+
that passed quality control checks.
|
|
394
|
+
"""
|
|
395
|
+
possible = self.df.groupby('type')['observation'].count()
|
|
396
|
+
possible.rename('possible', inplace=True)
|
|
397
|
+
|
|
398
|
+
failed_qcs = self.select_failed_qcs().groupby('type')['observation'].count()
|
|
399
|
+
used = possible - failed_qcs.reindex(possible.index, fill_value=0)
|
|
400
|
+
used.rename('used', inplace=True)
|
|
401
|
+
|
|
402
|
+
return pd.concat([possible, used], axis=1).reset_index()
|
|
403
|
+
|
|
404
|
+
|
|
282
405
|
@staticmethod
|
|
283
406
|
def is_binary(file):
|
|
284
407
|
"""Check if a file is binary file."""
|
|
@@ -659,65 +782,6 @@ def convert_dart_time(seconds, days):
|
|
|
659
782
|
"""
|
|
660
783
|
time = dt.datetime(1601,1,1) + dt.timedelta(days=days, seconds=seconds)
|
|
661
784
|
return time
|
|
662
|
-
|
|
663
|
-
def select_by_dart_qc(df, dart_qc):
|
|
664
|
-
"""
|
|
665
|
-
Selects rows from a DataFrame based on the DART quality control flag.
|
|
666
|
-
|
|
667
|
-
Parameters:
|
|
668
|
-
df (DataFrame): A pandas DataFrame.
|
|
669
|
-
dart_qc (int): The DART quality control flag to select.
|
|
670
|
-
|
|
671
|
-
Returns:
|
|
672
|
-
DataFrame: A DataFrame containing only the rows with the specified DART quality control flag.
|
|
673
|
-
|
|
674
|
-
Raises:
|
|
675
|
-
ValueError: If the DART quality control flag is not present in the DataFrame.
|
|
676
|
-
"""
|
|
677
|
-
if dart_qc not in df['DART_quality_control'].unique():
|
|
678
|
-
raise ValueError(f"DART quality control flag '{dart_qc}' not found in DataFrame.")
|
|
679
|
-
else:
|
|
680
|
-
return df[df['DART_quality_control'] == dart_qc]
|
|
681
|
-
|
|
682
|
-
def select_failed_qcs(df):
|
|
683
|
-
"""
|
|
684
|
-
Selects rows from a DataFrame where the DART quality control flag is greater than 0.
|
|
685
|
-
|
|
686
|
-
Parameters:
|
|
687
|
-
df (DataFrame): A pandas DataFrame.
|
|
688
|
-
|
|
689
|
-
Returns:
|
|
690
|
-
DataFrame: A DataFrame containing only the rows with a DART quality control flag greater than 0.
|
|
691
|
-
"""
|
|
692
|
-
return df[df['DART_quality_control'] > 0]
|
|
693
|
-
|
|
694
|
-
def possible_vs_used(df):
|
|
695
|
-
"""
|
|
696
|
-
Calculates the count of possible vs. used observations by type.
|
|
697
|
-
|
|
698
|
-
This function takes a DataFrame containing observation data, including a 'type' column for the observation
|
|
699
|
-
type and an 'observation' column. The number of used observations ('used'), is the total number
|
|
700
|
-
minus the observations that failed quality control checks (as determined by the `select_failed_qcs` function).
|
|
701
|
-
The result is a DataFrame with each observation type, the count of possible observations, and the count of
|
|
702
|
-
used observations.
|
|
703
|
-
|
|
704
|
-
Parameters:
|
|
705
|
-
df (pd.DataFrame): A DataFrame with at least two columns: 'type' for the observation type and 'observation'
|
|
706
|
-
for the observation data. It may also contain other columns required by the `select_failed_qcs` function
|
|
707
|
-
to determine failed quality control checks.
|
|
708
|
-
|
|
709
|
-
Returns:
|
|
710
|
-
pd.DataFrame: A DataFrame with three columns: 'type', 'possible', and 'used'. 'type' is the observation type,
|
|
711
|
-
'possible' is the count of all observations of that type, and 'used' is the count of observations of that type
|
|
712
|
-
that passed quality control checks.
|
|
713
|
-
|
|
714
|
-
"""
|
|
715
|
-
possible = df.groupby('type')['observation'].count()
|
|
716
|
-
possible.rename('possible', inplace=True)
|
|
717
|
-
used = df.groupby('type')['observation'].count() - select_failed_qcs(df).groupby('type')['observation'].count()
|
|
718
|
-
used.rename('used', inplace=True)
|
|
719
|
-
return pd.concat([possible, used], axis=1).reset_index()
|
|
720
|
-
|
|
721
785
|
|
|
722
786
|
def construct_composit(df_comp, composite, components):
|
|
723
787
|
"""
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import plotly.express as px
|
|
4
|
+
import plotly.graph_objects as go
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
def plot_rank_histogram(df):
|
|
8
|
+
"""
|
|
9
|
+
Plots a rank histogram colored by observation type.
|
|
10
|
+
|
|
11
|
+
All histogram bars are initalized to be hidden and can be toggled visible in the plot's legend
|
|
12
|
+
"""
|
|
13
|
+
_, _, df_hist = calculate_rank(df)
|
|
14
|
+
fig = px.histogram(df_hist, x='rank', color='obstype', title='Histogram Colored by obstype')
|
|
15
|
+
for trace in fig.data:
|
|
16
|
+
trace.visible = 'legendonly'
|
|
17
|
+
fig.show()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def calculate_rank(df):
|
|
21
|
+
"""
|
|
22
|
+
Calculate the rank of observations within an ensemble.
|
|
23
|
+
|
|
24
|
+
This function takes a DataFrame containing ensemble predictions and observed values,
|
|
25
|
+
adds sampling noise to the ensemble predictions, and calculates the rank of the observed
|
|
26
|
+
value within the perturbed ensemble for each observation. The rank indicates the position
|
|
27
|
+
of the observed value within the sorted ensemble values, with 1 being the lowest. If the
|
|
28
|
+
observed value is larger than the largest ensemble member, its rank is set to the ensemble
|
|
29
|
+
size plus one.
|
|
30
|
+
|
|
31
|
+
Parameters:
|
|
32
|
+
df (pd.DataFrame): A DataFrame with columns for mean, standard deviation, observed values,
|
|
33
|
+
ensemble size, and observation type. The DataFrame should have one row per observation.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
tuple: A tuple containing the rank array, ensemble size, and a result DataFrame. The result
|
|
37
|
+
DataFrame contains columns for 'rank' and 'obstype'.
|
|
38
|
+
"""
|
|
39
|
+
ensemble_values = df.filter(regex='prior_ensemble_member').to_numpy().copy()
|
|
40
|
+
std_dev = np.sqrt(df['obs_err_var']).to_numpy()
|
|
41
|
+
obsvalue = df['observation'].to_numpy()
|
|
42
|
+
obstype = df['type'].to_numpy()
|
|
43
|
+
ens_size = ensemble_values.shape[1]
|
|
44
|
+
mean = 0.0 # mean of the sampling noise
|
|
45
|
+
rank = np.zeros(obsvalue.shape[0], dtype=int)
|
|
46
|
+
|
|
47
|
+
for obs in range(ensemble_values.shape[0]):
|
|
48
|
+
sampling_noise = np.random.normal(mean, std_dev[obs], ens_size)
|
|
49
|
+
ensemble_values[obs] += sampling_noise
|
|
50
|
+
ensemble_values[obs].sort()
|
|
51
|
+
for i, ens in enumerate(ensemble_values[obs]):
|
|
52
|
+
if obsvalue[obs] <= ens:
|
|
53
|
+
rank[obs] = i + 1
|
|
54
|
+
break
|
|
55
|
+
|
|
56
|
+
if rank[obs] == 0: # observation is larger than largest ensemble member
|
|
57
|
+
rank[obs] = ens_size + 1
|
|
58
|
+
|
|
59
|
+
result_df = pd.DataFrame({
|
|
60
|
+
'rank': rank,
|
|
61
|
+
'obstype': obstype
|
|
62
|
+
})
|
|
63
|
+
|
|
64
|
+
return (rank, ens_size, result_df)
|
|
65
|
+
|
|
66
|
+
def plot_profile(df, levels, verticalUnit = "pressure (Pa)"):
|
|
67
|
+
"""
|
|
68
|
+
Plots RMSE, bias, and total spread profiles for different observation types across specified vertical levels.
|
|
69
|
+
|
|
70
|
+
This function takes a DataFrame containing observational data and model predictions, categorizes
|
|
71
|
+
the data into specified vertical levels, and calculates the RMSE, bias and total spread for each level and
|
|
72
|
+
observation type. It then plots three line charts: one for RMSE, one for bias, one for total spread, as functions
|
|
73
|
+
of vertical level. The vertical levels are plotted on the y-axis in reversed order to represent
|
|
74
|
+
the vertical profile in the atmosphere correctly if the vertical units are pressure.
|
|
75
|
+
|
|
76
|
+
Parameters:
|
|
77
|
+
df (pd.DataFrame): The input DataFrame containing at least the 'vertical' column for vertical levels,
|
|
78
|
+
the vert_unit column, and other columns required by the `rmse_bias` function for calculating RMSE and
|
|
79
|
+
Bias.
|
|
80
|
+
levels (array-like): The bin edges for categorizing the 'vertical' column values into the desired
|
|
81
|
+
vertical levels.
|
|
82
|
+
verticalUnit (string) (optional): The vertical unit to be used. Only observations in df which have this
|
|
83
|
+
string in the vert_unit column will be plotted. Defaults to 'pressure (Pa)'.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
tuple: A tuple containing the DataFrame with RMSE, bias and total spread calculations,
|
|
87
|
+
The DataFrame includes a 'vlevels' column representing the categorized vertical levels
|
|
88
|
+
and 'midpoint' column representing the midpoint of each vertical level bin. And the three figures.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If there are missing values in the 'vertical' column of the input DataFrame.
|
|
92
|
+
ValueError: If none of the input obs have 'verticalUnit' in the 'vert_unit' column of the input DataFrame.
|
|
93
|
+
|
|
94
|
+
Note:
|
|
95
|
+
- The function modifies the input DataFrame by adding 'vlevels' and 'midpoint' columns.
|
|
96
|
+
- The 'midpoint' values are calculated as half the midpoint of each vertical level bin, which may need
|
|
97
|
+
adjustment based on the specific requirements for vertical level representation.
|
|
98
|
+
- The plots are generated using Plotly Express and are displayed inline. The y-axis of the plots is
|
|
99
|
+
reversed to align with standard atmospheric pressure level representation if the vertical units
|
|
100
|
+
are atmospheric pressure.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
pd.options.mode.copy_on_write = True
|
|
104
|
+
if df['vertical'].isnull().values.any(): # what about horizontal observations?
|
|
105
|
+
raise ValueError("Missing values in 'vertical' column.")
|
|
106
|
+
elif verticalUnit not in df['vert_unit'].values:
|
|
107
|
+
raise ValueError("No obs with expected vertical unit '"+verticalUnit+"'.")
|
|
108
|
+
else:
|
|
109
|
+
df = df[df["vert_unit"].isin({verticalUnit})] # Subset to only rows with the correct vertical unit
|
|
110
|
+
df.loc[:,'vlevels'] = pd.cut(df['vertical'], levels)
|
|
111
|
+
if verticalUnit == "pressure (Pa)":
|
|
112
|
+
df.loc[:,'midpoint'] = df['vlevels'].apply(lambda x: x.mid / 100.) # HK todo units
|
|
113
|
+
else:
|
|
114
|
+
df.loc[:,'midpoint'] = df['vlevels'].apply(lambda x: x.mid)
|
|
115
|
+
|
|
116
|
+
# Calculations
|
|
117
|
+
df_profile_prior = rmse_bias_totalspread(df, phase='prior')
|
|
118
|
+
df_profile_posterior = None
|
|
119
|
+
if 'posterior_ensemble_mean' in df.columns:
|
|
120
|
+
df_profile_posterior = rmse_bias_totalspread(df, phase='posterior')
|
|
121
|
+
|
|
122
|
+
# Merge prior and posterior dataframes
|
|
123
|
+
if df_profile_posterior is not None:
|
|
124
|
+
df_profile = pd.merge(df_profile_prior, df_profile_posterior, on=['midpoint', 'type'], suffixes=('_prior', '_posterior'))
|
|
125
|
+
fig_rmse = plot_profile_prior_post(df_profile, 'rmse', verticalUnit)
|
|
126
|
+
fig_rmse.show()
|
|
127
|
+
fig_bias = plot_profile_prior_post(df_profile, 'bias', verticalUnit)
|
|
128
|
+
fig_bias.show()
|
|
129
|
+
fig_ts = plot_profile_prior_post(df_profile, 'totalspread', verticalUnit)
|
|
130
|
+
fig_ts.show()
|
|
131
|
+
else:
|
|
132
|
+
df_profile = df_profile_prior
|
|
133
|
+
fig_rmse = plot_profile_prior(df_profile, 'rmse', verticalUnit)
|
|
134
|
+
fig_rmse.show()
|
|
135
|
+
fig_bias = plot_profile_prior(df_profile, 'bias', verticalUnit)
|
|
136
|
+
fig_bias.show()
|
|
137
|
+
fig_ts = plot_profile_prior(df_profile, 'totalspread', verticalUnit)
|
|
138
|
+
fig_ts.show()
|
|
139
|
+
|
|
140
|
+
return df_profile, fig_rmse, fig_ts, fig_bias
|
|
141
|
+
|
|
142
|
+
def plot_profile_prior_post(df_profile, stat, verticalUnit):
|
|
143
|
+
"""
|
|
144
|
+
Plots prior and posterior statistics by vertical level for different observation types.
|
|
145
|
+
|
|
146
|
+
Parameters:
|
|
147
|
+
df_profile (pd.DataFrame): DataFrame containing the prior and posterior statistics.
|
|
148
|
+
stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
|
|
149
|
+
verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
plotly.graph_objects.Figure: The generated Plotly figure.
|
|
153
|
+
"""
|
|
154
|
+
# Reshape DataFrame to long format for easier plotting
|
|
155
|
+
df_long = pd.melt(
|
|
156
|
+
df_profile,
|
|
157
|
+
id_vars=["midpoint", "type"],
|
|
158
|
+
value_vars=["prior_"+stat, "posterior_"+stat],
|
|
159
|
+
var_name=stat+"_type",
|
|
160
|
+
value_name=stat+"_value"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Define a color mapping for observation each type
|
|
164
|
+
unique_types = df_long["type"].unique()
|
|
165
|
+
colors = px.colors.qualitative.Plotly
|
|
166
|
+
color_mapping = {type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)}
|
|
167
|
+
|
|
168
|
+
# Create a mapping for line styles based on stat
|
|
169
|
+
line_styles = {"prior_"+stat: "solid", "posterior_"+stat: "dash"}
|
|
170
|
+
|
|
171
|
+
# Create the figure
|
|
172
|
+
fig_stat = go.Figure()
|
|
173
|
+
|
|
174
|
+
# Loop through each type and type to add traces
|
|
175
|
+
for t in df_long["type"].unique():
|
|
176
|
+
for stat_type, dash_style in line_styles.items():
|
|
177
|
+
# Filter the DataFrame for this type and stat
|
|
178
|
+
df_filtered = df_long[(df_long[stat+"_type"] == stat_type) & (df_long["type"] == t)]
|
|
179
|
+
|
|
180
|
+
# Add a trace
|
|
181
|
+
fig_stat.add_trace(go.Scatter(
|
|
182
|
+
x=df_filtered[stat+"_value"],
|
|
183
|
+
y=df_filtered["midpoint"],
|
|
184
|
+
mode='lines+markers',
|
|
185
|
+
name='prior '+t if stat_type == "prior_"+stat else 'post ', # Show legend for "prior_stat OBS TYPE" only
|
|
186
|
+
line=dict(dash=dash_style, color=color_mapping[t]), # Same color for all traces in group
|
|
187
|
+
marker=dict(size=8, color=color_mapping[t]),
|
|
188
|
+
legendgroup=t # Group traces by type
|
|
189
|
+
))
|
|
190
|
+
|
|
191
|
+
# Update layout
|
|
192
|
+
fig_stat.update_layout(
|
|
193
|
+
title= stat+' by Level',
|
|
194
|
+
xaxis_title=stat,
|
|
195
|
+
yaxis_title=verticalUnit,
|
|
196
|
+
width=800,
|
|
197
|
+
height=800,
|
|
198
|
+
template="plotly_white"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if verticalUnit == "pressure (Pa)":
|
|
202
|
+
fig_stat.update_yaxes(autorange="reversed")
|
|
203
|
+
|
|
204
|
+
return fig_stat
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def plot_profile_prior(df_profile, stat, verticalUnit):
|
|
208
|
+
"""
|
|
209
|
+
Plots prior statistics by vertical level for different observation types.
|
|
210
|
+
|
|
211
|
+
Parameters:
|
|
212
|
+
df_profile (pd.DataFrame): DataFrame containing the prior statistics.
|
|
213
|
+
stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
|
|
214
|
+
verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
plotly.graph_objects.Figure: The generated Plotly figure.
|
|
218
|
+
"""
|
|
219
|
+
# Reshape DataFrame to long format for easier plotting - not needed for prior only, but
|
|
220
|
+
# leaving it in for consistency with the plot_profile_prior_post function for now
|
|
221
|
+
df_long = pd.melt(
|
|
222
|
+
df_profile,
|
|
223
|
+
id_vars=["midpoint", "type"],
|
|
224
|
+
value_vars=["prior_"+stat],
|
|
225
|
+
var_name=stat+"_type",
|
|
226
|
+
value_name=stat+"_value"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Define a color mapping for observation each type
|
|
230
|
+
unique_types = df_long["type"].unique()
|
|
231
|
+
colors = px.colors.qualitative.Plotly
|
|
232
|
+
color_mapping = {type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)}
|
|
233
|
+
|
|
234
|
+
# Create the figure
|
|
235
|
+
fig_stat = go.Figure()
|
|
236
|
+
|
|
237
|
+
# Loop through each type to add traces
|
|
238
|
+
for t in df_long["type"].unique():
|
|
239
|
+
# Filter the DataFrame for this type and stat
|
|
240
|
+
df_filtered = df_long[(df_long["type"] == t)]
|
|
241
|
+
|
|
242
|
+
# Add a trace
|
|
243
|
+
fig_stat.add_trace(go.Scatter(
|
|
244
|
+
x=df_filtered[stat+"_value"],
|
|
245
|
+
y=df_filtered["midpoint"],
|
|
246
|
+
mode='lines+markers',
|
|
247
|
+
name='prior ' + t,
|
|
248
|
+
line=dict(color=color_mapping[t]), # Same color for all traces in group
|
|
249
|
+
marker=dict(size=8, color=color_mapping[t]),
|
|
250
|
+
legendgroup=t # Group traces by type
|
|
251
|
+
))
|
|
252
|
+
|
|
253
|
+
# Update layout
|
|
254
|
+
fig_stat.update_layout(
|
|
255
|
+
title=stat + ' by Level',
|
|
256
|
+
xaxis_title=stat,
|
|
257
|
+
yaxis_title=verticalUnit,
|
|
258
|
+
width=800,
|
|
259
|
+
height=800,
|
|
260
|
+
template="plotly_white"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if verticalUnit == "pressure (Pa)":
|
|
264
|
+
fig_stat.update_yaxes(autorange="reversed")
|
|
265
|
+
|
|
266
|
+
return fig_stat
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def mean_then_sqrt(x):
|
|
270
|
+
"""
|
|
271
|
+
Calculates the mean of an array-like object and then takes the square root of the result.
|
|
272
|
+
|
|
273
|
+
Parameters:
|
|
274
|
+
arr (array-like): An array-like object (such as a list or a pandas Series).
|
|
275
|
+
The elements should be numeric.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
float: The square root of the mean of the input array.
|
|
279
|
+
|
|
280
|
+
Raises:
|
|
281
|
+
TypeError: If the input is not an array-like object containing numeric values.
|
|
282
|
+
ValueError: If the input array is empty.
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
return np.sqrt(np.mean(x))
|
|
286
|
+
|
|
287
|
+
def rmse_bias_totalspread(df, phase='prior'):
|
|
288
|
+
if phase == 'prior':
|
|
289
|
+
sq_err_column = 'prior_sq_err'
|
|
290
|
+
bias_column = 'prior_bias'
|
|
291
|
+
rmse_column = 'prior_rmse'
|
|
292
|
+
spread_column = 'prior_ensemble_spread'
|
|
293
|
+
totalspread_column = 'prior_totalspread'
|
|
294
|
+
elif phase == 'posterior':
|
|
295
|
+
sq_err_column = 'posterior_sq_err'
|
|
296
|
+
bias_column = 'posterior_bias'
|
|
297
|
+
rmse_column = 'posterior_rmse'
|
|
298
|
+
spread_column = 'posterior_ensemble_spread'
|
|
299
|
+
totalspread_column = 'posterior_totalspread'
|
|
300
|
+
else:
|
|
301
|
+
raise ValueError("Invalid phase. Must be 'prior' or 'posterior'.")
|
|
302
|
+
|
|
303
|
+
rmse_bias_ts_df = df.groupby(['midpoint', 'type'], observed=False).agg({
|
|
304
|
+
sq_err_column: mean_then_sqrt,
|
|
305
|
+
bias_column: 'mean',
|
|
306
|
+
spread_column: mean_then_sqrt,
|
|
307
|
+
'obs_err_var': mean_then_sqrt
|
|
308
|
+
}).reset_index()
|
|
309
|
+
|
|
310
|
+
# Add column for totalspread
|
|
311
|
+
rmse_bias_ts_df[totalspread_column] = np.sqrt(rmse_bias_ts_df[spread_column] + rmse_bias_ts_df['obs_err_var'])
|
|
312
|
+
|
|
313
|
+
# Rename square error to root mean square error
|
|
314
|
+
rmse_bias_ts_df.rename(columns={sq_err_column: rmse_column}, inplace=True)
|
|
315
|
+
|
|
316
|
+
return rmse_bias_ts_df
|
|
317
|
+
|
|
318
|
+
def rmse_bias_by_obs_type(df, obs_type):
|
|
319
|
+
"""
|
|
320
|
+
Calculate the RMSE and bias for a given observation type.
|
|
321
|
+
|
|
322
|
+
Parameters:
|
|
323
|
+
df (DataFrame): A pandas DataFrame.
|
|
324
|
+
obs_type (str): The observation type for which to calculate the RMSE and bias.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
DataFrame: A DataFrame containing the RMSE and bias for the given observation type.
|
|
328
|
+
|
|
329
|
+
Raises:
|
|
330
|
+
ValueError: If the observation type is not present in the DataFrame.
|
|
331
|
+
"""
|
|
332
|
+
if obs_type not in df['type'].unique():
|
|
333
|
+
raise ValueError(f"Observation type '{obs_type}' not found in DataFrame.")
|
|
334
|
+
else:
|
|
335
|
+
obs_type_df = df[df['type'] == obs_type]
|
|
336
|
+
obs_type_agg = obs_type_df.groupby('vlevels', observed=False).agg({'sq_err':mean_then_sqrt, 'bias':'mean'}).reset_index()
|
|
337
|
+
obs_type_agg.rename(columns={'sq_err':'rmse'}, inplace=True)
|
|
338
|
+
return obs_type_agg
|
|
339
|
+
|