google-meridian 1.0.5__py3-none-any.whl → 1.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.0.5.dist-info → google_meridian-1.0.7.dist-info}/METADATA +12 -11
- {google_meridian-1.0.5.dist-info → google_meridian-1.0.7.dist-info}/RECORD +16 -16
- {google_meridian-1.0.5.dist-info → google_meridian-1.0.7.dist-info}/WHEEL +1 -1
- meridian/__init__.py +1 -1
- meridian/analysis/analyzer.py +677 -817
- meridian/analysis/optimizer.py +192 -134
- meridian/analysis/summarizer.py +7 -3
- meridian/analysis/test_utils.py +72 -20
- meridian/analysis/visualizer.py +10 -10
- meridian/constants.py +3 -0
- meridian/data/input_data.py +49 -3
- meridian/data/load.py +10 -7
- meridian/data/test_utils.py +18 -11
- meridian/data/time_coordinates.py +38 -17
- {google_meridian-1.0.5.dist-info → google_meridian-1.0.7.dist-info/licenses}/LICENSE +0 -0
- {google_meridian-1.0.5.dist-info → google_meridian-1.0.7.dist-info}/top_level.txt +0 -0
meridian/analysis/visualizer.py
CHANGED
|
@@ -239,7 +239,7 @@ class ModelDiagnostics:
|
|
|
239
239
|
groupby = posterior_df.columns.tolist()
|
|
240
240
|
groupby.remove(parameter)
|
|
241
241
|
plot = (
|
|
242
|
-
alt.Chart(prior_posterior_df)
|
|
242
|
+
alt.Chart(prior_posterior_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
|
|
243
243
|
.transform_density(
|
|
244
244
|
parameter, groupby=groupby, as_=[parameter, 'density']
|
|
245
245
|
)
|
|
@@ -332,7 +332,7 @@ class ModelDiagnostics:
|
|
|
332
332
|
rhat = rhat.dropna(subset=[c.RHAT])
|
|
333
333
|
|
|
334
334
|
boxplot = (
|
|
335
|
-
alt.Chart(rhat)
|
|
335
|
+
alt.Chart(rhat, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
|
|
336
336
|
.mark_boxplot(median={'color': c.BLUE_300}, outliers={'filled': True})
|
|
337
337
|
.encode(
|
|
338
338
|
x=alt.X(c.PARAMETER, axis=alt.Axis(labelAngle=-45)),
|
|
@@ -461,7 +461,7 @@ class ModelFit:
|
|
|
461
461
|
else:
|
|
462
462
|
y_axis_label = summary_text.KPI_LABEL
|
|
463
463
|
plot = (
|
|
464
|
-
alt.Chart(model_fit_df)
|
|
464
|
+
alt.Chart(model_fit_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
|
|
465
465
|
.mark_line()
|
|
466
466
|
.encode(
|
|
467
467
|
x=alt.X(
|
|
@@ -762,7 +762,7 @@ class ReachAndFrequency:
|
|
|
762
762
|
range=[c.BLUE_600, c.RED_600],
|
|
763
763
|
)
|
|
764
764
|
|
|
765
|
-
base = alt.Chart().transform_calculate(
|
|
765
|
+
base = alt.Chart(width=c.VEGALITE_FACET_DEFAULT_WIDTH).transform_calculate(
|
|
766
766
|
optimal_freq=f"'{summary_text.OPTIMAL_FREQ_LABEL}'",
|
|
767
767
|
expected_roi=f"'{summary_text.EXPECTED_ROI_LABEL}'",
|
|
768
768
|
)
|
|
@@ -1012,7 +1012,7 @@ class MediaEffects:
|
|
|
1012
1012
|
else:
|
|
1013
1013
|
y_axis_label = summary_text.INC_KPI_LABEL
|
|
1014
1014
|
base = (
|
|
1015
|
-
alt.Chart(response_curves_df)
|
|
1015
|
+
alt.Chart(response_curves_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
|
|
1016
1016
|
.transform_calculate(
|
|
1017
1017
|
spend_level=(
|
|
1018
1018
|
'datum.spend_multiplier >= 1.0 ? "Above current spend" : "Below'
|
|
@@ -1099,7 +1099,7 @@ class MediaEffects:
|
|
|
1099
1099
|
An Altair plot showing the Adstock decay prior and posterior per media.
|
|
1100
1100
|
"""
|
|
1101
1101
|
dataframe = self.adstock_decay_dataframe(confidence_level=confidence_level)
|
|
1102
|
-
base = alt.Chart(dataframe)
|
|
1102
|
+
base = alt.Chart(dataframe, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
|
|
1103
1103
|
|
|
1104
1104
|
scaled_confidence_level = int(confidence_level * 100)
|
|
1105
1105
|
|
|
@@ -1254,7 +1254,7 @@ class MediaEffects:
|
|
|
1254
1254
|
]
|
|
1255
1255
|
range_list = [c.BLUE_700, c.GREY_600]
|
|
1256
1256
|
|
|
1257
|
-
base = alt.Chart(df_channel_type)
|
|
1257
|
+
base = alt.Chart(df_channel_type, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
|
|
1258
1258
|
color_scale = alt.Scale(
|
|
1259
1259
|
domain=domain_list,
|
|
1260
1260
|
range=range_list,
|
|
@@ -1274,7 +1274,7 @@ class MediaEffects:
|
|
|
1274
1274
|
y2=f'{c.CI_HI}:Q',
|
|
1275
1275
|
color=alt.Color(f'{c.DISTRIBUTION}:N', scale=color_scale),
|
|
1276
1276
|
)
|
|
1277
|
-
histogram = base.
|
|
1277
|
+
histogram = base.mark_rect(color=c.GREY_600, opacity=0.4).encode(
|
|
1278
1278
|
x=f'{c.START_INTERVAL_HISTOGRAM}:Q',
|
|
1279
1279
|
x2=f'{c.END_INTERVAL_HISTOGRAM}:Q',
|
|
1280
1280
|
y=alt.Y(f'{c.SCALED_COUNT_HISTOGRAM}:Q'),
|
|
@@ -1700,7 +1700,7 @@ class MediaSummary:
|
|
|
1700
1700
|
|
|
1701
1701
|
domain = [c.BASELINE, c.ALL_CHANNELS]
|
|
1702
1702
|
colors = [c.YELLOW_600, c.BLUE_700]
|
|
1703
|
-
base = alt.Chart(outcome_df).encode(
|
|
1703
|
+
base = alt.Chart(outcome_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH).encode(
|
|
1704
1704
|
alt.Theta(f'{c.PCT_OF_CONTRIBUTION}:Q', stack=True),
|
|
1705
1705
|
alt.Color(
|
|
1706
1706
|
f'{c.CHANNEL}:N',
|
|
@@ -1985,7 +1985,7 @@ class MediaSummary:
|
|
|
1985
1985
|
axes_scale = alt.Scale(domain=(0, max_roi), nice=True)
|
|
1986
1986
|
|
|
1987
1987
|
plot = (
|
|
1988
|
-
alt.Chart(plot_df)
|
|
1988
|
+
alt.Chart(plot_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
|
|
1989
1989
|
.mark_circle(tooltip=True, size=c.POINT_SIZE)
|
|
1990
1990
|
.encode(
|
|
1991
1991
|
x=alt.X(c.ROI, title='ROI', scale=axes_scale),
|
meridian/constants.py
CHANGED
meridian/data/input_data.py
CHANGED
|
@@ -18,6 +18,7 @@ The `InputData` class is used to store all the input data to the model.
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
from collections import abc
|
|
21
|
+
from collections.abc import Sequence
|
|
21
22
|
import dataclasses
|
|
22
23
|
import datetime as dt
|
|
23
24
|
import functools
|
|
@@ -59,7 +60,7 @@ def _check_dim_collection(
|
|
|
59
60
|
)
|
|
60
61
|
|
|
61
62
|
|
|
62
|
-
def _check_dim_match(dim, arrays):
|
|
63
|
+
def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray]):
|
|
63
64
|
"""Verifies that the dimensions of the appropriate arrays match."""
|
|
64
65
|
lengths = [len(array.coords[dim]) for array in arrays if array is not None]
|
|
65
66
|
names = [array.name for array in arrays if array is not None]
|
|
@@ -69,6 +70,19 @@ def _check_dim_match(dim, arrays):
|
|
|
69
70
|
)
|
|
70
71
|
|
|
71
72
|
|
|
73
|
+
def _check_coords_match(dim: str, arrays: Sequence[xr.DataArray]):
|
|
74
|
+
"""Verifies that the coordinates of the appropriate arrays match."""
|
|
75
|
+
arrays = [arr for arr in arrays if arr is not None and dim in arr.coords]
|
|
76
|
+
if not arrays:
|
|
77
|
+
return
|
|
78
|
+
first_coords = arrays[0].coords[dim].values
|
|
79
|
+
for arr in arrays[1:]:
|
|
80
|
+
if not np.array_equal(arr.coords[dim].values, first_coords):
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"`{dim}` coordinates of array `{arr.name}` don't match."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
72
86
|
@dataclasses.dataclass
|
|
73
87
|
class InputData:
|
|
74
88
|
"""A data container for advertising data in a format supported by Meridian.
|
|
@@ -242,6 +256,7 @@ class InputData:
|
|
|
242
256
|
self._validate_media_channels()
|
|
243
257
|
self._validate_time_formats()
|
|
244
258
|
self._validate_times()
|
|
259
|
+
self._validate_geos()
|
|
245
260
|
|
|
246
261
|
def _convert_geos_to_strings(self):
|
|
247
262
|
"""Converts geo coordinates to strings in all relevant DataArrays."""
|
|
@@ -542,11 +557,13 @@ class InputData:
|
|
|
542
557
|
try:
|
|
543
558
|
_ = self.time_coordinates.interval_days
|
|
544
559
|
except ValueError as exc:
|
|
545
|
-
raise ValueError("Time coordinates must be
|
|
560
|
+
raise ValueError("Time coordinates must be regularly spaced.") from exc
|
|
546
561
|
try:
|
|
547
562
|
_ = self.media_time_coordinates.interval_days
|
|
548
563
|
except ValueError as exc:
|
|
549
|
-
raise ValueError(
|
|
564
|
+
raise ValueError(
|
|
565
|
+
"Media time coordinates must be regularly spaced."
|
|
566
|
+
) from exc
|
|
550
567
|
|
|
551
568
|
def _validate_time(self, array: xr.DataArray | None):
|
|
552
569
|
"""Validates the `time` dimension of the given `DataArray`.
|
|
@@ -617,6 +634,35 @@ class InputData:
|
|
|
617
634
|
f" {constants.DATE_FORMAT}"
|
|
618
635
|
) from exc
|
|
619
636
|
|
|
637
|
+
def _check_unique_names(self, dim: str, array: xr.DataArray | None):
|
|
638
|
+
"""Checks if a DataArray contains unique names on the specified dimension."""
|
|
639
|
+
if array is not None and dim in array.coords:
|
|
640
|
+
names = array.coords[dim].values.tolist()
|
|
641
|
+
if len(names) != len(set(names)):
|
|
642
|
+
raise ValueError(
|
|
643
|
+
f"`{dim}` names must be unique within the array `{array.name}`."
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
def _validate_geos(self):
|
|
647
|
+
"""Validates geo coordinates across relevant DataArrays."""
|
|
648
|
+
arrays_with_geos = [
|
|
649
|
+
self.kpi,
|
|
650
|
+
self.revenue_per_kpi,
|
|
651
|
+
self.media,
|
|
652
|
+
self.controls,
|
|
653
|
+
self.population,
|
|
654
|
+
self.reach,
|
|
655
|
+
self.frequency,
|
|
656
|
+
self.organic_media,
|
|
657
|
+
self.organic_reach,
|
|
658
|
+
self.organic_frequency,
|
|
659
|
+
self.non_media_treatments,
|
|
660
|
+
]
|
|
661
|
+
for array in arrays_with_geos:
|
|
662
|
+
self._check_unique_names(constants.GEO, array)
|
|
663
|
+
|
|
664
|
+
_check_coords_match(constants.GEO, arrays_with_geos)
|
|
665
|
+
|
|
620
666
|
def as_dataset(self) -> xr.Dataset:
|
|
621
667
|
"""Returns data as a single `xarray.Dataset` object."""
|
|
622
668
|
data = [
|
meridian/data/load.py
CHANGED
|
@@ -950,7 +950,7 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
950
950
|
raise ValueError('NA values found in the organic_frequency columns.')
|
|
951
951
|
|
|
952
952
|
# Determine columns in which NAs are expected in the lagged-media period.
|
|
953
|
-
|
|
953
|
+
not_lagged_columns = []
|
|
954
954
|
coords = [
|
|
955
955
|
constants.KPI,
|
|
956
956
|
constants.CONTROLS,
|
|
@@ -967,12 +967,12 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
967
967
|
for coord in coords:
|
|
968
968
|
columns = getattr(self.coord_to_columns, coord)
|
|
969
969
|
columns = [columns] if isinstance(columns, str) else columns
|
|
970
|
-
|
|
970
|
+
not_lagged_columns.extend(columns)
|
|
971
971
|
|
|
972
972
|
# Dates with at least one non-NA value in columns different from media,
|
|
973
973
|
# reach, frequency, organic_media, organic_reach, and organic_frequency.
|
|
974
974
|
time_column_name = self.coord_to_columns.time
|
|
975
|
-
no_na_period = self.df[(~self.df[
|
|
975
|
+
no_na_period = self.df[(~self.df[not_lagged_columns].isna()).any(axis=1)][
|
|
976
976
|
time_column_name
|
|
977
977
|
].unique()
|
|
978
978
|
|
|
@@ -999,13 +999,16 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
999
999
|
# organic_frequency.
|
|
1000
1000
|
not_lagged_data = self.df.loc[
|
|
1001
1001
|
self.df[time_column_name].isin(no_na_period),
|
|
1002
|
-
|
|
1002
|
+
not_lagged_columns,
|
|
1003
1003
|
]
|
|
1004
1004
|
if not_lagged_data.isna().any(axis=None):
|
|
1005
|
+
incorrect_columns = []
|
|
1006
|
+
for column in not_lagged_columns:
|
|
1007
|
+
if not_lagged_data[column].isna().any(axis=None):
|
|
1008
|
+
incorrect_columns.append(column)
|
|
1005
1009
|
raise ValueError(
|
|
1006
|
-
'NA values found in
|
|
1007
|
-
|
|
1008
|
-
' non-media columns).'
|
|
1010
|
+
f'NA values found in columns {incorrect_columns} within the modeling'
|
|
1011
|
+
' time window (time periods where the KPI is modeled).'
|
|
1009
1012
|
)
|
|
1010
1013
|
|
|
1011
1014
|
def load(self) -> input_data.InputData:
|
meridian/data/test_utils.py
CHANGED
|
@@ -37,7 +37,7 @@ def _sample_names(prefix: str, n_names: int | None) -> list[str] | None:
|
|
|
37
37
|
return [prefix + str(n) for n in range(n_names)] if n_names else None
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
def
|
|
40
|
+
def sample_geos(
|
|
41
41
|
n_geos: int | None, integer_geos: bool = False
|
|
42
42
|
) -> list[str] | list[int] | None:
|
|
43
43
|
"""Generates a list of sample geos."""
|
|
@@ -519,6 +519,7 @@ def random_media_da(
|
|
|
519
519
|
n_media_channels: int,
|
|
520
520
|
seed: int = 0,
|
|
521
521
|
date_format: str = c.DATE_FORMAT,
|
|
522
|
+
explicit_geo_names: Sequence[str] | None = None,
|
|
522
523
|
explicit_time_index: Sequence[str] | None = None,
|
|
523
524
|
explicit_media_channel_names: Sequence[str] | None = None,
|
|
524
525
|
array_name: str = 'media',
|
|
@@ -535,6 +536,7 @@ def random_media_da(
|
|
|
535
536
|
n_media_channels: Number of media channels
|
|
536
537
|
seed: Random seed used by `np.random.seed()`
|
|
537
538
|
date_format: The date format to use for time coordinate labels
|
|
539
|
+
explicit_geo_names: If given, ignore `n_geos` and use this as is.
|
|
538
540
|
explicit_time_index: If given, ignore `date_format` and use this as is
|
|
539
541
|
explicit_media_channel_names: If given, ignore `n_media_channels` and use
|
|
540
542
|
this as is
|
|
@@ -558,6 +560,11 @@ def random_media_da(
|
|
|
558
560
|
np.random.normal(5, 5, size=(n_geos, n_media_times, n_media_channels))
|
|
559
561
|
)
|
|
560
562
|
)
|
|
563
|
+
if explicit_geo_names is None:
|
|
564
|
+
geos = sample_geos(n_geos, integer_geos)
|
|
565
|
+
else:
|
|
566
|
+
geos = explicit_geo_names
|
|
567
|
+
|
|
561
568
|
if explicit_time_index is None:
|
|
562
569
|
media_time = _sample_times(
|
|
563
570
|
n_times=n_media_times,
|
|
@@ -576,7 +583,7 @@ def random_media_da(
|
|
|
576
583
|
media,
|
|
577
584
|
dims=['geo', 'media_time', channel_variable_name],
|
|
578
585
|
coords={
|
|
579
|
-
'geo':
|
|
586
|
+
'geo': geos,
|
|
580
587
|
'media_time': media_time,
|
|
581
588
|
channel_variable_name: media_channels,
|
|
582
589
|
},
|
|
@@ -647,7 +654,7 @@ def random_media_spend_nd_da(
|
|
|
647
654
|
coords = {}
|
|
648
655
|
if n_geos is not None:
|
|
649
656
|
dims.append('geo')
|
|
650
|
-
coords['geo'] =
|
|
657
|
+
coords['geo'] = sample_geos(n_geos, integer_geos)
|
|
651
658
|
if n_times is not None:
|
|
652
659
|
dims.append('time')
|
|
653
660
|
coords['time'] = _sample_times(n_times=n_times)
|
|
@@ -719,7 +726,7 @@ def random_controls_da(
|
|
|
719
726
|
controls,
|
|
720
727
|
dims=['geo', 'time', 'control_variable'],
|
|
721
728
|
coords={
|
|
722
|
-
'geo':
|
|
729
|
+
'geo': sample_geos(n_geos, integer_geos),
|
|
723
730
|
'time': (
|
|
724
731
|
_sample_times(n_times=n_times, date_format=date_format)
|
|
725
732
|
if explicit_time_index is None
|
|
@@ -775,7 +782,7 @@ def random_kpi_da(
|
|
|
775
782
|
kpi,
|
|
776
783
|
dims=['geo', 'time'],
|
|
777
784
|
coords={
|
|
778
|
-
'geo':
|
|
785
|
+
'geo': sample_geos(n_geos, integer_geos),
|
|
779
786
|
'time': _sample_times(n_times=n_times),
|
|
780
787
|
},
|
|
781
788
|
name=c.KPI,
|
|
@@ -796,7 +803,7 @@ def constant_revenue_per_kpi(
|
|
|
796
803
|
revenue_per_kpi,
|
|
797
804
|
dims=['geo', 'time'],
|
|
798
805
|
coords={
|
|
799
|
-
'geo':
|
|
806
|
+
'geo': sample_geos(n_geos, integer_geos),
|
|
800
807
|
'time': _sample_times(n_times=n_times),
|
|
801
808
|
},
|
|
802
809
|
name='revenue_per_kpi',
|
|
@@ -815,7 +822,7 @@ def random_population(
|
|
|
815
822
|
return xr.DataArray(
|
|
816
823
|
population,
|
|
817
824
|
dims=['geo'],
|
|
818
|
-
coords={'geo':
|
|
825
|
+
coords={'geo': sample_geos(n_geos, integer_geos)},
|
|
819
826
|
name='population',
|
|
820
827
|
)
|
|
821
828
|
|
|
@@ -857,7 +864,7 @@ def random_reach_da(
|
|
|
857
864
|
reach,
|
|
858
865
|
dims=['geo', 'media_time', channel_variable_name],
|
|
859
866
|
coords={
|
|
860
|
-
'geo':
|
|
867
|
+
'geo': sample_geos(n_geos, integer_geos),
|
|
861
868
|
'media_time': _sample_times(
|
|
862
869
|
n_times=n_media_times, start_date=start_date
|
|
863
870
|
),
|
|
@@ -925,7 +932,7 @@ def random_frequency_da(
|
|
|
925
932
|
frequency,
|
|
926
933
|
dims=['geo', 'media_time', channel_variable_name],
|
|
927
934
|
coords={
|
|
928
|
-
'geo':
|
|
935
|
+
'geo': sample_geos(n_geos, integer_geos),
|
|
929
936
|
'media_time': _sample_times(
|
|
930
937
|
n_times=n_media_times, start_date=start_date
|
|
931
938
|
),
|
|
@@ -992,7 +999,7 @@ def random_rf_spend_nd_da(
|
|
|
992
999
|
coords = {}
|
|
993
1000
|
if n_geos is not None:
|
|
994
1001
|
dims.append('geo')
|
|
995
|
-
coords['geo'] =
|
|
1002
|
+
coords['geo'] = sample_geos(n_geos, integer_geos)
|
|
996
1003
|
if n_times is not None:
|
|
997
1004
|
dims.append('time')
|
|
998
1005
|
coords['time'] = _sample_times(n_times=n_times)
|
|
@@ -1060,7 +1067,7 @@ def random_non_media_treatments_da(
|
|
|
1060
1067
|
non_media_treatments,
|
|
1061
1068
|
dims=['geo', 'time', 'non_media_channel'],
|
|
1062
1069
|
coords={
|
|
1063
|
-
'geo':
|
|
1070
|
+
'geo': sample_geos(n_geos, integer_geos),
|
|
1064
1071
|
'time': (
|
|
1065
1072
|
_sample_times(n_times=n_times, date_format=date_format)
|
|
1066
1073
|
if explicit_time_index is None
|
|
@@ -19,7 +19,6 @@ import dataclasses
|
|
|
19
19
|
import datetime
|
|
20
20
|
import functools
|
|
21
21
|
from typing import TypeAlias
|
|
22
|
-
import warnings
|
|
23
22
|
|
|
24
23
|
from meridian import constants
|
|
25
24
|
import numpy as np
|
|
@@ -145,6 +144,10 @@ class TimeCoordinates:
|
|
|
145
144
|
return cls(datetime_index=_to_pandas_datetime_index(dates))
|
|
146
145
|
|
|
147
146
|
def __post_init__(self):
|
|
147
|
+
if len(self.datetime_index) <= 1:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
"There must be more than one date index in the time coordinates."
|
|
150
|
+
)
|
|
148
151
|
if not self.datetime_index.is_monotonic_increasing:
|
|
149
152
|
raise ValueError(
|
|
150
153
|
"Time coordinates must be strictly monotonically increasing."
|
|
@@ -162,28 +165,46 @@ class TimeCoordinates:
|
|
|
162
165
|
|
|
163
166
|
@functools.cached_property
|
|
164
167
|
def interval_days(self) -> int:
|
|
165
|
-
"""Returns the interval between two neighboring dates in `all_dates`.
|
|
168
|
+
"""Returns the *mean* interval between two neighboring dates in `all_dates`.
|
|
166
169
|
|
|
167
170
|
Raises:
|
|
168
|
-
ValueError if the date index is not regularly spaced.
|
|
171
|
+
ValueError if the date index is not "regularly spaced".
|
|
169
172
|
"""
|
|
170
|
-
|
|
171
|
-
|
|
173
|
+
if not self._is_regular_time_index():
|
|
174
|
+
raise ValueError("Time coordinates are not regularly spaced!")
|
|
172
175
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
"The time coordinates only have one date. Returning an interval of 0."
|
|
178
|
-
)
|
|
179
|
-
return 0
|
|
176
|
+
# Calculate the difference between consecutive dates, in days.
|
|
177
|
+
diffs = self._interval_days
|
|
178
|
+
# Return the rounded mean interval.
|
|
179
|
+
return int(np.round(np.mean(diffs)))
|
|
180
180
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
181
|
+
@property
|
|
182
|
+
def _timedelta_index(self) -> pd.TimedeltaIndex:
|
|
183
|
+
"""Returns the timedeltas between consecutive dates in `datetime_index`."""
|
|
184
|
+
return self.datetime_index.diff().dropna()
|
|
184
185
|
|
|
185
|
-
|
|
186
|
-
|
|
186
|
+
@property
|
|
187
|
+
def _interval_days(self) -> Sequence[int]:
|
|
188
|
+
"""Converts `_timedelta_index` to a sequence of days for easier compute."""
|
|
189
|
+
return self._timedelta_index.days.to_numpy()
|
|
190
|
+
|
|
191
|
+
def _is_regular_time_index(self) -> bool:
|
|
192
|
+
"""Returns True if the time index is "regularly spaced"."""
|
|
193
|
+
if np.all(self._interval_days == self._interval_days[0]):
|
|
194
|
+
# All intervals are regular. Base case.
|
|
195
|
+
return True
|
|
196
|
+
# Special cases:
|
|
197
|
+
# * Monthly cadences
|
|
198
|
+
if np.all(np.isin(self._interval_days, [28, 29, 30, 31])):
|
|
199
|
+
return True
|
|
200
|
+
# * Quarterly cadences
|
|
201
|
+
if np.all(np.isin(self._interval_days, [90, 91, 92])):
|
|
202
|
+
return True
|
|
203
|
+
# * Yearly cadences
|
|
204
|
+
if np.all(np.isin(self._interval_days, [365, 366])):
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
return False
|
|
187
208
|
|
|
188
209
|
def get_selected_dates(
|
|
189
210
|
self,
|
|
File without changes
|
|
File without changes
|