pydartdiags 0.0.42__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydartdiags might be problematic. Click here for more details.

@@ -1,161 +1,191 @@
1
-
1
+ # SPDX-License-Identifier: Apache-2.0
2
2
  import numpy as np
3
3
  import plotly.express as px
4
+ import plotly.graph_objects as go
4
5
  import pandas as pd
6
+ from pydartdiags.stats import stats
7
+
5
8
 
6
- def plot_rank_histogram(df):
9
+ def plot_rank_histogram(df, phase, ens_size):
7
10
  """
8
11
  Plots a rank histogram colored by observation type.
9
12
 
10
- All histogram bars are initalized to be hidden and can be toggled visible in the plot's legend
13
+ All histogram bars are initialized to be hidden and can be toggled visible in the plot's legend
11
14
  """
12
- _, _, df_hist = calculate_rank(df)
13
- fig = px.histogram(df_hist, x='rank', color='obstype', title='Histogram Colored by obstype')
15
+ fig = px.histogram(
16
+ df,
17
+ x=f"{phase}_rank",
18
+ color="type",
19
+ title="Histogram Colored by obs type",
20
+ nbins=ens_size,
21
+ )
22
+ fig.update_xaxes(range=[1, ens_size + 1])
14
23
  for trace in fig.data:
15
- trace.visible = 'legendonly'
24
+ trace.visible = "legendonly"
16
25
  fig.show()
17
26
 
18
27
 
19
- def calculate_rank(df):
20
- """
21
- Calculate the rank of observations within an ensemble.
22
-
23
- This function takes a DataFrame containing ensemble predictions and observed values,
24
- adds sampling noise to the ensemble predictions, and calculates the rank of the observed
25
- value within the perturbed ensemble for each observation. The rank indicates the position
26
- of the observed value within the sorted ensemble values, with 1 being the lowest. If the
27
- observed value is larger than the largest ensemble member, its rank is set to the ensemble
28
- size plus one.
29
-
30
- Parameters:
31
- df (pd.DataFrame): A DataFrame with columns for mean, standard deviation, observed values,
32
- ensemble size, and observation type. The DataFrame should have one row per observation.
28
+ def plot_profile(df_in, verticalUnit):
29
+ """Assumes diag_stats has been run on the dataframe and the resulting dataframe is passed in"""
33
30
 
34
- Returns:
35
- tuple: A tuple containing the rank array, ensemble size, and a result DataFrame. The result
36
- DataFrame contains columns for 'rank' and 'obstype'.
37
- """
38
- ensemble_values = df.filter(regex='prior_ensemble_member').to_numpy().copy()
39
- std_dev = np.sqrt(df['obs_err_var']).to_numpy()
40
- obsvalue = df['observation'].to_numpy()
41
- obstype = df['type'].to_numpy()
42
- ens_size = ensemble_values.shape[1]
43
- mean = 0.0 # mean of the sampling noise
44
- rank = np.zeros(obsvalue.shape[0], dtype=int)
45
-
46
- for obs in range(ensemble_values.shape[0]):
47
- sampling_noise = np.random.normal(mean, std_dev[obs], ens_size)
48
- ensemble_values[obs] += sampling_noise
49
- ensemble_values[obs].sort()
50
- for i, ens in enumerate(ensemble_values[obs]):
51
- if obsvalue[obs] <= ens:
52
- rank[obs] = i + 1
53
- break
54
-
55
- if rank[obs] == 0: # observation is larger than largest ensemble member
56
- rank[obs] = ens_size + 1
57
-
58
- result_df = pd.DataFrame({
59
- 'rank': rank,
60
- 'obstype': obstype
61
- })
62
-
63
- return (rank, ens_size, result_df)
64
-
65
- def plot_profile(df, levels):
66
- """
67
- Plots RMSE and Bias profiles for different observation types across specified pressure levels.
68
-
69
- This function takes a DataFrame containing observational data and model predictions, categorizes
70
- the data into specified pressure levels, and calculates the RMSE and Bias for each level and
71
- observation type. It then plots two line charts: one for RMSE and another for Bias, both as functions
72
- of pressure level. The pressure levels are plotted on the y-axis in reversed order to represent
73
- the vertical profile in the atmosphere correctly.
31
+ df = stats.layer_statistics(df_in)
32
+ if "posterior_rmse" in df.columns:
33
+ fig_rmse = plot_profile_prior_post(df, "rmse", verticalUnit)
34
+ fig_rmse.show()
35
+ fig_bias = plot_profile_prior_post(df, "bias", verticalUnit)
36
+ fig_bias.show()
37
+ fig_ts = plot_profile_prior_post(df, "totalspread", verticalUnit)
38
+ fig_ts.show()
39
+ else:
40
+ fig_rmse = plot_profile_prior(df, "rmse", verticalUnit)
41
+ fig_rmse.show()
42
+ fig_bias = plot_profile_prior(df, "bias", verticalUnit)
43
+ fig_bias.show()
44
+ fig_ts = plot_profile_prior(df, "totalspread", verticalUnit)
45
+ fig_ts.show()
74
46
 
75
- Parameters:
76
- df (pd.DataFrame): The input DataFrame containing at least the 'vertical' column for pressure levels,
77
- and other columns required by the `rmse_bias` function for calculating RMSE and Bias.
78
- levels (array-like): The bin edges for categorizing the 'vertical' column values into pressure levels.
47
+ return fig_rmse, fig_ts, fig_bias
79
48
 
80
- Returns:
81
- tuple: A tuple containing the DataFrame with RMSE and Bias calculations, the RMSE plot figure, and the
82
- Bias plot figure. The DataFrame includes a 'plevels' column representing the categorized pressure levels
83
- and 'hPa' column representing the midpoint of each pressure level bin.
84
-
85
- Raises:
86
- ValueError: If there are missing values in the 'vertical' column of the input DataFrame.
87
-
88
- Note:
89
- - The function modifies the input DataFrame by adding 'plevels' and 'hPa' columns.
90
- - The 'hPa' values are calculated as half the midpoint of each pressure level bin, which may need
91
- adjustment based on the specific requirements for pressure level representation.
92
- - The plots are generated using Plotly Express and are displayed inline. The y-axis of the plots is
93
- reversed to align with standard atmospheric pressure level representation.
94
- """
95
49
 
96
- pd.options.mode.copy_on_write = True
97
- if df['vertical'].isnull().values.any(): # what about horizontal observations?
98
- raise ValueError("Missing values in 'vertical' column.")
99
- else:
100
- df.loc[:,'plevels'] = pd.cut(df['vertical'], levels)
101
- df.loc[:,'hPa'] = df['plevels'].apply(lambda x: x.mid / 1000.) # HK todo units
102
-
103
- df_profile = rmse_bias(df)
104
- fig_rmse = px.line(df_profile, y='hPa', x='rmse', title='RMSE by Level', markers=True, color='type', width=800, height=800)
105
- fig_rmse.update_yaxes(autorange="reversed")
106
- fig_rmse.show()
107
-
108
- fig_bias = px.line(df_profile, y='hPa', x='bias', title='Bias by Level', markers=True, color='type', width=800, height=800)
109
- fig_bias.update_yaxes(autorange="reversed")
110
- fig_bias.show()
111
-
112
- return df_profile, fig_rmse, fig_bias
113
-
114
-
115
- def mean_then_sqrt(x):
50
+ def plot_profile_prior_post(df_profile, stat, verticalUnit):
116
51
  """
117
- Calculates the mean of an array-like object and then takes the square root of the result.
52
+ Plots prior and posterior statistics by vertical level for different observation types.
118
53
 
119
54
  Parameters:
120
- arr (array-like): An array-like object (such as a list or a pandas Series).
121
- The elements should be numeric.
55
+ df_profile (pd.DataFrame): DataFrame containing the prior and posterior statistics.
56
+ stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
57
+ verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
122
58
 
123
59
  Returns:
124
- float: The square root of the mean of the input array.
125
-
126
- Raises:
127
- TypeError: If the input is not an array-like object containing numeric values.
128
- ValueError: If the input array is empty.
60
+ plotly.graph_objects.Figure: The generated Plotly figure.
129
61
  """
130
-
131
- return np.sqrt(np.mean(x))
132
-
133
- def rmse_bias(df):
134
- rmse_bias_df = df.groupby(['hPa', 'type']).agg({'sq_err':mean_then_sqrt, 'bias':'mean'}).reset_index()
135
- rmse_bias_df.rename(columns={'sq_err':'rmse'}, inplace=True)
136
-
137
- return rmse_bias_df
138
-
139
-
140
- def rmse_bias_by_obs_type(df, obs_type):
62
+ # Filter the DataFrame to include only rows with the required verticalUnit
63
+ df_filtered = df_profile[df_profile["vert_unit"] == verticalUnit]
64
+
65
+ # Reshape DataFrame to long format for easier plotting
66
+ df_long = pd.melt(
67
+ df_profile,
68
+ id_vars=["midpoint", "type"],
69
+ value_vars=["prior_" + stat, "posterior_" + stat],
70
+ var_name=stat + "_type",
71
+ value_name=stat + "_value",
72
+ )
73
+
74
+ # Define a color mapping for observation each type
75
+ unique_types = df_long["type"].unique()
76
+ colors = px.colors.qualitative.Plotly
77
+ color_mapping = {
78
+ type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)
79
+ }
80
+
81
+ # Create a mapping for line styles based on stat
82
+ line_styles = {"prior_" + stat: "solid", "posterior_" + stat: "dash"}
83
+
84
+ # Create the figure
85
+ fig_stat = go.Figure()
86
+
87
+ # Loop through each type and type to add traces
88
+ for t in df_long["type"].unique():
89
+ for stat_type, dash_style in line_styles.items():
90
+ # Filter the DataFrame for this type and stat
91
+ df_filtered = df_long[
92
+ (df_long[stat + "_type"] == stat_type) & (df_long["type"] == t)
93
+ ]
94
+
95
+ # Add a trace
96
+ fig_stat.add_trace(
97
+ go.Scatter(
98
+ x=df_filtered[stat + "_value"],
99
+ y=df_filtered["midpoint"],
100
+ mode="lines+markers",
101
+ name=(
102
+ "prior " + t if stat_type == "prior_" + stat else "post "
103
+ ), # Show legend for "prior_stat OBS TYPE" only
104
+ line=dict(
105
+ dash=dash_style, color=color_mapping[t]
106
+ ), # Same color for all traces in group
107
+ marker=dict(size=8, color=color_mapping[t]),
108
+ legendgroup=t, # Group traces by type
109
+ )
110
+ )
111
+
112
+ # Update layout
113
+ fig_stat.update_layout(
114
+ title=stat + " by Level",
115
+ xaxis_title=stat,
116
+ yaxis_title=verticalUnit,
117
+ width=800,
118
+ height=800,
119
+ template="plotly_white",
120
+ )
121
+
122
+ if verticalUnit == "pressure (Pa)":
123
+ fig_stat.update_yaxes(autorange="reversed")
124
+
125
+ return fig_stat
126
+
127
+
128
+ def plot_profile_prior(df_profile, stat, verticalUnit):
141
129
  """
142
- Calculate the RMSE and bias for a given observation type.
130
+ Plots prior statistics by vertical level for different observation types.
143
131
 
144
132
  Parameters:
145
- df (DataFrame): A pandas DataFrame.
146
- obs_type (str): The observation type for which to calculate the RMSE and bias.
133
+ df_profile (pd.DataFrame): DataFrame containing the prior statistics.
134
+ stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
135
+ verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
147
136
 
148
137
  Returns:
149
- DataFrame: A DataFrame containing the RMSE and bias for the given observation type.
150
-
151
- Raises:
152
- ValueError: If the observation type is not present in the DataFrame.
138
+ plotly.graph_objects.Figure: The generated Plotly figure.
153
139
  """
154
- if obs_type not in df['type'].unique():
155
- raise ValueError(f"Observation type '{obs_type}' not found in DataFrame.")
156
- else:
157
- obs_type_df = df[df['type'] == obs_type]
158
- obs_type_agg = obs_type_df.groupby('plevels').agg({'sq_err':mean_then_sqrt, 'bias':'mean'}).reset_index()
159
- obs_type_agg.rename(columns={'sq_err':'rmse'}, inplace=True)
160
- return obs_type_agg
161
-
140
+ # Reshape DataFrame to long format for easier plotting - not needed for prior only, but
141
+ # leaving it in for consistency with the plot_profile_prior_post function for now
142
+ df_long = pd.melt(
143
+ df_profile,
144
+ id_vars=["midpoint", "type"],
145
+ value_vars=["prior_" + stat],
146
+ var_name=stat + "_type",
147
+ value_name=stat + "_value",
148
+ )
149
+
150
+ # Define a color mapping for observation each type
151
+ unique_types = df_long["type"].unique()
152
+ colors = px.colors.qualitative.Plotly
153
+ color_mapping = {
154
+ type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)
155
+ }
156
+
157
+ # Create the figure
158
+ fig_stat = go.Figure()
159
+
160
+ # Loop through each type to add traces
161
+ for t in df_long["type"].unique():
162
+ # Filter the DataFrame for this type and stat
163
+ df_filtered = df_long[(df_long["type"] == t)]
164
+
165
+ # Add a trace
166
+ fig_stat.add_trace(
167
+ go.Scatter(
168
+ x=df_filtered[stat + "_value"],
169
+ y=df_filtered["midpoint"],
170
+ mode="lines+markers",
171
+ name="prior " + t,
172
+ line=dict(color=color_mapping[t]), # Same color for all traces in group
173
+ marker=dict(size=8, color=color_mapping[t]),
174
+ legendgroup=t, # Group traces by type
175
+ )
176
+ )
177
+
178
+ # Update layout
179
+ fig_stat.update_layout(
180
+ title=stat + " by Level",
181
+ xaxis_title=stat,
182
+ yaxis_title=verticalUnit,
183
+ width=800,
184
+ height=800,
185
+ template="plotly_white",
186
+ )
187
+
188
+ if verticalUnit == "pressure (Pa)":
189
+ fig_stat.update_yaxes(autorange="reversed")
190
+
191
+ return fig_stat
File without changes
@@ -0,0 +1,323 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import pandas as pd
3
+ import numpy as np
4
+ from functools import wraps
5
+
6
+ # from pydartdiags.obs_sequence import obs_sequence as obsq
7
+
8
+
9
+ def apply_to_phases_in_place(func):
10
+ """
11
+ Decorator to apply a function to both 'prior' and 'posterior' phases
12
+ and modify the DataFrame in place.
13
+
14
+ The decorated function should accept 'phase' as its first argument.
15
+ """
16
+
17
+ @wraps(func)
18
+ def wrapper(df, *args, **kwargs):
19
+ for phase in ["prior", "posterior"]:
20
+ if f"{phase}_ensemble_spread" in df.columns:
21
+ func(df, phase, *args, **kwargs)
22
+ return df
23
+
24
+ return wrapper
25
+
26
+
27
+ def apply_to_phases_by_type_return_df(func):
28
+ """
29
+ Decorator to apply a function to both 'prior' and 'posterior' phases and return a new DataFrame.
30
+
31
+ The decorated function should accept 'phase' as its first argument and return a DataFrame.
32
+ """
33
+
34
+ @wraps(func)
35
+ def wrapper(df, *args, **kwargs):
36
+ results = []
37
+ for phase in ["prior", "posterior"]:
38
+ if f"{phase}_ensemble_mean" in df.columns:
39
+ result = func(df, phase, *args, **kwargs)
40
+ results.append(result)
41
+
42
+ if "midpoint" in result.columns:
43
+ if len(results) == 2:
44
+ return pd.merge(
45
+ results[0],
46
+ results[1],
47
+ on=["midpoint", "vlevels", "type", "vert_unit"],
48
+ )
49
+ else:
50
+ return results[0]
51
+ else:
52
+ if len(results) == 2:
53
+ return pd.merge(results[0], results[1], on="type")
54
+ else:
55
+ return results[0]
56
+
57
+ return wrapper
58
+
59
+
60
+ def apply_to_phases_by_obs(func):
61
+ """
62
+ Decorator to apply a function to both 'prior' and 'posterior' phases and return a new DataFrame.
63
+
64
+ The decorated function should accept 'phase' as its first argument and return a DataFrame.
65
+ """
66
+
67
+ @wraps(func)
68
+ def wrapper(df, *args, **kwargs):
69
+
70
+ res_df = func(df, "prior", *args, **kwargs)
71
+ if "posterior_ensemble_mean" in df.columns:
72
+ posterior_df = func(df, "posterior", *args, **kwargs)
73
+ res_df["posterior_rank"] = posterior_df["posterior_rank"]
74
+
75
+ return res_df
76
+
77
+ return wrapper
78
+
79
+
80
+ @apply_to_phases_by_obs
81
+ def calculate_rank(df, phase):
82
+ """
83
+ Calculate the rank of observations within an ensemble.
84
+
85
+ This function takes a DataFrame containing ensemble predictions and observed values,
86
+ adds sampling noise to the ensemble predictions, and calculates the rank of the observed
87
+ value within the perturbed ensemble for each observation. The rank indicates the position
88
+ of the observed value within the sorted ensemble values, with 1 being the lowest. If the
89
+ observed value is larger than the largest ensemble member, its rank is set to the ensemble
90
+ size plus one.
91
+
92
+ Parameters:
93
+ df (pd.DataFrame): A DataFrame with columns for rank, and observation type.
94
+
95
+ phase (str): The phase for which to calculate the statistics ('prior' or 'posterior')
96
+
97
+ Returns:
98
+ DataFrame containing columns for 'rank' and observation 'type'.
99
+ """
100
+ column = f"{phase}_ensemble_member"
101
+ ensemble_values = df.filter(regex=column).to_numpy().copy()
102
+ std_dev = np.sqrt(df["obs_err_var"]).to_numpy()
103
+ obsvalue = df["observation"].to_numpy()
104
+ obstype = df["type"].to_numpy()
105
+ ens_size = ensemble_values.shape[1]
106
+ mean = 0.0 # mean of the sampling noise
107
+ rank = np.zeros(obsvalue.shape[0], dtype=int)
108
+
109
+ for obs in range(ensemble_values.shape[0]):
110
+ sampling_noise = np.random.normal(mean, std_dev[obs], ens_size)
111
+ ensemble_values[obs] += sampling_noise
112
+ ensemble_values[obs].sort()
113
+ for i, ens in enumerate(ensemble_values[obs]):
114
+ if obsvalue[obs] <= ens:
115
+ rank[obs] = i + 1
116
+ break
117
+
118
+ if rank[obs] == 0: # observation is larger than largest ensemble member
119
+ rank[obs] = ens_size + 1
120
+
121
+ result_df = pd.DataFrame({"type": obstype, f"{phase}_rank": rank})
122
+
123
+ return result_df
124
+
125
+
126
+ def mean_then_sqrt(x):
127
+ """
128
+ Calculates the mean of an array-like object and then takes the square root of the result.
129
+
130
+ Parameters:
131
+ arr (array-like): An array-like object (such as a list or a pandas Series).
132
+ The elements should be numeric.
133
+
134
+ Returns:
135
+ float: The square root of the mean of the input array.
136
+
137
+ Raises:
138
+ TypeError: If the input is not an array-like object containing numeric values.
139
+ ValueError: If the input array is empty.
140
+ """
141
+
142
+ return np.sqrt(np.mean(x))
143
+
144
+
145
+ @apply_to_phases_in_place
146
+ def diag_stats(df, phase):
147
+ """
148
+ Calculate diagnostic statistics for a given phase and add them to the DataFrame.
149
+
150
+ Args:
151
+ df (pandas.DataFrame): The input DataFrame containing observation data and ensemble statistics.
152
+ The DataFrame must include the following columns:
153
+ - 'observation': The actual observation values.
154
+ - 'obs_err_var': The variance of the observation error.
155
+ - 'prior_ensemble_mean' and/or 'posterior_ensemble_mean': The mean of the ensemble.
156
+ - 'prior_ensemble_spread' and/or 'posterior_ensemble_spread': The spread of the ensemble.
157
+
158
+ phase (str): The phase for which to calculate the statistics ('prior' or 'posterior')
159
+
160
+ Returns:
161
+ None: The function modifies the DataFrame in place by adding the following columns:
162
+ - 'prior_sq_err' and/or 'posterior_sq_err': The square error for the 'prior' and 'posterior' phases.
163
+ - 'prior_bias' and/or 'posterior_bias': The bias for the 'prior' and 'posterior' phases.
164
+ - 'prior_totalvar' and/or 'posterior_totalvar': The total variance for the 'prior' and 'posterior' phases.
165
+
166
+ Notes:
167
+ - Spread is the standard deviation of the ensemble.
168
+ - The function modifies the input DataFrame by adding new columns for the calculated statistics.
169
+ """
170
+ pd.options.mode.copy_on_write = True
171
+
172
+ # input from the observation sequence
173
+ spread_column = f"{phase}_ensemble_spread"
174
+ mean_column = f"{phase}_ensemble_mean"
175
+
176
+ # Calculated from the observation sequence
177
+ sq_err_column = f"{phase}_sq_err"
178
+ bias_column = f"{phase}_bias"
179
+ totalvar_column = f"{phase}_totalvar"
180
+
181
+ df[sq_err_column] = (df[mean_column] - df["observation"]) ** 2
182
+ df[bias_column] = df[mean_column] - df["observation"]
183
+ df[totalvar_column] = df["obs_err_var"] + df[spread_column] ** 2
184
+
185
+
186
+ def bin_by_layer(df, levels, verticalUnit="pressure (Pa)"):
187
+ """
188
+ Bin observations by vertical layers and add 'vlevels' and 'midpoint' columns to the DataFrame.
189
+
190
+ This function bins the observations in the DataFrame based on the specified vertical levels and adds two new columns:
191
+ 'vlevels', which represents the categorized vertical levels, and 'midpoint', which represents the midpoint of each
192
+ vertical level bin. Only observations (row) with the specified vertical unit are binned.
193
+
194
+ Args:
195
+ df (pandas.DataFrame): The input DataFrame containing observation data. The DataFrame must include the following columns:
196
+ - 'vertical': The vertical coordinate values of the observations.
197
+ - 'vert_unit': The unit of the vertical coordinate values.
198
+ levels (list): A list of bin edges for the vertical levels.
199
+ verticalUnit (str, optional): The unit of the vertical axis (e.g., 'pressure (Pa)'). Default is 'pressure (Pa)'.
200
+
201
+ Returns:
202
+ pandas.DataFrame: The input DataFrame with additional columns for the binned vertical levels and their midpoints:
203
+ - 'vlevels': The categorized vertical levels.
204
+ - 'midpoint': The midpoint of each vertical level bin.
205
+
206
+ Notes:
207
+ - The function modifies the input DataFrame by adding 'vlevels' and 'midpoint' columns.
208
+ - The 'midpoint' values are calculated as half the midpoint of each vertical level bin.
209
+ """
210
+ pd.options.mode.copy_on_write = True
211
+ df.loc[df["vert_unit"] == verticalUnit, "vlevels"] = pd.cut(
212
+ df.loc[df["vert_unit"] == verticalUnit, "vertical"], levels
213
+ )
214
+ if verticalUnit == "pressure (Pa)":
215
+ df.loc[:, "midpoint"] = df["vlevels"].apply(
216
+ lambda x: x.mid
217
+ ) # HK todo units HPa - change now or in plotting?
218
+ df.loc[:, "vlevels"] = df["vlevels"].apply(
219
+ lambda x: x
220
+ ) # HK todo units HPa - change now or in plotting?
221
+ else:
222
+ df.loc[:, "midpoint"] = df["vlevels"].apply(lambda x: x.mid)
223
+
224
+
225
+ @apply_to_phases_by_type_return_df
226
+ def grand_statistics(df, phase):
227
+
228
+ # assuming diag_stats has been called
229
+ grand = (
230
+ df.groupby(["type"], observed=False)
231
+ .agg(
232
+ {
233
+ f"{phase}_sq_err": mean_then_sqrt,
234
+ f"{phase}_bias": "mean",
235
+ f"{phase}_totalvar": mean_then_sqrt,
236
+ }
237
+ )
238
+ .reset_index()
239
+ )
240
+
241
+ grand.rename(columns={f"{phase}_sq_err": f"{phase}_rmse"}, inplace=True)
242
+ grand.rename(columns={f"{phase}_totalvar": f"{phase}_totalspread"}, inplace=True)
243
+
244
+ return grand
245
+
246
+
247
+ @apply_to_phases_by_type_return_df
248
+ def layer_statistics(df, phase):
249
+
250
+ # assuming diag_stats has been called
251
+ layer_stats = (
252
+ df.groupby(["midpoint", "type"], observed=False)
253
+ .agg(
254
+ {
255
+ f"{phase}_sq_err": mean_then_sqrt,
256
+ f"{phase}_bias": "mean",
257
+ f"{phase}_totalvar": mean_then_sqrt,
258
+ "vert_unit": "first",
259
+ "vlevels": "first",
260
+ }
261
+ )
262
+ .reset_index()
263
+ )
264
+
265
+ layer_stats.rename(columns={f"{phase}_sq_err": f"{phase}_rmse"}, inplace=True)
266
+ layer_stats.rename(
267
+ columns={f"{phase}_totalvar": f"{phase}_totalspread"}, inplace=True
268
+ )
269
+
270
+ return layer_stats
271
+
272
+
273
+ def possible_vs_used(df):
274
+ """
275
+ Calculates the count of possible vs. used observations by type.
276
+
277
+ This function takes a DataFrame containing observation data, including a 'type' column for the observation
278
+ type and an 'observation' column. The number of used observations ('used'), is the total number
279
+ minus the observations that failed quality control checks (as determined by the `select_failed_qcs` function).
280
+ The result is a DataFrame with each observation type, the count of possible observations, and the count of
281
+ used observations.
282
+
283
+ Returns:
284
+ pd.DataFrame: A DataFrame with three columns: 'type', 'possible', and 'used'. 'type' is the observation type,
285
+ 'possible' is the count of all observations of that type, and 'used' is the count of observations of that type
286
+ that passed quality control checks.
287
+ """
288
+ possible = df.groupby("type")["observation"].count()
289
+ possible.rename("possible", inplace=True)
290
+
291
+ failed_qcs = select_failed_qcs(df).groupby("type")["observation"].count()
292
+ used = possible - failed_qcs.reindex(possible.index, fill_value=0)
293
+ used.rename("used", inplace=True)
294
+
295
+ return pd.concat([possible, used], axis=1).reset_index()
296
+
297
+
298
+ def possible_vs_used_by_layer(df):
299
+ """
300
+ Calculates the count of possible vs. used observations by type and vertical level.
301
+ """
302
+ possible = df.groupby(["type", "midpoint"], observed=False)["type"].count()
303
+ possible.rename("possible", inplace=True)
304
+
305
+ failed_qcs = (
306
+ select_failed_qcs(df)
307
+ .groupby(["type", "midpoint"], observed=False)["type"]
308
+ .count()
309
+ )
310
+ used = possible - failed_qcs.reindex(possible.index, fill_value=0)
311
+ used.rename("used", inplace=True)
312
+
313
+ return pd.concat([possible, used], axis=1).reset_index()
314
+
315
+
316
+ def select_failed_qcs(df):
317
+ """
318
+ Select rows from the DataFrame where the DART quality control flag is greater than 0.
319
+
320
+ Returns:
321
+ pandas.DataFrame: A DataFrame containing only the rows with a DART quality control flag greater than 0.
322
+ """
323
+ return df[df["DART_quality_control"] > 0]