prophet-rb 0.2.3 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 756e42a4e2c39e114610d2f41e2403d9b160e77da02c0483e43b4bf460dd828b
4
- data.tar.gz: e7bce55f47410227131afcba9d2f90c9812dab093efd1c17465383b3a5e6dc73
3
+ metadata.gz: 506ab7cfb738d7f7289db812134b2fb64d6371da66e7586a1bcc254b26c6fa1c
4
+ data.tar.gz: ed97303bb3563bdebe86cf9865ec04142e9b3fa0e3f7bfdea8ae6a8a2ff8570f
5
5
  SHA512:
6
- metadata.gz: ec014f32ff39abd49195d7e9a7f38b3b297bc7f6fc28da3d1be6fb19b6edbe816a681c4fbb35d5ea79fc2ced0aa948b0468cff9f6d9d8293590761711944b71b
7
- data.tar.gz: 87ecc8706c73e7f1063c8d75cc5c687ebc3fddb92d89e60bb115d21c90a60fbb6172648a0d414baeb11f04572b37e06777d4100acb07b8559dea8e8a711741fb
6
+ metadata.gz: 5339ac3f8e7f26539137dc23b40d481013bc1ec082a008221edff00ea574ab08617956849e7cd221af40e9b12105630cbe294f8b5e2cfa1a33670d9c1fcd970c
7
+ data.tar.gz: 63b64fbef8414f65dfb39b7c266e780790f28b42cc6f2a0da415885b293c57ba309d50243e9d9ad2d61d93fe034017a31871cef26eed5e33b01cc651816eed3c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.2.4 (2021-04-02)
2
+
3
+ - Added support for flat growth
4
+
1
5
  ## 0.2.3 (2020-10-14)
2
6
 
3
7
  - Added support for times to `forecast` method
data/README.md CHANGED
@@ -10,7 +10,7 @@ Supports:
10
10
 
11
11
  And gracefully handles missing data
12
12
 
13
- [![Build Status](https://travis-ci.org/ankane/prophet.svg?branch=master)](https://travis-ci.org/ankane/prophet) [![Build status](https://ci.appveyor.com/api/projects/status/8ahmsvvhum4ivnmv/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/prophet/branch/master)
13
+ [![Build Status](https://github.com/ankane/prophet/workflows/build/badge.svg?branch=master)](https://github.com/ankane/prophet/actions)
14
14
 
15
15
  ## Installation
16
16
 
@@ -75,8 +75,8 @@ module Prophet
75
75
  end
76
76
 
77
77
  def validate_inputs
78
- if !["linear", "logistic"].include?(@growth)
79
- raise ArgumentError, "Parameter \"growth\" should be \"linear\" or \"logistic\"."
78
+ if !["linear", "logistic", "flat"].include?(@growth)
79
+ raise ArgumentError, "Parameter \"growth\" should be \"linear\", \"logistic\", or \"flat\"."
80
80
  end
81
81
  if @changepoint_range < 0 || @changepoint_range > 1
82
82
  raise ArgumentError, "Parameter \"changepoint_range\" must be in [0, 1]"
@@ -602,6 +602,12 @@ module Prophet
602
602
  [k, m]
603
603
  end
604
604
 
605
+ def flat_growth_init(df)
606
+ k = 0
607
+ m = df["y_scaled"].mean
608
+ [k, m]
609
+ end
610
+
605
611
  def fit(df, **kwargs)
606
612
  raise Error, "Prophet object can only be fit once" if @history
607
613
 
@@ -624,6 +630,8 @@ module Prophet
624
630
 
625
631
  set_changepoints
626
632
 
633
+ trend_indicator = {"linear" => 0, "logistic" => 1, "flat" => 2}
634
+
627
635
  dat = {
628
636
  "T" => history.shape[0],
629
637
  "K" => seasonal_features.shape[1],
@@ -634,7 +642,7 @@ module Prophet
634
642
  "X" => seasonal_features,
635
643
  "sigmas" => prior_scales,
636
644
  "tau" => @changepoint_prior_scale,
637
- "trend_indicator" => @growth == "logistic" ? 1 : 0,
645
+ "trend_indicator" => trend_indicator[@growth],
638
646
  "s_a" => component_cols["additive_terms"],
639
647
  "s_m" => component_cols["multiplicative_terms"]
640
648
  }
@@ -642,6 +650,9 @@ module Prophet
642
650
  if @growth == "linear"
643
651
  dat["cap"] = Numo::DFloat.zeros(@history.shape[0])
644
652
  kinit = linear_growth_init(history)
653
+ elsif @growth == "flat"
654
+ dat["cap"] = Numo::DFloat.zeros(@history.shape[0])
655
+ kinit = flat_growth_init(history)
645
656
  else
646
657
  dat["cap"] = history["cap_scaled"]
647
658
  kinit = logistic_growth_init(history)
@@ -655,7 +666,7 @@ module Prophet
655
666
  "sigma_obs" => 1
656
667
  }
657
668
 
658
- if history["y"].min == history["y"].max && @growth == "linear"
669
+ if history["y"].min == history["y"].max && (@growth == "linear" || @growth == "flat")
659
670
  # Nothing to fit.
660
671
  @params = stan_init
661
672
  @params["sigma_obs"] = 1e-9
@@ -741,6 +752,11 @@ module Prophet
741
752
  cap.to_numo / (1 + Numo::NMath.exp(-k_t * (t - m_t)))
742
753
  end
743
754
 
755
+ def flat_trend(t, m)
756
+ m_t = m * t.new_ones
757
+ m_t
758
+ end
759
+
744
760
  def predict_trend(df)
745
761
  k = @params["k"].mean(nan: true)
746
762
  m = @params["m"].mean(nan: true)
@@ -749,9 +765,11 @@ module Prophet
749
765
  t = Numo::NArray.asarray(df["t"].to_a)
750
766
  if @growth == "linear"
751
767
  trend = piecewise_linear(t, deltas, k, m, @changepoints_t)
752
- else
768
+ elsif @growth == "logistic"
753
769
  cap = df["cap_scaled"]
754
770
  trend = piecewise_logistic(t, cap, deltas, k, m, @changepoints_t)
771
+ elsif @growth == "flat"
772
+ trend = flat_trend(t, m)
755
773
  end
756
774
 
757
775
  trend * @y_scale + Numo::NArray.asarray(df["floor"].to_a)
@@ -887,9 +905,11 @@ module Prophet
887
905
 
888
906
  if @growth == "linear"
889
907
  trend = piecewise_linear(t, deltas, k, m, changepoint_ts)
890
- else
908
+ elsif @growth == "logistic"
891
909
  cap = df["cap_scaled"]
892
910
  trend = piecewise_logistic(t, cap, deltas, k, m, changepoint_ts)
911
+ elsif @growth == "flat"
912
+ trend = flat_trend(t, m)
893
913
  end
894
914
 
895
915
  trend * @y_scale + Numo::NArray.asarray(df["floor"].to_a)
@@ -1,3 +1,3 @@
1
1
  module Prophet
2
- VERSION = "0.2.3"
2
+ VERSION = "0.2.4"
3
3
  end
@@ -73,6 +73,15 @@ functions {
73
73
  ) {
74
74
  return (k + A * delta) .* t + (m + A * (-t_change .* delta));
75
75
  }
76
+
77
+ // Flat trend function
78
+
79
+ vector flat_trend(
80
+ real m,
81
+ int T
82
+ ) {
83
+ return rep_vector(m, T);
84
+ }
76
85
  }
77
86
 
78
87
  data {
@@ -86,7 +95,7 @@ data {
86
95
  matrix[T,K] X; // Regressors
87
96
  vector[K] sigmas; // Scale on seasonality prior
88
97
  real<lower=0> tau; // Scale on changepoints prior
89
- int trend_indicator; // 0 for linear, 1 for logistic
98
+ int trend_indicator; // 0 for linear, 1 for logistic, 2 for flat
90
99
  vector[K] s_a; // Indicator of additive features
91
100
  vector[K] s_m; // Indicator of multiplicative features
92
101
  }
@@ -104,6 +113,17 @@ parameters {
104
113
  vector[K] beta; // Regressor coefficients
105
114
  }
106
115
 
116
+ transformed parameters {
117
+ vector[T] trend;
118
+ if (trend_indicator == 0) {
119
+ trend = linear_trend(k, m, delta, t, A, t_change);
120
+ } else if (trend_indicator == 1) {
121
+ trend = logistic_trend(k, m, delta, t, cap, A, t_change, S);
122
+ } else if (trend_indicator == 2) {
123
+ trend = flat_trend(m, T);
124
+ }
125
+ }
126
+
107
127
  model {
108
128
  //priors
109
129
  k ~ normal(0, 5);
@@ -113,19 +133,10 @@ model {
113
133
  beta ~ normal(0, sigmas);
114
134
 
115
135
  // Likelihood
116
- if (trend_indicator == 0) {
117
- y ~ normal(
118
- linear_trend(k, m, delta, t, A, t_change)
119
- .* (1 + X * (beta .* s_m))
120
- + X * (beta .* s_a),
121
- sigma_obs
122
- );
123
- } else if (trend_indicator == 1) {
124
- y ~ normal(
125
- logistic_trend(k, m, delta, t, cap, A, t_change, S)
126
- .* (1 + X * (beta .* s_m))
127
- + X * (beta .* s_a),
128
- sigma_obs
129
- );
130
- }
136
+ y ~ normal(
137
+ trend
138
+ .* (1 + X * (beta .* s_m))
139
+ + X * (beta .* s_a),
140
+ sigma_obs
141
+ );
131
142
  }
@@ -47,7 +47,7 @@ functions {
47
47
  }
48
48
  return gamma;
49
49
  }
50
-
50
+
51
51
  real[] logistic_trend(
52
52
  real k,
53
53
  real m,
@@ -94,6 +94,17 @@ functions {
94
94
  }
95
95
  return Y;
96
96
  }
97
+
98
+ // Flat trend function
99
+
100
+ real[] flat_trend(
101
+ real m,
102
+ int T
103
+ ) {
104
+ return rep_array(m, T);
105
+ }
106
+
107
+
97
108
  }
98
109
 
99
110
  data {
@@ -107,7 +118,7 @@ data {
107
118
  real X[T,K]; // Regressors
108
119
  vector[K] sigmas; // Scale on seasonality prior
109
120
  real<lower=0> tau; // Scale on changepoints prior
110
- int trend_indicator; // 0 for linear, 1 for logistic
121
+ int trend_indicator; // 0 for linear, 1 for logistic, 2 for flat
111
122
  real s_a[K]; // Indicator of additive features
112
123
  real s_m[K]; // Indicator of multiplicative features
113
124
  }
@@ -135,6 +146,8 @@ transformed parameters {
135
146
  trend = linear_trend(k, m, delta, t, A, t_change, S, T);
136
147
  } else if (trend_indicator == 1) {
137
148
  trend = logistic_trend(k, m, delta, t, cap, A, t_change, S, T);
149
+ } else if (trend_indicator == 2){
150
+ trend = flat_trend(m, T);
138
151
  }
139
152
 
140
153
  for (i in 1:K) {
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: prophet-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.2.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-10-15 00:00:00.000000000 Z
11
+ date: 2021-04-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: cmdstan
@@ -52,106 +52,8 @@ dependencies:
52
52
  - - ">="
53
53
  - !ruby/object:Gem::Version
54
54
  version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: bundler
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: rake
71
- requirement: !ruby/object:Gem::Requirement
72
- requirements:
73
- - - ">="
74
- - !ruby/object:Gem::Version
75
- version: '0'
76
- type: :development
77
- prerelease: false
78
- version_requirements: !ruby/object:Gem::Requirement
79
- requirements:
80
- - - ">="
81
- - !ruby/object:Gem::Version
82
- version: '0'
83
- - !ruby/object:Gem::Dependency
84
- name: minitest
85
- requirement: !ruby/object:Gem::Requirement
86
- requirements:
87
- - - ">="
88
- - !ruby/object:Gem::Version
89
- version: '5'
90
- type: :development
91
- prerelease: false
92
- version_requirements: !ruby/object:Gem::Requirement
93
- requirements:
94
- - - ">="
95
- - !ruby/object:Gem::Version
96
- version: '5'
97
- - !ruby/object:Gem::Dependency
98
- name: daru
99
- requirement: !ruby/object:Gem::Requirement
100
- requirements:
101
- - - ">="
102
- - !ruby/object:Gem::Version
103
- version: '0'
104
- type: :development
105
- prerelease: false
106
- version_requirements: !ruby/object:Gem::Requirement
107
- requirements:
108
- - - ">="
109
- - !ruby/object:Gem::Version
110
- version: '0'
111
- - !ruby/object:Gem::Dependency
112
- name: matplotlib
113
- requirement: !ruby/object:Gem::Requirement
114
- requirements:
115
- - - ">="
116
- - !ruby/object:Gem::Version
117
- version: '0'
118
- type: :development
119
- prerelease: false
120
- version_requirements: !ruby/object:Gem::Requirement
121
- requirements:
122
- - - ">="
123
- - !ruby/object:Gem::Version
124
- version: '0'
125
- - !ruby/object:Gem::Dependency
126
- name: activesupport
127
- requirement: !ruby/object:Gem::Requirement
128
- requirements:
129
- - - ">="
130
- - !ruby/object:Gem::Version
131
- version: '0'
132
- type: :development
133
- prerelease: false
134
- version_requirements: !ruby/object:Gem::Requirement
135
- requirements:
136
- - - ">="
137
- - !ruby/object:Gem::Version
138
- version: '0'
139
- - !ruby/object:Gem::Dependency
140
- name: tzinfo-data
141
- requirement: !ruby/object:Gem::Requirement
142
- requirements:
143
- - - ">="
144
- - !ruby/object:Gem::Version
145
- version: '0'
146
- type: :development
147
- prerelease: false
148
- version_requirements: !ruby/object:Gem::Requirement
149
- requirements:
150
- - - ">="
151
- - !ruby/object:Gem::Version
152
- version: '0'
153
- description:
154
- email: andrew@chartkick.com
55
+ description:
56
+ email: andrew@ankane.org
155
57
  executables: []
156
58
  extensions:
157
59
  - ext/prophet/extconf.rb
@@ -176,7 +78,7 @@ homepage: https://github.com/ankane/prophet
176
78
  licenses:
177
79
  - MIT
178
80
  metadata: {}
179
- post_install_message:
81
+ post_install_message:
180
82
  rdoc_options: []
181
83
  require_paths:
182
84
  - lib
@@ -191,8 +93,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
191
93
  - !ruby/object:Gem::Version
192
94
  version: '0'
193
95
  requirements: []
194
- rubygems_version: 3.1.2
195
- signing_key:
96
+ rubygems_version: 3.2.3
97
+ signing_key:
196
98
  specification_version: 4
197
99
  summary: Time series forecasting for Ruby
198
100
  test_files: []