google-meridian 1.3.1__py3-none-any.whl → 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/METADATA +7 -7
- {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/RECORD +35 -35
- meridian/analysis/__init__.py +1 -2
- meridian/analysis/analyzer.py +0 -1
- meridian/analysis/optimizer.py +5 -3
- meridian/analysis/review/checks.py +81 -30
- meridian/analysis/review/constants.py +4 -0
- meridian/analysis/review/results.py +40 -9
- meridian/analysis/summarizer.py +1 -1
- meridian/analysis/visualizer.py +1 -1
- meridian/backend/__init__.py +53 -5
- meridian/backend/test_utils.py +72 -0
- meridian/constants.py +1 -0
- meridian/data/load.py +2 -0
- meridian/model/eda/__init__.py +0 -1
- meridian/model/eda/constants.py +12 -2
- meridian/model/eda/eda_engine.py +299 -37
- meridian/model/eda/eda_outcome.py +21 -1
- meridian/model/knots.py +17 -0
- meridian/{analysis/templates → templates}/card.html.jinja +1 -1
- meridian/{analysis/templates → templates}/chart.html.jinja +1 -1
- meridian/{analysis/templates → templates}/chips.html.jinja +1 -1
- meridian/{analysis → templates}/formatter.py +12 -1
- meridian/templates/formatter_test.py +216 -0
- meridian/{analysis/templates → templates}/insights.html.jinja +1 -1
- meridian/{analysis/templates → templates}/stats.html.jinja +1 -1
- meridian/{analysis/templates → templates}/style.css +1 -1
- meridian/{analysis/templates → templates}/style.scss +1 -1
- meridian/{analysis/templates → templates}/summary.html.jinja +4 -2
- meridian/{analysis/templates → templates}/table.html.jinja +1 -1
- meridian/version.py +1 -1
- schema/__init__.py +12 -0
- meridian/model/eda/meridian_eda.py +0 -220
- {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/WHEEL +0 -0
- {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from xml.etree import ElementTree as ET
|
|
16
|
+
|
|
17
|
+
from absl.testing import absltest
|
|
18
|
+
from absl.testing import parameterized
|
|
19
|
+
from meridian.templates import formatter
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FormatterTest(parameterized.TestCase):
|
|
23
|
+
|
|
24
|
+
def test_custom_title_params_correct(self):
|
|
25
|
+
title_params = formatter.custom_title_params('test title')
|
|
26
|
+
self.assertEqual(
|
|
27
|
+
title_params.to_dict(),
|
|
28
|
+
{
|
|
29
|
+
'anchor': 'start',
|
|
30
|
+
'color': '#3C4043',
|
|
31
|
+
'font': 'Google Sans Display',
|
|
32
|
+
'fontSize': 18,
|
|
33
|
+
'fontWeight': 'normal',
|
|
34
|
+
'offset': 10,
|
|
35
|
+
'text': 'test title',
|
|
36
|
+
},
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def test_bar_chart_width(self):
|
|
40
|
+
num_bars = 3
|
|
41
|
+
width = formatter.bar_chart_width(num_bars)
|
|
42
|
+
self.assertEqual(width, 186)
|
|
43
|
+
|
|
44
|
+
@parameterized.named_parameters(
|
|
45
|
+
('zero_percent', 0.0, '0%'),
|
|
46
|
+
('less_than_one_percent', 0.0005, '0.05%'),
|
|
47
|
+
('one_percent', 0.01, '1%'),
|
|
48
|
+
('greater_than_one_percent', 0.4257, '43%'),
|
|
49
|
+
)
|
|
50
|
+
def test_format_percent_correct(self, percent, expected):
|
|
51
|
+
formatted_percent = formatter.format_percent(percent)
|
|
52
|
+
self.assertEqual(formatted_percent, expected)
|
|
53
|
+
|
|
54
|
+
def test_compact_number_expr_default(self):
|
|
55
|
+
expr = formatter.compact_number_expr()
|
|
56
|
+
self.assertEqual(expr, "replace(format(datum.value, '.3~s'), 'G', 'B')")
|
|
57
|
+
|
|
58
|
+
def test_compact_number_expr_params(self):
|
|
59
|
+
expr = formatter.compact_number_expr('other', 2)
|
|
60
|
+
self.assertEqual(expr, "replace(format(datum.other, '.2~s'), 'G', 'B')")
|
|
61
|
+
|
|
62
|
+
@parameterized.named_parameters(
|
|
63
|
+
('rounded_up_percent', 0.4257, 15, '42.6% (15)'),
|
|
64
|
+
('rounded_down_percent', 0.4251, 15, '42.5% (15)'),
|
|
65
|
+
('thousand_value', 0.42, 2e4, '42.0% (20k)'),
|
|
66
|
+
('million_value', 0.42, 3e7, '42.0% (30M)'),
|
|
67
|
+
('billion_value', 0.42, 4e9, '42.0% (4B)'),
|
|
68
|
+
)
|
|
69
|
+
def test_format_number_text_correct(self, percent, value, expected):
|
|
70
|
+
formatted_text = formatter.format_number_text(percent, value)
|
|
71
|
+
self.assertEqual(formatted_text, expected)
|
|
72
|
+
|
|
73
|
+
@parameterized.named_parameters(
|
|
74
|
+
('zero_precision_thousands', 12345, '$', '$12k'),
|
|
75
|
+
('round_up_thousands', 14900, '€', '€15k'),
|
|
76
|
+
('million_value', 3.21e6, '£', '£3.2M'),
|
|
77
|
+
('billion_value_round_up', 4.28e9, '¥', '¥4.3B'),
|
|
78
|
+
('negative', -12345, '₮', '-₮12k'),
|
|
79
|
+
)
|
|
80
|
+
def test_format_monetary_num_correct(self, num, currency, expected):
|
|
81
|
+
formatted_number = formatter.format_monetary_num(num, currency)
|
|
82
|
+
self.assertEqual(formatted_number, expected)
|
|
83
|
+
|
|
84
|
+
@parameterized.named_parameters(
|
|
85
|
+
('decimals', -0.1234, 2, '$', '-$0.12'),
|
|
86
|
+
('zero_precision_thousands', 12345, 0, '', '12k'),
|
|
87
|
+
('round_up_thousands', 14900, 0, '$', '$15k'),
|
|
88
|
+
('million_value', 3.21e6, 2, '$', '$3.21M'),
|
|
89
|
+
('negative', -12345, 0, '$', '-$12k'),
|
|
90
|
+
)
|
|
91
|
+
def test_compact_number_correct(self, num, precision, currency, expected):
|
|
92
|
+
formatted_number = formatter.compact_number(num, precision, currency)
|
|
93
|
+
self.assertEqual(formatted_number, expected)
|
|
94
|
+
|
|
95
|
+
def test_create_summary_html(self):
|
|
96
|
+
template_env = formatter.create_template_env()
|
|
97
|
+
title = 'Integration Test Report'
|
|
98
|
+
cards = ['<card>Card 1</card>', '<card>Card 2</card>']
|
|
99
|
+
|
|
100
|
+
html_result = formatter.create_summary_html(template_env, title, cards)
|
|
101
|
+
|
|
102
|
+
# Since summary.html contains DOCTYPE (which breaks ElementTree XML parser),
|
|
103
|
+
# we verify the output using string assertions.
|
|
104
|
+
self.assertIn('<!DOCTYPE html>', html_result)
|
|
105
|
+
self.assertIn(title, html_result)
|
|
106
|
+
self.assertIn('<card>Card 1</card>', html_result)
|
|
107
|
+
self.assertIn('<card>Card 2</card>', html_result)
|
|
108
|
+
|
|
109
|
+
def test_create_card_html_structure(self):
|
|
110
|
+
template_env = formatter.create_template_env()
|
|
111
|
+
card_spec = formatter.CardSpec(id='test_id', title='test_title')
|
|
112
|
+
stats_spec = formatter.StatsSpec(title='stats_title', stat='test_stat')
|
|
113
|
+
chart_spec = formatter.ChartSpec(
|
|
114
|
+
'test_chart_id', 'test_chart_json', 'test_chart_description'
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
card_html = ET.fromstring(
|
|
118
|
+
formatter.create_card_html(
|
|
119
|
+
template_env, card_spec, 'test_insights', [chart_spec], [stats_spec]
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
self.assertEqual(card_html.tag, 'card')
|
|
123
|
+
self.assertLen(card_html, 4)
|
|
124
|
+
self.assertEqual(card_html[0].tag, 'card-title')
|
|
125
|
+
self.assertEqual(card_html[1].tag, 'card-insights')
|
|
126
|
+
self.assertEqual(card_html[2].tag, 'stats-section')
|
|
127
|
+
self.assertEqual(card_html[3].tag, 'charts')
|
|
128
|
+
|
|
129
|
+
def test_create_card_html_text(self):
|
|
130
|
+
template_env = formatter.create_template_env()
|
|
131
|
+
card_spec = formatter.CardSpec(id='test_id', title='test_title')
|
|
132
|
+
chart_spec = formatter.ChartSpec(
|
|
133
|
+
'test_chart_id', 'test_chart_json', 'test_chart_description'
|
|
134
|
+
)
|
|
135
|
+
card_html = ET.fromstring(
|
|
136
|
+
formatter.create_card_html(
|
|
137
|
+
template_env, card_spec, 'test_insights', [chart_spec]
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
self.assertContainsSubset('test_title', card_html[0].text)
|
|
141
|
+
self.assertContainsSubset('test_insights', card_html[1][1].text)
|
|
142
|
+
|
|
143
|
+
def test_create_card_html_multiple_charts(self):
|
|
144
|
+
template_env = formatter.create_template_env()
|
|
145
|
+
card_spec = formatter.CardSpec(id='test_id', title='test_title')
|
|
146
|
+
chart_spec1 = formatter.ChartSpec(
|
|
147
|
+
'test_chart_id1', 'test_chart_json1', 'test_chart_description1'
|
|
148
|
+
)
|
|
149
|
+
chart_spec2 = formatter.ChartSpec(
|
|
150
|
+
'test_chart_id2', 'test_chart_json2', 'test_chart_description2'
|
|
151
|
+
)
|
|
152
|
+
card_html = ET.fromstring(
|
|
153
|
+
formatter.create_card_html(
|
|
154
|
+
template_env, card_spec, 'test_insights', [chart_spec1, chart_spec2]
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
charts = card_html[2]
|
|
158
|
+
self.assertLen(charts, 4) # Each chart has 2 items, chart and script.
|
|
159
|
+
|
|
160
|
+
def test_create_card_html_chart_structure(self):
|
|
161
|
+
template_env = formatter.create_template_env()
|
|
162
|
+
card_spec = formatter.CardSpec(id='test_id', title='test_title')
|
|
163
|
+
chart_spec = formatter.ChartSpec(
|
|
164
|
+
'test_chart_id', 'test_chart_json', 'test_chart_description'
|
|
165
|
+
)
|
|
166
|
+
card_html = ET.fromstring(
|
|
167
|
+
formatter.create_card_html(
|
|
168
|
+
template_env, card_spec, 'test_insights', [chart_spec]
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
chart_html = card_html[2]
|
|
172
|
+
self.assertEqual(chart_html.tag, 'charts')
|
|
173
|
+
self.assertEqual(chart_html[0].tag, 'chart')
|
|
174
|
+
self.assertEqual(chart_html[0][0].tag, 'chart-embed')
|
|
175
|
+
self.assertEqual(chart_html[0][1].tag, 'chart-description')
|
|
176
|
+
self.assertContainsSubset('test_chart_description', chart_html[0][1].text)
|
|
177
|
+
self.assertEqual(chart_html[1].tag, 'script')
|
|
178
|
+
self.assertContainsSubset('test_chart_json', chart_html[1].text)
|
|
179
|
+
|
|
180
|
+
def test_create_card_html_mulitple_stats(self):
|
|
181
|
+
template_env = formatter.create_template_env()
|
|
182
|
+
card_spec = formatter.CardSpec(id='test_id', title='test_title')
|
|
183
|
+
stat1 = formatter.StatsSpec(title='stats_title1', stat='test_stat1')
|
|
184
|
+
stat2 = formatter.StatsSpec(title='stats_title2', stat='test_stat2')
|
|
185
|
+
card_html = ET.fromstring(
|
|
186
|
+
formatter.create_card_html(
|
|
187
|
+
template_env, card_spec, 'test_insights', stats_specs=[stat1, stat2]
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
stats = card_html[2]
|
|
191
|
+
self.assertLen(stats, 2)
|
|
192
|
+
|
|
193
|
+
def test_create_card_html_stats_structure(self):
|
|
194
|
+
template_env = formatter.create_template_env()
|
|
195
|
+
card_spec = formatter.CardSpec(id='test_id', title='test_title')
|
|
196
|
+
stats_spec = formatter.StatsSpec(
|
|
197
|
+
title='stats_title', stat='test_stat', delta='+0.3'
|
|
198
|
+
)
|
|
199
|
+
card_html = ET.fromstring(
|
|
200
|
+
formatter.create_card_html(
|
|
201
|
+
template_env, card_spec, 'test_insights', stats_specs=[stats_spec]
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
stats_html = card_html[2]
|
|
205
|
+
self.assertEqual(stats_html.tag, 'stats-section')
|
|
206
|
+
self.assertEqual(stats_html[0].tag, 'stats')
|
|
207
|
+
self.assertEqual(stats_html[0][0].tag, 'stats-title')
|
|
208
|
+
self.assertEqual(stats_html[0][0].text, 'stats_title')
|
|
209
|
+
self.assertEqual(stats_html[0][1].tag, 'stat')
|
|
210
|
+
self.assertEqual(stats_html[0][1].text, 'test_stat')
|
|
211
|
+
self.assertEqual(stats_html[0][2].tag, 'delta')
|
|
212
|
+
self.assertContainsSubset('+0.3', stats_html[0][2].text)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
if __name__ == '__main__':
|
|
216
|
+
absltest.main()
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
<!--
|
|
2
|
-
Copyright
|
|
2
|
+
Copyright 2025 Google LLC
|
|
3
3
|
|
|
4
4
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
you may not use this file except in compliance with the License.
|
|
@@ -42,7 +42,9 @@ limitations under the License.
|
|
|
42
42
|
</div>
|
|
43
43
|
</div>
|
|
44
44
|
|
|
45
|
-
{%
|
|
45
|
+
{% if start_date is defined and start_date %}
|
|
46
|
+
{% include "chips.html.jinja" %}
|
|
47
|
+
{% endif %}
|
|
46
48
|
|
|
47
49
|
<cards>
|
|
48
50
|
{# Each card is laid out in a grid. See; .card display layout. #}
|
meridian/version.py
CHANGED
schema/__init__.py
CHANGED
|
@@ -14,5 +14,17 @@
|
|
|
14
14
|
|
|
15
15
|
"""Module containing MMM schema library."""
|
|
16
16
|
|
|
17
|
+
try: # pylint: disable=g-statement-before-imports
|
|
18
|
+
# A quick check for schema dependencies.
|
|
19
|
+
# If this fails, it's likely because meridian was installed without
|
|
20
|
+
# `pip install google-meridian[schema]`.
|
|
21
|
+
from mmm.v1.model.meridian import meridian_model_pb2 # pylint: disable=g-import-not-at-top
|
|
22
|
+
except ModuleNotFoundError as exc:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
'Schema dependencies not found. Please install meridian with '
|
|
25
|
+
'`pip install google-meridian[schema]`.'
|
|
26
|
+
) from exc
|
|
27
|
+
|
|
28
|
+
# pylint: disable=g-import-not-at-top
|
|
17
29
|
from schema import serde
|
|
18
30
|
from schema import utils
|
|
@@ -1,220 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 The Meridian Authors.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""Module containing Meridian related exploratory data analysis (EDA) functionalities."""
|
|
16
|
-
from __future__ import annotations
|
|
17
|
-
|
|
18
|
-
from typing import Literal, TYPE_CHECKING, Union
|
|
19
|
-
|
|
20
|
-
import altair as alt
|
|
21
|
-
from meridian import constants
|
|
22
|
-
from meridian.model.eda import constants as eda_constants
|
|
23
|
-
import pandas as pd
|
|
24
|
-
|
|
25
|
-
if TYPE_CHECKING:
|
|
26
|
-
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
|
|
27
|
-
|
|
28
|
-
__all__ = [
|
|
29
|
-
'MeridianEDA',
|
|
30
|
-
]
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class MeridianEDA:
|
|
34
|
-
"""Class for running pre-modeling exploratory data analysis for Meridian InputData."""
|
|
35
|
-
|
|
36
|
-
_PAIRWISE_CORR_COLOR_SCALE = alt.Scale(
|
|
37
|
-
domain=[-1.0, 0.0, 1.0],
|
|
38
|
-
range=['#1f78b4', '#f7f7f7', '#e34a33'], # Blue-light grey-Orange
|
|
39
|
-
type='linear',
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
def __init__(
|
|
43
|
-
self,
|
|
44
|
-
meridian: model.Meridian,
|
|
45
|
-
):
|
|
46
|
-
self._meridian = meridian
|
|
47
|
-
|
|
48
|
-
def generate_and_save_report(self, filename: str, filepath: str):
|
|
49
|
-
"""Generates and saves the 2 page HTML report containing findings in EDA about given InputData.
|
|
50
|
-
|
|
51
|
-
Args:
|
|
52
|
-
filename: The filename for the generated HTML output.
|
|
53
|
-
filepath: The path to the directory where the file will be saved.
|
|
54
|
-
"""
|
|
55
|
-
# TODO: Implement.
|
|
56
|
-
raise NotImplementedError()
|
|
57
|
-
|
|
58
|
-
def plot_pairwise_correlation(
|
|
59
|
-
self, geos: Union[int, list[str], Literal['nationalize']] = 1
|
|
60
|
-
) -> alt.Chart:
|
|
61
|
-
"""Plots the Pairwise Correlation data.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
geos: Defines which geos to plot. - int: The number of top geos to plot,
|
|
65
|
-
ranked by population. - list[str]: A specific list of geo names to plot.
|
|
66
|
-
- 'nationalize': Aggregates all geos into a single national view.
|
|
67
|
-
Defaults to 1 (plotting the top geo). If the data is already at a
|
|
68
|
-
national level, this parameter is ignored and a national plot is
|
|
69
|
-
generated.
|
|
70
|
-
|
|
71
|
-
Returns:
|
|
72
|
-
Altair chart(s) of the Pairwise Correlation data.
|
|
73
|
-
"""
|
|
74
|
-
geos_to_plot = self._validate_and_get_geos_to_plot(geos)
|
|
75
|
-
is_national = self._meridian.is_national
|
|
76
|
-
nationalize_geos = geos == 'nationalize'
|
|
77
|
-
|
|
78
|
-
if is_national or nationalize_geos:
|
|
79
|
-
pairwise_corr_artifact = (
|
|
80
|
-
self._meridian.eda_engine.check_national_pairwise_corr().get_national_artifact
|
|
81
|
-
)
|
|
82
|
-
if pairwise_corr_artifact is None:
|
|
83
|
-
raise ValueError('EDAOutcome does not have national artifact.')
|
|
84
|
-
else:
|
|
85
|
-
pairwise_corr_artifact = (
|
|
86
|
-
self._meridian.eda_engine.check_geo_pairwise_corr().get_geo_artifact
|
|
87
|
-
)
|
|
88
|
-
if pairwise_corr_artifact is None:
|
|
89
|
-
raise ValueError('EDAOutcome does not have geo artifact.')
|
|
90
|
-
pairwise_corr_data = pairwise_corr_artifact.corr_matrix.to_dataframe()
|
|
91
|
-
|
|
92
|
-
charts = []
|
|
93
|
-
for geo_to_plot in geos_to_plot:
|
|
94
|
-
title = (
|
|
95
|
-
'Pairwise correlations among all treatments and controls for'
|
|
96
|
-
f' {geo_to_plot}'
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
if not (is_national or nationalize_geos):
|
|
100
|
-
plot_data = (
|
|
101
|
-
pairwise_corr_data.xs(geo_to_plot, level=constants.GEO)
|
|
102
|
-
.rename_axis(
|
|
103
|
-
index=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2]
|
|
104
|
-
)
|
|
105
|
-
.reset_index()
|
|
106
|
-
)
|
|
107
|
-
else:
|
|
108
|
-
plot_data = pairwise_corr_data.rename_axis(
|
|
109
|
-
index=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2]
|
|
110
|
-
).reset_index()
|
|
111
|
-
plot_data.columns = [
|
|
112
|
-
eda_constants.VARIABLE_1,
|
|
113
|
-
eda_constants.VARIABLE_2,
|
|
114
|
-
eda_constants.CORRELATION,
|
|
115
|
-
]
|
|
116
|
-
unique_variables = plot_data[eda_constants.VARIABLE_1].unique()
|
|
117
|
-
variable_to_index = {name: i for i, name in enumerate(unique_variables)}
|
|
118
|
-
|
|
119
|
-
plot_data['idx1'] = plot_data[eda_constants.VARIABLE_1].map(
|
|
120
|
-
variable_to_index
|
|
121
|
-
)
|
|
122
|
-
plot_data['idx2'] = plot_data[eda_constants.VARIABLE_2].map(
|
|
123
|
-
variable_to_index
|
|
124
|
-
)
|
|
125
|
-
lower_triangle_data = plot_data[plot_data['idx2'] > plot_data['idx1']]
|
|
126
|
-
|
|
127
|
-
charts.append(
|
|
128
|
-
self._plot_2d_heatmap(lower_triangle_data, title, unique_variables)
|
|
129
|
-
)
|
|
130
|
-
final_chart = (
|
|
131
|
-
alt.vconcat(*charts)
|
|
132
|
-
.resolve_legend(color='independent')
|
|
133
|
-
.configure_axis(labelAngle=315)
|
|
134
|
-
.configure_title(anchor='start')
|
|
135
|
-
.configure_view(stroke=None)
|
|
136
|
-
)
|
|
137
|
-
return final_chart
|
|
138
|
-
|
|
139
|
-
def _plot_2d_heatmap(
|
|
140
|
-
self, data: pd.DataFrame, title: str, unique_variables: list[str]
|
|
141
|
-
) -> alt.Chart:
|
|
142
|
-
"""Plots a 2D heatmap."""
|
|
143
|
-
# Base chart with position encodings
|
|
144
|
-
base = (
|
|
145
|
-
alt.Chart(data)
|
|
146
|
-
.encode(
|
|
147
|
-
x=alt.X(
|
|
148
|
-
f'{eda_constants.VARIABLE_1}:N',
|
|
149
|
-
title=None,
|
|
150
|
-
sort=unique_variables,
|
|
151
|
-
scale=alt.Scale(domain=unique_variables),
|
|
152
|
-
),
|
|
153
|
-
y=alt.Y(
|
|
154
|
-
f'{eda_constants.VARIABLE_2}:N',
|
|
155
|
-
title=None,
|
|
156
|
-
sort=unique_variables,
|
|
157
|
-
scale=alt.Scale(domain=unique_variables),
|
|
158
|
-
),
|
|
159
|
-
)
|
|
160
|
-
.properties(title=title)
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
# Heatmap layer (rectangles)
|
|
164
|
-
heatmap = base.mark_rect().encode(
|
|
165
|
-
color=alt.Color(
|
|
166
|
-
f'{eda_constants.CORRELATION}:Q',
|
|
167
|
-
scale=self._PAIRWISE_CORR_COLOR_SCALE,
|
|
168
|
-
legend=alt.Legend(title=eda_constants.CORRELATION),
|
|
169
|
-
),
|
|
170
|
-
tooltip=[
|
|
171
|
-
eda_constants.VARIABLE_1,
|
|
172
|
-
eda_constants.VARIABLE_2,
|
|
173
|
-
alt.Tooltip(f'{eda_constants.CORRELATION}:Q', format='.3f'),
|
|
174
|
-
],
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
# Text annotation layer (values)
|
|
178
|
-
text = base.mark_text().encode(
|
|
179
|
-
text=alt.Text(f'{eda_constants.CORRELATION}:Q', format='.3f'),
|
|
180
|
-
color=alt.value('black'),
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
# Combine layers and apply final configurations
|
|
184
|
-
chart = (heatmap + text).properties(width=350, height=350)
|
|
185
|
-
|
|
186
|
-
return chart
|
|
187
|
-
|
|
188
|
-
def _generate_pairwise_correlation_report(self) -> str:
|
|
189
|
-
"""Creates the HTML snippet for Pairwise Correlation report section."""
|
|
190
|
-
# TODO: Implement.
|
|
191
|
-
raise NotImplementedError()
|
|
192
|
-
|
|
193
|
-
def _validate_and_get_geos_to_plot(
|
|
194
|
-
self, geos: Union[int, list[str], Literal['nationalize']]
|
|
195
|
-
) -> list[str]:
|
|
196
|
-
"""Validates and returns the geos to plot."""
|
|
197
|
-
## Validate
|
|
198
|
-
is_national = self._meridian.is_national
|
|
199
|
-
if is_national or geos == 'nationalize':
|
|
200
|
-
geos_to_plot = [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]
|
|
201
|
-
elif isinstance(geos, int):
|
|
202
|
-
if geos > len(self._meridian.input_data.geo) or geos <= 0:
|
|
203
|
-
raise ValueError(
|
|
204
|
-
'geos must be a positive integer less than or equal to the number'
|
|
205
|
-
' of geos in the data.'
|
|
206
|
-
)
|
|
207
|
-
geos_to_plot = self._meridian.input_data.get_n_top_largest_geos(geos)
|
|
208
|
-
else:
|
|
209
|
-
geos_to_plot = geos
|
|
210
|
-
|
|
211
|
-
if (
|
|
212
|
-
not is_national and geos != 'nationalize'
|
|
213
|
-
): # if national then geos_to_plot will be ignored
|
|
214
|
-
for geo in geos_to_plot:
|
|
215
|
-
if geo not in self._meridian.input_data.geo:
|
|
216
|
-
raise ValueError(f'Geo {geo} does not exist in the data.')
|
|
217
|
-
if len(geos_to_plot) != len(set(geos_to_plot)):
|
|
218
|
-
raise ValueError('geos must not contain duplicate values.')
|
|
219
|
-
|
|
220
|
-
return geos_to_plot
|
|
File without changes
|
|
File without changes
|
|
File without changes
|