google-meridian 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.5.0.dist-info → google_meridian-1.5.1.dist-info}/METADATA +2 -2
- {google_meridian-1.5.0.dist-info → google_meridian-1.5.1.dist-info}/RECORD +12 -12
- meridian/analysis/review/checks.py +118 -116
- meridian/analysis/review/constants.py +3 -3
- meridian/analysis/review/results.py +131 -68
- meridian/analysis/review/reviewer.py +4 -22
- meridian/model/eda/eda_engine.py +1 -0
- meridian/version.py +1 -1
- schema/serde/meridian_serde.py +6 -2
- {google_meridian-1.5.0.dist-info → google_meridian-1.5.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.5.0.dist-info → google_meridian-1.5.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.5.0.dist-info → google_meridian-1.5.1.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: google-meridian
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.1
|
|
4
4
|
Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
|
|
5
5
|
Author-email: The Meridian Authors <no-reply@google.com>
|
|
6
6
|
Project-URL: homepage, https://github.com/google/meridian
|
|
@@ -210,7 +210,7 @@ To cite this repository:
|
|
|
210
210
|
author = {Google Meridian Marketing Mix Modeling Team},
|
|
211
211
|
title = {Meridian: Marketing Mix Modeling},
|
|
212
212
|
url = {https://github.com/google/meridian},
|
|
213
|
-
version = {1.5.
|
|
213
|
+
version = {1.5.1},
|
|
214
214
|
year = {2025},
|
|
215
215
|
}
|
|
216
216
|
```
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
google_meridian-1.5.
|
|
1
|
+
google_meridian-1.5.1.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
|
2
2
|
meridian/__init__.py,sha256=0fOT5oNZF7-pbiWWGUefV-ysafttieG079m1ijMFQO8,861
|
|
3
3
|
meridian/constants.py,sha256=3OTX4TKp_lLuuO6V2fA2icgoJQxQUHJjIDxH8yBD4Kw,20359
|
|
4
|
-
meridian/version.py,sha256=
|
|
4
|
+
meridian/version.py,sha256=kAZOoCSOwL13YeWOWHzABMmvsYhX2egAYQqEeUsPS2M,644
|
|
5
5
|
meridian/analysis/__init__.py,sha256=AM7xpqoeC-mmY4tPIyHisjQ2MICI7v3jSri--DhDqXA,874
|
|
6
6
|
meridian/analysis/analyzer.py,sha256=8x6yrDnk_Sy_-fp9M9ZDZzKXqrszrZ_xmXcA78Tf3AY,224888
|
|
7
7
|
meridian/analysis/optimizer.py,sha256=2_saKikIXuQuwGzNJpkuX2hYpdgcoBYR6_-m1MbbkSM,127223
|
|
@@ -10,11 +10,11 @@ meridian/analysis/summary_text.py,sha256=I_smDkZJYp2j77ea-9AIbgeraDa7-qUYyb-IthP
|
|
|
10
10
|
meridian/analysis/test_utils.py,sha256=pQQPhKertGawgH2ry1hxiV1aAOlVcV5aHoafqQPzS6s,98743
|
|
11
11
|
meridian/analysis/visualizer.py,sha256=MprZXNMOAF1BTJi5zdtHRnSwNZ20YYEe8UL2jGJk64k,94850
|
|
12
12
|
meridian/analysis/review/__init__.py,sha256=cF24EbhiVSs-tvtRf59uVin39tu6aCTTCaeEdv6ISZ8,804
|
|
13
|
-
meridian/analysis/review/checks.py,sha256=
|
|
13
|
+
meridian/analysis/review/checks.py,sha256=Q-niQrgyird1orFgkjgEqQDef57ooRRKR8IhYFPMXsc,27155
|
|
14
14
|
meridian/analysis/review/configs.py,sha256=5JJ8v6n22GNBmE78xNX6jwdjkZz2qar4Q9YTcVqzcoI,3653
|
|
15
|
-
meridian/analysis/review/constants.py,sha256=
|
|
16
|
-
meridian/analysis/review/results.py,sha256=
|
|
17
|
-
meridian/analysis/review/reviewer.py,sha256=
|
|
15
|
+
meridian/analysis/review/constants.py,sha256=DM6mgDXiLXcyA89EYthWOHHCN6_CP3ey_DuPhP3ZWu4,1497
|
|
16
|
+
meridian/analysis/review/results.py,sha256=HtKW3qw8T2wJ_Ei4wv-qhWTizczD3H_dS0tc9Nik5-4,19252
|
|
17
|
+
meridian/analysis/review/reviewer.py,sha256=llE4ssH4QK4xVZwKcVzrXqEsef2pnCWTKHnmA4sT46M,6009
|
|
18
18
|
meridian/backend/__init__.py,sha256=DaFTfvsqYtkheFvgV2kdPsyJoz8c-X2_ISSMlleHbVk,45411
|
|
19
19
|
meridian/backend/config.py,sha256=B9VQnhBfg9RW04GNbt7F5uCugByenoJzt-keFLLYEp8,3561
|
|
20
20
|
meridian/backend/test_utils.py,sha256=oJNosF_x_BzNuia8LzLFb_YfjGWHRCzR5FXNN5KQ8sw,13738
|
|
@@ -45,7 +45,7 @@ meridian/model/spec.py,sha256=hmVz1LZlE1un3Lt2Hx6L8FR7iG8OtL1i6XScCXqvVzE,19684
|
|
|
45
45
|
meridian/model/transformers.py,sha256=HxlVJitxP-wu-NOHU0tArFUZ4NAO3c7adAYj4Zvqnvo,8363
|
|
46
46
|
meridian/model/eda/__init__.py,sha256=bMj9kd2LWU_LQZAjQv54FFggzdv4CKRYblvc-0cHXc4,768
|
|
47
47
|
meridian/model/eda/constants.py,sha256=maaZ0suGwhWbHIoNqQis9mV4LwlNexyADYx92U2Mrew,15124
|
|
48
|
-
meridian/model/eda/eda_engine.py,sha256
|
|
48
|
+
meridian/model/eda/eda_engine.py,sha256=-lz2PLAhujkCfq2IMhDnIFePPEF-oZAScnuvFmgJK-Y,88516
|
|
49
49
|
meridian/model/eda/eda_outcome.py,sha256=xCy0sl92Vge0ANnMqLuadjFTeZlyAs0rBh0zRBUrpzM,11328
|
|
50
50
|
meridian/model/eda/eda_spec.py,sha256=diieYyZH0ee3ZLy0rGFMcWrrgiUrz2HctMwOrmtJR6w,2871
|
|
51
51
|
meridian/templates/card.html.jinja,sha256=AWgKGLPf7qTFVNy-vylXm-_tatzW9ngPk1mJm9sTCPg,1332
|
|
@@ -99,14 +99,14 @@ schema/serde/function_registry.py,sha256=GbgC5_9NDcA9Y7nqmdJ-4-LK5JPhhfI50Lmfy5Z
|
|
|
99
99
|
schema/serde/hyperparameters.py,sha256=0Lgep_lT5Ro6svvLPdR6OyL_qCb0-bRrxJVxsmySmJs,12176
|
|
100
100
|
schema/serde/inference_data.py,sha256=DrwE9hU8LMrl0z8W_sUSIaPrRdym_lu0iOqpT4KZxsA,3623
|
|
101
101
|
schema/serde/marketing_data.py,sha256=yb-fRTe84Sjg7-v3wsvYRRXvrxLSFWSenO0_ikMvUpk,44845
|
|
102
|
-
schema/serde/meridian_serde.py,sha256=
|
|
102
|
+
schema/serde/meridian_serde.py,sha256=5q2AkZ52Ew0SJUH9g4VXqWHSwjzVJ_-ChCx6B5FA8CE,16246
|
|
103
103
|
schema/serde/serde.py,sha256=8vUqhJxvZgX9UY3rXTyWJznRgapwDzzaHXDHwV_kKTA,1612
|
|
104
104
|
schema/serde/test_data.py,sha256=7hfEWyvZ9WcAkVAOXt6elX8stJlsfhfd-ASlHo9SRb8,107342
|
|
105
105
|
schema/utils/__init__.py,sha256=OzDmXWCpogCt6EkremIShzTowsZF8dHzfEjkJkE9qfk,767
|
|
106
106
|
schema/utils/date_range_bucketing.py,sha256=14vcRGf3odWT9mBdCykRNmVCEiuUI_1SvVygNzvqBuM,3809
|
|
107
107
|
schema/utils/proto_enum_converter.py,sha256=vCKGQGWfCt6W7GZy7QQRFAj3XqLUQwt_eWZzsX6pA0E,4021
|
|
108
108
|
schema/utils/time_record.py,sha256=-KzHFjvSBUUXsfESPAfcJP_VFxaFLqj90Ac0kgKWfpI,4624
|
|
109
|
-
google_meridian-1.5.
|
|
110
|
-
google_meridian-1.5.
|
|
111
|
-
google_meridian-1.5.
|
|
112
|
-
google_meridian-1.5.
|
|
109
|
+
google_meridian-1.5.1.dist-info/METADATA,sha256=SWHh9POzkljYJxkNgvQaG0GMHK9BHsjsW0ksaoIjqaI,10024
|
|
110
|
+
google_meridian-1.5.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
111
|
+
google_meridian-1.5.1.dist-info/top_level.txt,sha256=oAi0z-fUuo6p8SnJ0WrojGR2mKOWDz43yr6EjzaXqy8,32
|
|
112
|
+
google_meridian-1.5.1.dist-info/RECORD,,
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Implementation of the Model Quality Checks."""
|
|
16
16
|
|
|
17
17
|
import abc
|
|
18
|
-
from collections.abc import Sequence
|
|
18
|
+
from collections.abc import MutableMapping, Sequence
|
|
19
19
|
import dataclasses
|
|
20
20
|
from typing import Generic, TypeVar
|
|
21
21
|
import warnings
|
|
@@ -77,31 +77,16 @@ class ConvergenceCheck(
|
|
|
77
77
|
if not valid_rhat_items:
|
|
78
78
|
return results.ConvergenceCheckResult(
|
|
79
79
|
case=results.ConvergenceCases.CONVERGED,
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
review_constants.CONVERGENCE_THRESHOLD: (
|
|
84
|
-
self._config.convergence_threshold
|
|
85
|
-
),
|
|
86
|
-
},
|
|
80
|
+
config=self._config,
|
|
81
|
+
max_rhat=np.nan,
|
|
82
|
+
max_parameter=np.nan,
|
|
87
83
|
)
|
|
88
84
|
|
|
89
85
|
max_parameter, max_rhat = max(max_rhats.items(), key=lambda item: item[1])
|
|
90
86
|
|
|
91
|
-
details = {
|
|
92
|
-
review_constants.RHAT: max_rhat,
|
|
93
|
-
review_constants.PARAMETER: max_parameter,
|
|
94
|
-
review_constants.CONVERGENCE_THRESHOLD: (
|
|
95
|
-
self._config.convergence_threshold
|
|
96
|
-
),
|
|
97
|
-
}
|
|
98
|
-
|
|
99
87
|
# Case 1: Converged.
|
|
100
88
|
if max_rhat < self._config.convergence_threshold:
|
|
101
|
-
|
|
102
|
-
case=results.ConvergenceCases.CONVERGED,
|
|
103
|
-
details=details,
|
|
104
|
-
)
|
|
89
|
+
case = results.ConvergenceCases.CONVERGED
|
|
105
90
|
|
|
106
91
|
# Case 2: Not fully converged, but potentially acceptable.
|
|
107
92
|
elif (
|
|
@@ -109,17 +94,18 @@ class ConvergenceCheck(
|
|
|
109
94
|
<= max_rhat
|
|
110
95
|
< self._config.not_fully_convergence_threshold
|
|
111
96
|
):
|
|
112
|
-
|
|
113
|
-
case=results.ConvergenceCases.NOT_FULLY_CONVERGED,
|
|
114
|
-
details=details,
|
|
115
|
-
)
|
|
97
|
+
case = results.ConvergenceCases.NOT_FULLY_CONVERGED
|
|
116
98
|
|
|
117
99
|
# Case 3: Not converged and unacceptable.
|
|
118
100
|
else: # max_rhat >= divergence_threshold
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
101
|
+
case = results.ConvergenceCases.NOT_CONVERGED
|
|
102
|
+
|
|
103
|
+
return results.ConvergenceCheckResult(
|
|
104
|
+
case=case,
|
|
105
|
+
config=self._config,
|
|
106
|
+
max_rhat=max_rhat,
|
|
107
|
+
max_parameter=max_parameter,
|
|
108
|
+
)
|
|
123
109
|
|
|
124
110
|
|
|
125
111
|
# ==============================================================================
|
|
@@ -131,33 +117,25 @@ class BaselineCheck(
|
|
|
131
117
|
"""Checks for negative baseline probability."""
|
|
132
118
|
|
|
133
119
|
def run(self) -> results.BaselineCheckResult:
|
|
134
|
-
prob = self._analyzer.negative_baseline_probability()
|
|
135
|
-
|
|
136
|
-
review_constants.NEGATIVE_BASELINE_PROB: prob,
|
|
137
|
-
review_constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD: (
|
|
138
|
-
self._config.negative_baseline_prob_fail_threshold
|
|
139
|
-
),
|
|
140
|
-
review_constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD: (
|
|
141
|
-
self._config.negative_baseline_prob_review_threshold
|
|
142
|
-
),
|
|
143
|
-
}
|
|
120
|
+
prob = float(self._analyzer.negative_baseline_probability())
|
|
121
|
+
|
|
144
122
|
# Case 1: FAIL
|
|
145
123
|
if prob > self._config.negative_baseline_prob_fail_threshold:
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
details=details,
|
|
149
|
-
)
|
|
124
|
+
case = results.BaselineCases.FAIL
|
|
125
|
+
|
|
150
126
|
# Case 2: REVIEW
|
|
151
127
|
elif prob >= self._config.negative_baseline_prob_review_threshold:
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
details=details,
|
|
155
|
-
)
|
|
128
|
+
case = results.BaselineCases.REVIEW
|
|
129
|
+
|
|
156
130
|
# Case 3: PASS
|
|
157
131
|
else:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
132
|
+
case = results.BaselineCases.PASS
|
|
133
|
+
|
|
134
|
+
return results.BaselineCheckResult(
|
|
135
|
+
case=case,
|
|
136
|
+
config=self._config,
|
|
137
|
+
negative_baseline_prob=prob,
|
|
138
|
+
)
|
|
161
139
|
|
|
162
140
|
|
|
163
141
|
# ==============================================================================
|
|
@@ -189,45 +167,40 @@ class BayesianPPPCheck(
|
|
|
189
167
|
>= np.abs(total_outcome_actual - total_outcome_expected_mean)
|
|
190
168
|
)
|
|
191
169
|
|
|
192
|
-
details = {
|
|
193
|
-
review_constants.BAYESIAN_PPP: bayesian_ppp,
|
|
194
|
-
}
|
|
195
|
-
|
|
196
170
|
if bayesian_ppp >= self._config.ppp_threshold:
|
|
197
|
-
|
|
198
|
-
case=results.BayesianPPPCases.PASS,
|
|
199
|
-
details=details,
|
|
200
|
-
)
|
|
171
|
+
case = results.BayesianPPPCases.PASS
|
|
201
172
|
else:
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
173
|
+
case = results.BayesianPPPCases.FAIL
|
|
174
|
+
|
|
175
|
+
return results.BayesianPPPCheckResult(
|
|
176
|
+
case=case,
|
|
177
|
+
config=self._config,
|
|
178
|
+
bayesian_ppp=bayesian_ppp,
|
|
179
|
+
)
|
|
206
180
|
|
|
207
181
|
|
|
208
182
|
# ==============================================================================
|
|
209
183
|
# Check: Goodness of Fit
|
|
210
184
|
# ==============================================================================
|
|
211
|
-
def
|
|
212
|
-
|
|
185
|
+
def _set_metrics_from_gof_dataframe(
|
|
186
|
+
metrics: MutableMapping[str, float],
|
|
213
187
|
gof_df: pd.DataFrame,
|
|
214
188
|
geo_granularity: str,
|
|
215
|
-
suffix: str
|
|
189
|
+
suffix: str,
|
|
216
190
|
) -> None:
|
|
217
|
-
"""Sets the `
|
|
191
|
+
"""Sets the `metrics` variable of the GoodnessOfFitCheckResult.
|
|
218
192
|
|
|
219
193
|
This method takes a DataFrame containing goodness of fit metrics and pivots it
|
|
220
|
-
to a Series, which is then added to the `
|
|
194
|
+
to a Series, which is then added to the `metrics` variable of the
|
|
221
195
|
`GoodnessOfFitCheckResult`.
|
|
222
196
|
|
|
223
197
|
Args:
|
|
224
|
-
|
|
198
|
+
metrics: A dictionary to store the goodness of fit metrics in.
|
|
225
199
|
gof_df: A DataFrame containing predictive accuracy of the whole data (if
|
|
226
200
|
holdout set is not used) of filtered to a single evaluation set ("all",
|
|
227
201
|
"train", or "test").
|
|
228
202
|
geo_granularity: The geo granularity of the data ("geo" or "national").
|
|
229
|
-
suffix: A suffix to add to the metric names (e.g., "
|
|
230
|
-
If None, the metrics are added without a suffix.
|
|
203
|
+
suffix: A suffix to add to the metric names (e.g., "_train", "_test").
|
|
231
204
|
"""
|
|
232
205
|
gof_metrics_pivoted = gof_df.pivot(
|
|
233
206
|
index=constants.GEO_GRANULARITY,
|
|
@@ -235,22 +208,15 @@ def _set_details_from_gof_dataframe(
|
|
|
235
208
|
values=constants.VALUE,
|
|
236
209
|
)
|
|
237
210
|
gof_metrics_series = gof_metrics_pivoted.loc[geo_granularity]
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
]
|
|
248
|
-
else:
|
|
249
|
-
details[review_constants.R_SQUARED] = gof_metrics_series[
|
|
250
|
-
constants.R_SQUARED
|
|
251
|
-
]
|
|
252
|
-
details[review_constants.MAPE] = gof_metrics_series[constants.MAPE]
|
|
253
|
-
details[review_constants.WMAPE] = gof_metrics_series[constants.WMAPE]
|
|
211
|
+
metrics[f"{review_constants.R_SQUARED}{suffix}"] = gof_metrics_series[
|
|
212
|
+
constants.R_SQUARED
|
|
213
|
+
]
|
|
214
|
+
metrics[f"{review_constants.MAPE}{suffix}"] = gof_metrics_series[
|
|
215
|
+
constants.MAPE
|
|
216
|
+
]
|
|
217
|
+
metrics[f"{review_constants.WMAPE}{suffix}"] = gof_metrics_series[
|
|
218
|
+
constants.WMAPE
|
|
219
|
+
]
|
|
254
220
|
|
|
255
221
|
|
|
256
222
|
class GoodnessOfFitCheck(
|
|
@@ -269,7 +235,7 @@ class GoodnessOfFitCheck(
|
|
|
269
235
|
gof_metrics = gof_df[gof_df[constants.GEO_GRANULARITY] == geo_granularity]
|
|
270
236
|
is_holdout = constants.EVALUATION_SET_VAR in gof_df.columns
|
|
271
237
|
|
|
272
|
-
|
|
238
|
+
metrics_dict = {}
|
|
273
239
|
case = results.GoodnessOfFitCases.PASS
|
|
274
240
|
|
|
275
241
|
if is_holdout:
|
|
@@ -281,29 +247,71 @@ class GoodnessOfFitCheck(
|
|
|
281
247
|
set_metrics = gof_metrics[
|
|
282
248
|
gof_metrics[constants.EVALUATION_SET_VAR] == evaluation_set
|
|
283
249
|
]
|
|
284
|
-
|
|
285
|
-
|
|
250
|
+
_set_metrics_from_gof_dataframe(
|
|
251
|
+
metrics=metrics_dict,
|
|
286
252
|
gof_df=set_metrics,
|
|
287
253
|
geo_granularity=geo_granularity,
|
|
288
254
|
suffix=suffix,
|
|
289
255
|
)
|
|
290
|
-
if
|
|
256
|
+
if metrics_dict[f"{review_constants.R_SQUARED}{suffix}"] <= 0:
|
|
291
257
|
case = results.GoodnessOfFitCases.REVIEW
|
|
258
|
+
return results.GoodnessOfFitCheckResult(
|
|
259
|
+
case=case,
|
|
260
|
+
metrics=results.GoodnessOfFitMetrics(
|
|
261
|
+
r_squared=metrics_dict[
|
|
262
|
+
f"{review_constants.R_SQUARED}{review_constants.ALL_SUFFIX}"
|
|
263
|
+
],
|
|
264
|
+
mape=metrics_dict[
|
|
265
|
+
f"{review_constants.MAPE}{review_constants.ALL_SUFFIX}"
|
|
266
|
+
],
|
|
267
|
+
wmape=metrics_dict[
|
|
268
|
+
f"{review_constants.WMAPE}{review_constants.ALL_SUFFIX}"
|
|
269
|
+
],
|
|
270
|
+
r_squared_train=metrics_dict[
|
|
271
|
+
f"{review_constants.R_SQUARED}{review_constants.TRAIN_SUFFIX}"
|
|
272
|
+
],
|
|
273
|
+
mape_train=metrics_dict[
|
|
274
|
+
f"{review_constants.MAPE}{review_constants.TRAIN_SUFFIX}"
|
|
275
|
+
],
|
|
276
|
+
wmape_train=metrics_dict[
|
|
277
|
+
f"{review_constants.WMAPE}{review_constants.TRAIN_SUFFIX}"
|
|
278
|
+
],
|
|
279
|
+
r_squared_test=metrics_dict[
|
|
280
|
+
f"{review_constants.R_SQUARED}{review_constants.TEST_SUFFIX}"
|
|
281
|
+
],
|
|
282
|
+
mape_test=metrics_dict[
|
|
283
|
+
f"{review_constants.MAPE}{review_constants.TEST_SUFFIX}"
|
|
284
|
+
],
|
|
285
|
+
wmape_test=metrics_dict[
|
|
286
|
+
f"{review_constants.WMAPE}{review_constants.TEST_SUFFIX}"
|
|
287
|
+
],
|
|
288
|
+
),
|
|
289
|
+
is_holdout=is_holdout,
|
|
290
|
+
)
|
|
292
291
|
else:
|
|
293
|
-
|
|
294
|
-
|
|
292
|
+
_set_metrics_from_gof_dataframe(
|
|
293
|
+
metrics=metrics_dict,
|
|
295
294
|
gof_df=gof_metrics,
|
|
296
295
|
geo_granularity=geo_granularity,
|
|
297
|
-
suffix=
|
|
296
|
+
suffix=review_constants.ALL_SUFFIX,
|
|
298
297
|
)
|
|
299
|
-
if
|
|
298
|
+
if metrics_dict[review_constants.R_SQUARED] <= 0:
|
|
300
299
|
case = results.GoodnessOfFitCases.REVIEW
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
300
|
+
return results.GoodnessOfFitCheckResult(
|
|
301
|
+
case=case,
|
|
302
|
+
metrics=results.GoodnessOfFitMetrics(
|
|
303
|
+
r_squared=metrics_dict[
|
|
304
|
+
f"{review_constants.R_SQUARED}{review_constants.ALL_SUFFIX}"
|
|
305
|
+
],
|
|
306
|
+
mape=metrics_dict[
|
|
307
|
+
f"{review_constants.MAPE}{review_constants.ALL_SUFFIX}"
|
|
308
|
+
],
|
|
309
|
+
wmape=metrics_dict[
|
|
310
|
+
f"{review_constants.WMAPE}{review_constants.ALL_SUFFIX}"
|
|
311
|
+
],
|
|
312
|
+
),
|
|
313
|
+
is_holdout=is_holdout,
|
|
314
|
+
)
|
|
307
315
|
|
|
308
316
|
|
|
309
317
|
# ==============================================================================
|
|
@@ -475,8 +483,10 @@ def _compute_channel_results(
|
|
|
475
483
|
channel_results.append(
|
|
476
484
|
results.ROIConsistencyChannelResult(
|
|
477
485
|
case=case,
|
|
478
|
-
details={},
|
|
479
486
|
channel_name=channel,
|
|
487
|
+
prior_roi_lo=np.nan,
|
|
488
|
+
prior_roi_hi=np.nan,
|
|
489
|
+
posterior_roi_mean=np.nan,
|
|
480
490
|
)
|
|
481
491
|
)
|
|
482
492
|
for i, channel in enumerate(channel_data.all_channels):
|
|
@@ -491,14 +501,10 @@ def _compute_channel_results(
|
|
|
491
501
|
channel_results.append(
|
|
492
502
|
results.ROIConsistencyChannelResult(
|
|
493
503
|
case=case,
|
|
494
|
-
details={
|
|
495
|
-
review_constants.PRIOR_ROI_LO: channel_data.prior_roi_los[i],
|
|
496
|
-
review_constants.PRIOR_ROI_HI: channel_data.prior_roi_his[i],
|
|
497
|
-
review_constants.POSTERIOR_ROI_MEAN: (
|
|
498
|
-
channel_data.posterior_means[i]
|
|
499
|
-
),
|
|
500
|
-
},
|
|
501
504
|
channel_name=channel,
|
|
505
|
+
prior_roi_lo=channel_data.prior_roi_los[i],
|
|
506
|
+
prior_roi_hi=channel_data.prior_roi_his[i],
|
|
507
|
+
posterior_roi_mean=channel_data.posterior_means[i],
|
|
502
508
|
)
|
|
503
509
|
)
|
|
504
510
|
return channel_results
|
|
@@ -558,7 +564,7 @@ def _compute_aggregate_result(
|
|
|
558
564
|
|
|
559
565
|
return results.ROIConsistencyCheckResult(
|
|
560
566
|
case=aggregate_case,
|
|
561
|
-
|
|
567
|
+
aggregate_details=aggregate_details,
|
|
562
568
|
channel_results=channel_results,
|
|
563
569
|
)
|
|
564
570
|
|
|
@@ -734,7 +740,7 @@ class PriorPosteriorShiftCheck(
|
|
|
734
740
|
no_shift_channels.append(channel_name)
|
|
735
741
|
channel_results.append(
|
|
736
742
|
results.PriorPosteriorShiftChannelResult(
|
|
737
|
-
case=case,
|
|
743
|
+
case=case, channel_name=channel_name
|
|
738
744
|
)
|
|
739
745
|
)
|
|
740
746
|
return channel_results, no_shift_channels
|
|
@@ -752,17 +758,13 @@ class PriorPosteriorShiftCheck(
|
|
|
752
758
|
|
|
753
759
|
if no_shift_channels:
|
|
754
760
|
agg_case = results.PriorPosteriorShiftAggregateCases.REVIEW
|
|
755
|
-
final_details = {
|
|
756
|
-
"channels_str": ", ".join(
|
|
757
|
-
f"`{channel}`" for channel in no_shift_channels
|
|
758
|
-
)
|
|
759
|
-
}
|
|
760
761
|
else:
|
|
761
762
|
agg_case = results.PriorPosteriorShiftAggregateCases.PASS
|
|
762
|
-
final_details = {}
|
|
763
763
|
|
|
764
764
|
return results.PriorPosteriorShiftCheckResult(
|
|
765
|
-
case=agg_case,
|
|
765
|
+
case=agg_case,
|
|
766
|
+
channel_results=channel_results,
|
|
767
|
+
no_shift_channels=no_shift_channels,
|
|
766
768
|
)
|
|
767
769
|
|
|
768
770
|
def run(self) -> results.PriorPosteriorShiftCheckResult:
|
|
@@ -32,9 +32,9 @@ NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD = (
|
|
|
32
32
|
R_SQUARED = "r_squared"
|
|
33
33
|
MAPE = "mape"
|
|
34
34
|
WMAPE = "wmape"
|
|
35
|
-
ALL_SUFFIX = "
|
|
36
|
-
TRAIN_SUFFIX = "
|
|
37
|
-
TEST_SUFFIX = "
|
|
35
|
+
ALL_SUFFIX = ""
|
|
36
|
+
TRAIN_SUFFIX = "_train"
|
|
37
|
+
TEST_SUFFIX = "_test"
|
|
38
38
|
EVALUATION_SET_SUFFIXES = (ALL_SUFFIX, TRAIN_SUFFIX, TEST_SUFFIX)
|
|
39
39
|
MEAN = "mean"
|
|
40
40
|
VARIANCE = "variance"
|
|
@@ -14,9 +14,13 @@
|
|
|
14
14
|
|
|
15
15
|
"""Data structures for the Model Quality Checks results."""
|
|
16
16
|
|
|
17
|
+
import abc
|
|
18
|
+
from collections.abc import Mapping
|
|
17
19
|
import dataclasses
|
|
18
20
|
import enum
|
|
19
21
|
from typing import Any
|
|
22
|
+
|
|
23
|
+
from meridian.analysis.review import configs
|
|
20
24
|
from meridian.analysis.review import constants
|
|
21
25
|
|
|
22
26
|
|
|
@@ -58,11 +62,16 @@ class ModelCheckCase(BaseCase):
|
|
|
58
62
|
|
|
59
63
|
|
|
60
64
|
@dataclasses.dataclass(frozen=True)
|
|
61
|
-
class BaseResultData:
|
|
65
|
+
class BaseResultData(abc.ABC):
|
|
62
66
|
"""Base class for check result data."""
|
|
63
67
|
|
|
64
68
|
case: BaseCase
|
|
65
|
-
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
@abc.abstractmethod
|
|
72
|
+
def details(self) -> Mapping[str, Any]:
|
|
73
|
+
"""Returns the details for message formatting."""
|
|
74
|
+
raise NotImplementedError
|
|
66
75
|
|
|
67
76
|
|
|
68
77
|
@dataclasses.dataclass(frozen=True)
|
|
@@ -145,17 +154,18 @@ class ConvergenceCheckResult(CheckResult):
|
|
|
145
154
|
"""The immutable result of the Convergence Check."""
|
|
146
155
|
|
|
147
156
|
case: ConvergenceCases
|
|
157
|
+
config: configs.ConvergenceConfig
|
|
158
|
+
max_rhat: float
|
|
159
|
+
max_parameter: str
|
|
148
160
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
f" Details: {self.details}."
|
|
158
|
-
)
|
|
161
|
+
@property
|
|
162
|
+
def details(self) -> Mapping[str, Any]:
|
|
163
|
+
"""The check result details."""
|
|
164
|
+
return {
|
|
165
|
+
constants.RHAT: self.max_rhat,
|
|
166
|
+
constants.PARAMETER: self.max_parameter,
|
|
167
|
+
constants.CONVERGENCE_THRESHOLD: self.config.convergence_threshold,
|
|
168
|
+
}
|
|
159
169
|
|
|
160
170
|
|
|
161
171
|
# ==============================================================================
|
|
@@ -223,24 +233,21 @@ class BaselineCheckResult(CheckResult):
|
|
|
223
233
|
"""The immutable result of the Baseline Check."""
|
|
224
234
|
|
|
225
235
|
case: BaselineCases
|
|
236
|
+
config: configs.BaselineConfig
|
|
237
|
+
negative_baseline_prob: float
|
|
226
238
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
"The message template is missing required formatting arguments:"
|
|
240
|
-
" negative_baseline_prob, negative_baseline_prob_fail_threshold,"
|
|
241
|
-
" negative_baseline_prob_review_threshold. Details:"
|
|
242
|
-
f" {self.details}."
|
|
243
|
-
)
|
|
239
|
+
@property
|
|
240
|
+
def details(self) -> Mapping[str, Any]:
|
|
241
|
+
"""The check result details."""
|
|
242
|
+
return {
|
|
243
|
+
constants.NEGATIVE_BASELINE_PROB: self.negative_baseline_prob,
|
|
244
|
+
constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD: (
|
|
245
|
+
self.config.negative_baseline_prob_fail_threshold
|
|
246
|
+
),
|
|
247
|
+
constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD: (
|
|
248
|
+
self.config.negative_baseline_prob_review_threshold
|
|
249
|
+
),
|
|
250
|
+
}
|
|
244
251
|
|
|
245
252
|
|
|
246
253
|
# ==============================================================================
|
|
@@ -287,14 +294,15 @@ class BayesianPPPCheckResult(CheckResult):
|
|
|
287
294
|
"""The immutable result of the Bayesian Posterior Predictive P-value Check."""
|
|
288
295
|
|
|
289
296
|
case: BayesianPPPCases
|
|
297
|
+
config: configs.BayesianPPPConfig
|
|
298
|
+
bayesian_ppp: float
|
|
290
299
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
)
|
|
300
|
+
@property
|
|
301
|
+
def details(self) -> Mapping[str, Any]:
|
|
302
|
+
"""The check result details."""
|
|
303
|
+
return {
|
|
304
|
+
constants.BAYESIAN_PPP: self.bayesian_ppp,
|
|
305
|
+
}
|
|
298
306
|
|
|
299
307
|
|
|
300
308
|
# ==============================================================================
|
|
@@ -337,55 +345,77 @@ class GoodnessOfFitCases(ModelCheckCase, enum.Enum):
|
|
|
337
345
|
super().__init__(status, message_template, recommendation)
|
|
338
346
|
|
|
339
347
|
|
|
348
|
+
@dataclasses.dataclass(frozen=True)
|
|
349
|
+
class GoodnessOfFitMetrics:
|
|
350
|
+
"""The metrics for the Goodness of Fit Check."""
|
|
351
|
+
|
|
352
|
+
r_squared: float
|
|
353
|
+
mape: float
|
|
354
|
+
wmape: float
|
|
355
|
+
r_squared_train: float | None = None
|
|
356
|
+
mape_train: float | None = None
|
|
357
|
+
wmape_train: float | None = None
|
|
358
|
+
r_squared_test: float | None = None
|
|
359
|
+
mape_test: float | None = None
|
|
360
|
+
wmape_test: float | None = None
|
|
361
|
+
|
|
362
|
+
|
|
340
363
|
@dataclasses.dataclass(frozen=True)
|
|
341
364
|
class GoodnessOfFitCheckResult(CheckResult):
|
|
342
365
|
"""The immutable result of the Goodness of Fit Check."""
|
|
343
366
|
|
|
344
367
|
case: GoodnessOfFitCases
|
|
368
|
+
metrics: GoodnessOfFitMetrics
|
|
345
369
|
is_holdout: bool = False
|
|
346
370
|
|
|
347
371
|
def __post_init__(self):
|
|
348
372
|
if self.is_holdout:
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
if any(key not in self.details for key in required_keys):
|
|
373
|
+
if any(
|
|
374
|
+
metric is None
|
|
375
|
+
for metric in (
|
|
376
|
+
self.metrics.r_squared_train,
|
|
377
|
+
self.metrics.mape_train,
|
|
378
|
+
self.metrics.wmape_train,
|
|
379
|
+
self.metrics.r_squared_test,
|
|
380
|
+
self.metrics.mape_test,
|
|
381
|
+
self.metrics.wmape_test,
|
|
382
|
+
)
|
|
383
|
+
):
|
|
361
384
|
raise ValueError(
|
|
362
385
|
"The message template is missing required formatting arguments for"
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
elif any(
|
|
367
|
-
key not in self.details
|
|
368
|
-
for key in (
|
|
369
|
-
constants.R_SQUARED,
|
|
370
|
-
constants.MAPE,
|
|
371
|
-
constants.WMAPE,
|
|
386
|
+
" holdout case. Required keys: r_squared_train, mape_train,"
|
|
387
|
+
" wmape_train, r_squared_test, mape_test, wmape_test. Metrics:"
|
|
388
|
+
f" {self.metrics}."
|
|
372
389
|
)
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
390
|
+
|
|
391
|
+
@property
|
|
392
|
+
def details(self) -> Mapping[str, Any]:
|
|
393
|
+
"""The check result details."""
|
|
394
|
+
return {
|
|
395
|
+
f"{constants.R_SQUARED}{constants.ALL_SUFFIX}": self.metrics.r_squared,
|
|
396
|
+
f"{constants.MAPE}{constants.ALL_SUFFIX}": self.metrics.mape,
|
|
397
|
+
f"{constants.WMAPE}{constants.ALL_SUFFIX}": self.metrics.wmape,
|
|
398
|
+
f"{constants.R_SQUARED}{constants.TRAIN_SUFFIX}": (
|
|
399
|
+
self.metrics.r_squared_train
|
|
400
|
+
),
|
|
401
|
+
f"{constants.MAPE}{constants.TRAIN_SUFFIX}": self.metrics.mape_train,
|
|
402
|
+
f"{constants.WMAPE}{constants.TRAIN_SUFFIX}": self.metrics.wmape_train,
|
|
403
|
+
f"{constants.R_SQUARED}{constants.TEST_SUFFIX}": (
|
|
404
|
+
self.metrics.r_squared_test
|
|
405
|
+
),
|
|
406
|
+
f"{constants.MAPE}{constants.TEST_SUFFIX}": self.metrics.mape_test,
|
|
407
|
+
f"{constants.WMAPE}{constants.TEST_SUFFIX}": self.metrics.wmape_test,
|
|
408
|
+
}
|
|
379
409
|
|
|
380
410
|
@property
|
|
381
411
|
def recommendation(self) -> str:
|
|
382
|
-
"""
|
|
412
|
+
"""The check result message."""
|
|
383
413
|
if self.is_holdout:
|
|
384
414
|
report_str = (
|
|
385
|
-
"R-squared = {
|
|
415
|
+
"R-squared = {r_squared:.4f} (All),"
|
|
386
416
|
" {r_squared_train:.4f} (Train), {r_squared_test:.4f} (Test); MAPE"
|
|
387
|
-
" = {
|
|
388
|
-
" {mape_test:.4f} (Test); wMAPE = {
|
|
417
|
+
" = {mape:.4f} (All), {mape_train:.4f} (Train),"
|
|
418
|
+
" {mape_test:.4f} (Test); wMAPE = {wmape:.4f} (All),"
|
|
389
419
|
" {wmape_train:.4f} (Train), {wmape_test:.4f} (Test)".format(
|
|
390
420
|
**self.details
|
|
391
421
|
)
|
|
@@ -450,6 +480,18 @@ class ROIConsistencyChannelResult(ChannelResult):
|
|
|
450
480
|
"""The immutable result of ROI Consistency Check for a single channel."""
|
|
451
481
|
|
|
452
482
|
case: ROIConsistencyChannelCases
|
|
483
|
+
prior_roi_lo: float
|
|
484
|
+
prior_roi_hi: float
|
|
485
|
+
posterior_roi_mean: float
|
|
486
|
+
|
|
487
|
+
@property
|
|
488
|
+
def details(self) -> Mapping[str, Any]:
|
|
489
|
+
"""Returns the check result details."""
|
|
490
|
+
return {
|
|
491
|
+
constants.PRIOR_ROI_LO: self.prior_roi_lo,
|
|
492
|
+
constants.PRIOR_ROI_HI: self.prior_roi_hi,
|
|
493
|
+
constants.POSTERIOR_ROI_MEAN: self.posterior_roi_mean,
|
|
494
|
+
}
|
|
453
495
|
|
|
454
496
|
|
|
455
497
|
@dataclasses.dataclass(frozen=True)
|
|
@@ -458,6 +500,12 @@ class ROIConsistencyCheckResult(CheckResult):
|
|
|
458
500
|
|
|
459
501
|
case: ROIConsistencyAggregateCases
|
|
460
502
|
channel_results: list[ROIConsistencyChannelResult]
|
|
503
|
+
aggregate_details: Mapping[str, Any]
|
|
504
|
+
|
|
505
|
+
@property
|
|
506
|
+
def details(self) -> Mapping[str, Any]:
|
|
507
|
+
"""Returns the check result details."""
|
|
508
|
+
return self.aggregate_details
|
|
461
509
|
|
|
462
510
|
|
|
463
511
|
# ==============================================================================
|
|
@@ -517,6 +565,11 @@ class PriorPosteriorShiftChannelResult(ChannelResult):
|
|
|
517
565
|
|
|
518
566
|
case: PriorPosteriorShiftChannelCases
|
|
519
567
|
|
|
568
|
+
@property
|
|
569
|
+
def details(self) -> Mapping[str, Any]:
|
|
570
|
+
"""Returns the check result details."""
|
|
571
|
+
return {}
|
|
572
|
+
|
|
520
573
|
|
|
521
574
|
@dataclasses.dataclass(frozen=True)
|
|
522
575
|
class PriorPosteriorShiftCheckResult(CheckResult):
|
|
@@ -524,6 +577,16 @@ class PriorPosteriorShiftCheckResult(CheckResult):
|
|
|
524
577
|
|
|
525
578
|
case: PriorPosteriorShiftAggregateCases
|
|
526
579
|
channel_results: list[PriorPosteriorShiftChannelResult]
|
|
580
|
+
no_shift_channels: list[str]
|
|
581
|
+
|
|
582
|
+
@property
|
|
583
|
+
def details(self) -> Mapping[str, Any]:
|
|
584
|
+
"""Returns the check result details."""
|
|
585
|
+
return {
|
|
586
|
+
"channels_str": ", ".join(
|
|
587
|
+
f"`{channel}`" for channel in self.no_shift_channels
|
|
588
|
+
)
|
|
589
|
+
}
|
|
527
590
|
|
|
528
591
|
|
|
529
592
|
# ==============================================================================
|
|
@@ -567,7 +630,7 @@ class ReviewSummary:
|
|
|
567
630
|
return "\n".join(report)
|
|
568
631
|
|
|
569
632
|
@property
|
|
570
|
-
def checks_status(self) ->
|
|
633
|
+
def checks_status(self) -> Mapping[str, str]:
|
|
571
634
|
"""Returns a dictionary of check names and statuses."""
|
|
572
635
|
return {
|
|
573
636
|
result.__class__.__name__: result.case.status.name
|
|
@@ -29,7 +29,7 @@ CheckType = typing.Type[checks.BaseCheck]
|
|
|
29
29
|
ConfigInstance = configs.BaseConfig
|
|
30
30
|
ChecksBattery = immutabledict.immutabledict[CheckType, ConfigInstance]
|
|
31
31
|
|
|
32
|
-
|
|
32
|
+
_POST_CONVERGENCE_CHECKS: ChecksBattery = immutabledict.immutabledict({
|
|
33
33
|
checks.BaselineCheck: configs.BaselineConfig(),
|
|
34
34
|
checks.BayesianPPPCheck: configs.BayesianPPPConfig(),
|
|
35
35
|
checks.GoodnessOfFitCheck: configs.GoodnessOfFitConfig(),
|
|
@@ -39,39 +39,22 @@ _DEFAULT_POST_CONVERGENCE_CHECKS: ChecksBattery = immutabledict.immutabledict({
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
class ModelReviewer:
|
|
42
|
-
"""
|
|
42
|
+
"""A tool for executing a series of quality checks on a Meridian model.
|
|
43
43
|
|
|
44
44
|
The reviewer first runs a convergence check. If the model has converged, it
|
|
45
45
|
proceeds to run a battery of post-convergence checks.
|
|
46
46
|
|
|
47
|
-
The
|
|
47
|
+
The battery of post-convergence checks includes:
|
|
48
48
|
- BaselineCheck
|
|
49
49
|
- BayesianPPPCheck
|
|
50
50
|
- GoodnessOfFitCheck
|
|
51
51
|
- PriorPosteriorShiftCheck
|
|
52
52
|
- ROIConsistencyCheck
|
|
53
|
-
Each with its default configuration.
|
|
54
|
-
|
|
55
|
-
This battery of checks can be customized by passing a dictionary to the
|
|
56
|
-
`post_convergence_checks` argument of the constructor, mapping check
|
|
57
|
-
classes to their configuration instances. For example, to run only the
|
|
58
|
-
BaselineCheck with a non-default configuration:
|
|
59
|
-
|
|
60
|
-
```python
|
|
61
|
-
my_checks = {
|
|
62
|
-
checks.BaselineCheck: configs.BaselineConfig(
|
|
63
|
-
negative_baseline_prob_review_threshold=0.1,
|
|
64
|
-
negative_baseline_prob_fail_threshold=0.5,
|
|
65
|
-
)
|
|
66
|
-
}
|
|
67
|
-
reviewer = ModelReviewer(meridian_model, post_convergence_checks=my_checks)
|
|
68
|
-
```
|
|
69
53
|
"""
|
|
70
54
|
|
|
71
55
|
def __init__(
|
|
72
56
|
self,
|
|
73
57
|
meridian,
|
|
74
|
-
post_convergence_checks: ChecksBattery = _DEFAULT_POST_CONVERGENCE_CHECKS,
|
|
75
58
|
):
|
|
76
59
|
self._meridian = meridian
|
|
77
60
|
self._results: list[results.CheckResult] = []
|
|
@@ -79,7 +62,6 @@ class ModelReviewer:
|
|
|
79
62
|
model_context=meridian.model_context,
|
|
80
63
|
inference_data=meridian.inference_data,
|
|
81
64
|
)
|
|
82
|
-
self._post_convergence_checks = post_convergence_checks
|
|
83
65
|
|
|
84
66
|
def _run_and_handle(self, check_class, config):
|
|
85
67
|
instance = check_class(self._meridian, self._analyzer, config) # pytype: disable=not-instantiable
|
|
@@ -139,7 +121,7 @@ class ModelReviewer:
|
|
|
139
121
|
)
|
|
140
122
|
|
|
141
123
|
# Run all other checks in sequence.
|
|
142
|
-
for check_class, config in
|
|
124
|
+
for check_class, config in _POST_CONVERGENCE_CHECKS.items():
|
|
143
125
|
if (
|
|
144
126
|
check_class == checks.PriorPosteriorShiftCheck
|
|
145
127
|
and not self._uses_roi_priors()
|
meridian/model/eda/eda_engine.py
CHANGED
meridian/version.py
CHANGED
schema/serde/meridian_serde.py
CHANGED
|
@@ -357,7 +357,9 @@ def save_meridian(
|
|
|
357
357
|
if not _file_exists(os.path.dirname(file_path)):
|
|
358
358
|
_make_dirs(os.path.dirname(file_path))
|
|
359
359
|
|
|
360
|
-
|
|
360
|
+
mode = 'wb' if file_path.endswith('.binpb') else 'w'
|
|
361
|
+
|
|
362
|
+
with _file_open(file_path, mode) as f:
|
|
361
363
|
# Creates an MmmKernel.
|
|
362
364
|
serialized_kernel = MeridianSerde().serialize(
|
|
363
365
|
mmm,
|
|
@@ -402,7 +404,9 @@ def load_meridian(
|
|
|
402
404
|
Returns:
|
|
403
405
|
Model object loaded from the file path.
|
|
404
406
|
"""
|
|
405
|
-
|
|
407
|
+
mode = 'rb' if file_path.endswith('.binpb') else 'r'
|
|
408
|
+
|
|
409
|
+
with _file_open(file_path, mode) as f:
|
|
406
410
|
if file_path.endswith('.binpb'):
|
|
407
411
|
serialized_model = kernel_pb.MmmKernel.FromString(f.read())
|
|
408
412
|
elif file_path.endswith('.textproto') or file_path.endswith('.txtpb'):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|