pydartdiags 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,191 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import numpy as np
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ import pandas as pd
6
+ from pydartdiags.stats import stats
7
+
8
+
9
+ def plot_rank_histogram(df, phase, ens_size):
10
+ """
11
+ Plots a rank histogram colored by observation type.
12
+
13
+ All histogram bars are initialized to be hidden and can be toggled visible in the plot's legend
14
+ """
15
+ fig = px.histogram(
16
+ df,
17
+ x=f"{phase}_rank",
18
+ color="type",
19
+ title="Histogram Colored by obs type",
20
+ nbins=ens_size,
21
+ )
22
+ fig.update_xaxes(range=[1, ens_size + 1])
23
+ for trace in fig.data:
24
+ trace.visible = "legendonly"
25
+ fig.show()
26
+
27
+
28
+ def plot_profile(df_in, verticalUnit):
29
+ """Assumes diag_stats has been run on the dataframe and the resulting dataframe is passed in"""
30
+
31
+ df = stats.layer_statistics(df_in)
32
+ if "posterior_rmse" in df.columns:
33
+ fig_rmse = plot_profile_prior_post(df, "rmse", verticalUnit)
34
+ fig_rmse.show()
35
+ fig_bias = plot_profile_prior_post(df, "bias", verticalUnit)
36
+ fig_bias.show()
37
+ fig_ts = plot_profile_prior_post(df, "totalspread", verticalUnit)
38
+ fig_ts.show()
39
+ else:
40
+ fig_rmse = plot_profile_prior(df, "rmse", verticalUnit)
41
+ fig_rmse.show()
42
+ fig_bias = plot_profile_prior(df, "bias", verticalUnit)
43
+ fig_bias.show()
44
+ fig_ts = plot_profile_prior(df, "totalspread", verticalUnit)
45
+ fig_ts.show()
46
+
47
+ return fig_rmse, fig_ts, fig_bias
48
+
49
+
50
+ def plot_profile_prior_post(df_profile, stat, verticalUnit):
51
+ """
52
+ Plots prior and posterior statistics by vertical level for different observation types.
53
+
54
+ Parameters:
55
+ df_profile (pd.DataFrame): DataFrame containing the prior and posterior statistics.
56
+ stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
57
+ verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
58
+
59
+ Returns:
60
+ plotly.graph_objects.Figure: The generated Plotly figure.
61
+ """
62
+ # Filter the DataFrame to include only rows with the required verticalUnit
63
+ df_filtered = df_profile[df_profile["vert_unit"] == verticalUnit]
64
+
65
+ # Reshape DataFrame to long format for easier plotting
66
+ df_long = pd.melt(
67
+ df_profile,
68
+ id_vars=["midpoint", "type"],
69
+ value_vars=["prior_" + stat, "posterior_" + stat],
70
+ var_name=stat + "_type",
71
+ value_name=stat + "_value",
72
+ )
73
+
74
+ # Define a color mapping for observation each type
75
+ unique_types = df_long["type"].unique()
76
+ colors = px.colors.qualitative.Plotly
77
+ color_mapping = {
78
+ type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)
79
+ }
80
+
81
+ # Create a mapping for line styles based on stat
82
+ line_styles = {"prior_" + stat: "solid", "posterior_" + stat: "dash"}
83
+
84
+ # Create the figure
85
+ fig_stat = go.Figure()
86
+
87
+ # Loop through each type and type to add traces
88
+ for t in df_long["type"].unique():
89
+ for stat_type, dash_style in line_styles.items():
90
+ # Filter the DataFrame for this type and stat
91
+ df_filtered = df_long[
92
+ (df_long[stat + "_type"] == stat_type) & (df_long["type"] == t)
93
+ ]
94
+
95
+ # Add a trace
96
+ fig_stat.add_trace(
97
+ go.Scatter(
98
+ x=df_filtered[stat + "_value"],
99
+ y=df_filtered["midpoint"],
100
+ mode="lines+markers",
101
+ name=(
102
+ "prior " + t if stat_type == "prior_" + stat else "post "
103
+ ), # Show legend for "prior_stat OBS TYPE" only
104
+ line=dict(
105
+ dash=dash_style, color=color_mapping[t]
106
+ ), # Same color for all traces in group
107
+ marker=dict(size=8, color=color_mapping[t]),
108
+ legendgroup=t, # Group traces by type
109
+ )
110
+ )
111
+
112
+ # Update layout
113
+ fig_stat.update_layout(
114
+ title=stat + " by Level",
115
+ xaxis_title=stat,
116
+ yaxis_title=verticalUnit,
117
+ width=800,
118
+ height=800,
119
+ template="plotly_white",
120
+ )
121
+
122
+ if verticalUnit == "pressure (Pa)":
123
+ fig_stat.update_yaxes(autorange="reversed")
124
+
125
+ return fig_stat
126
+
127
+
128
+ def plot_profile_prior(df_profile, stat, verticalUnit):
129
+ """
130
+ Plots prior statistics by vertical level for different observation types.
131
+
132
+ Parameters:
133
+ df_profile (pd.DataFrame): DataFrame containing the prior statistics.
134
+ stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
135
+ verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
136
+
137
+ Returns:
138
+ plotly.graph_objects.Figure: The generated Plotly figure.
139
+ """
140
+ # Reshape DataFrame to long format for easier plotting - not needed for prior only, but
141
+ # leaving it in for consistency with the plot_profile_prior_post function for now
142
+ df_long = pd.melt(
143
+ df_profile,
144
+ id_vars=["midpoint", "type"],
145
+ value_vars=["prior_" + stat],
146
+ var_name=stat + "_type",
147
+ value_name=stat + "_value",
148
+ )
149
+
150
+ # Define a color mapping for observation each type
151
+ unique_types = df_long["type"].unique()
152
+ colors = px.colors.qualitative.Plotly
153
+ color_mapping = {
154
+ type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)
155
+ }
156
+
157
+ # Create the figure
158
+ fig_stat = go.Figure()
159
+
160
+ # Loop through each type to add traces
161
+ for t in df_long["type"].unique():
162
+ # Filter the DataFrame for this type and stat
163
+ df_filtered = df_long[(df_long["type"] == t)]
164
+
165
+ # Add a trace
166
+ fig_stat.add_trace(
167
+ go.Scatter(
168
+ x=df_filtered[stat + "_value"],
169
+ y=df_filtered["midpoint"],
170
+ mode="lines+markers",
171
+ name="prior " + t,
172
+ line=dict(color=color_mapping[t]), # Same color for all traces in group
173
+ marker=dict(size=8, color=color_mapping[t]),
174
+ legendgroup=t, # Group traces by type
175
+ )
176
+ )
177
+
178
+ # Update layout
179
+ fig_stat.update_layout(
180
+ title=stat + " by Level",
181
+ xaxis_title=stat,
182
+ yaxis_title=verticalUnit,
183
+ width=800,
184
+ height=800,
185
+ template="plotly_white",
186
+ )
187
+
188
+ if verticalUnit == "pressure (Pa)":
189
+ fig_stat.update_yaxes(autorange="reversed")
190
+
191
+ return fig_stat
File without changes