pydartdiags 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydartdiags/__init__.py +0 -0
- pydartdiags/data.py +246 -0
- pydartdiags/matplots/__init__.py +0 -0
- pydartdiags/matplots/matplots.py +522 -0
- pydartdiags/obs_sequence/__init__.py +0 -0
- pydartdiags/obs_sequence/composite_types.yaml +35 -0
- pydartdiags/obs_sequence/obs_sequence.py +1360 -0
- pydartdiags/plots/__init__.py +0 -0
- pydartdiags/plots/plots.py +191 -0
- pydartdiags/stats/__init__.py +0 -0
- pydartdiags/stats/stats.py +510 -0
- pydartdiags-0.6.4.dist-info/METADATA +45 -0
- pydartdiags-0.6.4.dist-info/RECORD +16 -0
- pydartdiags-0.6.4.dist-info/WHEEL +5 -0
- pydartdiags-0.6.4.dist-info/licenses/LICENSE +201 -0
- pydartdiags-0.6.4.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import numpy as np
|
|
3
|
+
import plotly.express as px
|
|
4
|
+
import plotly.graph_objects as go
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from pydartdiags.stats import stats
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def plot_rank_histogram(df, phase, ens_size):
|
|
10
|
+
"""
|
|
11
|
+
Plots a rank histogram colored by observation type.
|
|
12
|
+
|
|
13
|
+
All histogram bars are initialized to be hidden and can be toggled visible in the plot's legend
|
|
14
|
+
"""
|
|
15
|
+
fig = px.histogram(
|
|
16
|
+
df,
|
|
17
|
+
x=f"{phase}_rank",
|
|
18
|
+
color="type",
|
|
19
|
+
title="Histogram Colored by obs type",
|
|
20
|
+
nbins=ens_size,
|
|
21
|
+
)
|
|
22
|
+
fig.update_xaxes(range=[1, ens_size + 1])
|
|
23
|
+
for trace in fig.data:
|
|
24
|
+
trace.visible = "legendonly"
|
|
25
|
+
fig.show()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def plot_profile(df_in, verticalUnit):
|
|
29
|
+
"""Assumes diag_stats has been run on the dataframe and the resulting dataframe is passed in"""
|
|
30
|
+
|
|
31
|
+
df = stats.layer_statistics(df_in)
|
|
32
|
+
if "posterior_rmse" in df.columns:
|
|
33
|
+
fig_rmse = plot_profile_prior_post(df, "rmse", verticalUnit)
|
|
34
|
+
fig_rmse.show()
|
|
35
|
+
fig_bias = plot_profile_prior_post(df, "bias", verticalUnit)
|
|
36
|
+
fig_bias.show()
|
|
37
|
+
fig_ts = plot_profile_prior_post(df, "totalspread", verticalUnit)
|
|
38
|
+
fig_ts.show()
|
|
39
|
+
else:
|
|
40
|
+
fig_rmse = plot_profile_prior(df, "rmse", verticalUnit)
|
|
41
|
+
fig_rmse.show()
|
|
42
|
+
fig_bias = plot_profile_prior(df, "bias", verticalUnit)
|
|
43
|
+
fig_bias.show()
|
|
44
|
+
fig_ts = plot_profile_prior(df, "totalspread", verticalUnit)
|
|
45
|
+
fig_ts.show()
|
|
46
|
+
|
|
47
|
+
return fig_rmse, fig_ts, fig_bias
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def plot_profile_prior_post(df_profile, stat, verticalUnit):
|
|
51
|
+
"""
|
|
52
|
+
Plots prior and posterior statistics by vertical level for different observation types.
|
|
53
|
+
|
|
54
|
+
Parameters:
|
|
55
|
+
df_profile (pd.DataFrame): DataFrame containing the prior and posterior statistics.
|
|
56
|
+
stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
|
|
57
|
+
verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
plotly.graph_objects.Figure: The generated Plotly figure.
|
|
61
|
+
"""
|
|
62
|
+
# Filter the DataFrame to include only rows with the required verticalUnit
|
|
63
|
+
df_filtered = df_profile[df_profile["vert_unit"] == verticalUnit]
|
|
64
|
+
|
|
65
|
+
# Reshape DataFrame to long format for easier plotting
|
|
66
|
+
df_long = pd.melt(
|
|
67
|
+
df_profile,
|
|
68
|
+
id_vars=["midpoint", "type"],
|
|
69
|
+
value_vars=["prior_" + stat, "posterior_" + stat],
|
|
70
|
+
var_name=stat + "_type",
|
|
71
|
+
value_name=stat + "_value",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Define a color mapping for observation each type
|
|
75
|
+
unique_types = df_long["type"].unique()
|
|
76
|
+
colors = px.colors.qualitative.Plotly
|
|
77
|
+
color_mapping = {
|
|
78
|
+
type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
# Create a mapping for line styles based on stat
|
|
82
|
+
line_styles = {"prior_" + stat: "solid", "posterior_" + stat: "dash"}
|
|
83
|
+
|
|
84
|
+
# Create the figure
|
|
85
|
+
fig_stat = go.Figure()
|
|
86
|
+
|
|
87
|
+
# Loop through each type and type to add traces
|
|
88
|
+
for t in df_long["type"].unique():
|
|
89
|
+
for stat_type, dash_style in line_styles.items():
|
|
90
|
+
# Filter the DataFrame for this type and stat
|
|
91
|
+
df_filtered = df_long[
|
|
92
|
+
(df_long[stat + "_type"] == stat_type) & (df_long["type"] == t)
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
# Add a trace
|
|
96
|
+
fig_stat.add_trace(
|
|
97
|
+
go.Scatter(
|
|
98
|
+
x=df_filtered[stat + "_value"],
|
|
99
|
+
y=df_filtered["midpoint"],
|
|
100
|
+
mode="lines+markers",
|
|
101
|
+
name=(
|
|
102
|
+
"prior " + t if stat_type == "prior_" + stat else "post "
|
|
103
|
+
), # Show legend for "prior_stat OBS TYPE" only
|
|
104
|
+
line=dict(
|
|
105
|
+
dash=dash_style, color=color_mapping[t]
|
|
106
|
+
), # Same color for all traces in group
|
|
107
|
+
marker=dict(size=8, color=color_mapping[t]),
|
|
108
|
+
legendgroup=t, # Group traces by type
|
|
109
|
+
)
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Update layout
|
|
113
|
+
fig_stat.update_layout(
|
|
114
|
+
title=stat + " by Level",
|
|
115
|
+
xaxis_title=stat,
|
|
116
|
+
yaxis_title=verticalUnit,
|
|
117
|
+
width=800,
|
|
118
|
+
height=800,
|
|
119
|
+
template="plotly_white",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if verticalUnit == "pressure (Pa)":
|
|
123
|
+
fig_stat.update_yaxes(autorange="reversed")
|
|
124
|
+
|
|
125
|
+
return fig_stat
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def plot_profile_prior(df_profile, stat, verticalUnit):
|
|
129
|
+
"""
|
|
130
|
+
Plots prior statistics by vertical level for different observation types.
|
|
131
|
+
|
|
132
|
+
Parameters:
|
|
133
|
+
df_profile (pd.DataFrame): DataFrame containing the prior statistics.
|
|
134
|
+
stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
|
|
135
|
+
verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
plotly.graph_objects.Figure: The generated Plotly figure.
|
|
139
|
+
"""
|
|
140
|
+
# Reshape DataFrame to long format for easier plotting - not needed for prior only, but
|
|
141
|
+
# leaving it in for consistency with the plot_profile_prior_post function for now
|
|
142
|
+
df_long = pd.melt(
|
|
143
|
+
df_profile,
|
|
144
|
+
id_vars=["midpoint", "type"],
|
|
145
|
+
value_vars=["prior_" + stat],
|
|
146
|
+
var_name=stat + "_type",
|
|
147
|
+
value_name=stat + "_value",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Define a color mapping for observation each type
|
|
151
|
+
unique_types = df_long["type"].unique()
|
|
152
|
+
colors = px.colors.qualitative.Plotly
|
|
153
|
+
color_mapping = {
|
|
154
|
+
type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
# Create the figure
|
|
158
|
+
fig_stat = go.Figure()
|
|
159
|
+
|
|
160
|
+
# Loop through each type to add traces
|
|
161
|
+
for t in df_long["type"].unique():
|
|
162
|
+
# Filter the DataFrame for this type and stat
|
|
163
|
+
df_filtered = df_long[(df_long["type"] == t)]
|
|
164
|
+
|
|
165
|
+
# Add a trace
|
|
166
|
+
fig_stat.add_trace(
|
|
167
|
+
go.Scatter(
|
|
168
|
+
x=df_filtered[stat + "_value"],
|
|
169
|
+
y=df_filtered["midpoint"],
|
|
170
|
+
mode="lines+markers",
|
|
171
|
+
name="prior " + t,
|
|
172
|
+
line=dict(color=color_mapping[t]), # Same color for all traces in group
|
|
173
|
+
marker=dict(size=8, color=color_mapping[t]),
|
|
174
|
+
legendgroup=t, # Group traces by type
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Update layout
|
|
179
|
+
fig_stat.update_layout(
|
|
180
|
+
title=stat + " by Level",
|
|
181
|
+
xaxis_title=stat,
|
|
182
|
+
yaxis_title=verticalUnit,
|
|
183
|
+
width=800,
|
|
184
|
+
height=800,
|
|
185
|
+
template="plotly_white",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if verticalUnit == "pressure (Pa)":
|
|
189
|
+
fig_stat.update_yaxes(autorange="reversed")
|
|
190
|
+
|
|
191
|
+
return fig_stat
|
|
File without changes
|