prophet-rb 0.4.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,7 +44,7 @@ module Prophet
44
44
  @yearly_seasonality = yearly_seasonality
45
45
  @weekly_seasonality = weekly_seasonality
46
46
  @daily_seasonality = daily_seasonality
47
- @holidays = holidays
47
+ @holidays = convert_df(holidays)
48
48
 
49
49
  @seasonality_mode = seasonality_mode
50
50
  @seasonality_prior_scale = seasonality_prior_scale.to_f
@@ -75,6 +75,7 @@ module Prophet
75
75
  validate_inputs
76
76
 
77
77
  @logger = ::Logger.new($stderr)
78
+ @logger.level = ::Logger::WARN
78
79
  @logger.formatter = proc do |severity, datetime, progname, msg|
79
80
  "[prophet] #{msg}\n"
80
81
  end
@@ -133,10 +134,10 @@ module Prophet
133
134
  if reserved_names.include?(name)
134
135
  raise ArgumentError, "Name #{name.inspect} is reserved."
135
136
  end
136
- if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
137
+ if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
137
138
  raise ArgumentError, "Name #{name.inspect} already used for a holiday."
138
139
  end
139
- if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
140
+ if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
140
141
  raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
141
142
  end
142
143
  if check_seasonalities && @seasonalities[name]
@@ -631,9 +632,7 @@ module Prophet
631
632
  def fit(df, **kwargs)
632
633
  raise Error, "Prophet object can only be fit once" if @history
633
634
 
634
- if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
635
- df = Rover::DataFrame.new(df.to_h)
636
- end
635
+ df = convert_df(df)
637
636
  raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
638
637
 
639
638
  unless df.include?("ds") && df.include?("y")
@@ -993,6 +992,16 @@ module Prophet
993
992
 
994
993
  private
995
994
 
995
+ def convert_df(df)
996
+ if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
997
+ Rover::DataFrame.new(df.to_h)
998
+ elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
999
+ Rover::DataFrame.new(df.to_h(as_series: false))
1000
+ else
1001
+ df
1002
+ end
1003
+ end
1004
+
996
1005
  # Time is preferred over DateTime in Ruby docs
997
1006
  # use UTC to be consistent with Python
998
1007
  # and so days have equal length (no DST)
@@ -1077,7 +1086,7 @@ module Prophet
1077
1086
  d = {
1078
1087
  "name" => "ds",
1079
1088
  "index" => v.size.times.to_a,
1080
- "data" => v.to_a.map { |v| v.iso8601(3) }
1089
+ "data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
1081
1090
  }
1082
1091
  model_dict[attribute] = JSON.generate(d)
1083
1092
  end
@@ -1096,7 +1105,7 @@ module Prophet
1096
1105
  v = instance_variable_get("@#{attribute}")
1097
1106
 
1098
1107
  v = v.dup
1099
- v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
1108
+ v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
1100
1109
  v.delete("col")
1101
1110
 
1102
1111
  fields =
@@ -1116,7 +1125,7 @@ module Prophet
1116
1125
  d = {
1117
1126
  "schema" => {
1118
1127
  "fields" => fields,
1119
- "pandas_version" => "0.20.0"
1128
+ "pandas_version" => "1.4.0"
1120
1129
  },
1121
1130
  "data" => v.to_a
1122
1131
  }
@@ -1151,8 +1160,7 @@ module Prophet
1151
1160
  # Params (Dict[str, np.ndarray])
1152
1161
  model_dict["params"] = params.transform_values(&:to_a)
1153
1162
  # Attributes that are skipped: stan_fit, stan_backend
1154
- # Returns 1.0 for Prophet 1.1
1155
- model_dict["__prophet_version"] = "1.0"
1163
+ model_dict["__prophet_version"] = "1.1.2"
1156
1164
  model_dict
1157
1165
  end
1158
1166
 
@@ -1174,7 +1182,7 @@ module Prophet
1174
1182
  d = JSON.parse(model_dict.fetch(attribute))
1175
1183
  s = Rover::Vector.new(d["data"])
1176
1184
  if d["name"] == "ds"
1177
- s = s.map { |v| Time.parse(v).utc }
1185
+ s = s.map { |v| DateTime.parse(v).to_time.utc }
1178
1186
  end
1179
1187
  model.instance_variable_set("@#{attribute}", s)
1180
1188
  end
@@ -1191,7 +1199,7 @@ module Prophet
1191
1199
  else
1192
1200
  d = JSON.parse(model_dict.fetch(attribute))
1193
1201
  df = Rover::DataFrame.new(d["data"])
1194
- df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
1202
+ df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
1195
1203
  if attribute == "train_component_cols"
1196
1204
  # Special handling because of named index column
1197
1205
  # df.columns.name = 'component'
@@ -4,9 +4,7 @@ module Prophet
4
4
  years = (1995..2045).to_a
5
5
  holiday_names = make_holidays_df(years, country)["holiday"].uniq
6
6
  if holiday_names.size == 0
7
- # TODO raise error in 0.5.0
8
- # raise ArgumentError, "Holidays in #{country} are not currently supported"
9
- logger.warn "Holidays in #{country} are not currently supported"
7
+ raise ArgumentError, "Holidays in #{country} are not currently supported"
10
8
  end
11
9
  holiday_names
12
10
  end
@@ -21,14 +21,16 @@ module Prophet
21
21
  kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
22
22
  iterations = 10000
23
23
 
24
+ args = {
25
+ data: stan_data,
26
+ inits: stan_init,
27
+ iter: iterations
28
+ }
29
+ args.merge!(kwargs)
30
+
24
31
  stan_fit = nil
25
32
  begin
26
- stan_fit = @model.optimize(
27
- data: stan_data,
28
- inits: stan_init,
29
- iter: iterations,
30
- **kwargs
31
- )
33
+ stan_fit = @model.optimize(**args)
32
34
  rescue => e
33
35
  if kwargs[:algorithm] != "Newton"
34
36
  @logger.warn "Optimization terminated abnormally. Falling back to Newton."
@@ -1,3 +1,3 @@
1
1
  module Prophet
2
- VERSION = "0.4.2"
2
+ VERSION = "0.5.0"
3
3
  end
data/lib/prophet-rb.rb CHANGED
@@ -1 +1 @@
1
- require "prophet"
1
+ require_relative "prophet"
data/lib/prophet.rb CHANGED
@@ -1,19 +1,19 @@
1
1
  # dependencies
2
2
  require "cmdstan"
3
- require "rover"
4
3
  require "numo/narray"
4
+ require "rover"
5
5
 
6
6
  # stdlib
7
7
  require "logger"
8
8
  require "set"
9
9
 
10
10
  # modules
11
- require "prophet/diagnostics"
12
- require "prophet/holidays"
13
- require "prophet/plot"
14
- require "prophet/forecaster"
15
- require "prophet/stan_backend"
16
- require "prophet/version"
11
+ require_relative "prophet/diagnostics"
12
+ require_relative "prophet/holidays"
13
+ require_relative "prophet/plot"
14
+ require_relative "prophet/forecaster"
15
+ require_relative "prophet/stan_backend"
16
+ require_relative "prophet/version"
17
17
 
18
18
  module Prophet
19
19
  class Error < StandardError; end
@@ -68,7 +68,7 @@ module Prophet
68
68
  df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
69
69
  df["cap"] = cap if cap
70
70
 
71
- m.logger.level = ::Logger::FATAL unless verbose
71
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
72
72
  m.add_country_holidays(country_holidays) if country_holidays
73
73
  m.fit(df)
74
74
 
@@ -87,7 +87,7 @@ module Prophet
87
87
  else
88
88
  result.each { |v| v["ds"] = v["ds"].localtime }
89
89
  end
90
- result.map { |v| [v["ds"], v["yhat"]] }.to_h
90
+ result.to_h { |v| [v["ds"], v["yhat"]] }
91
91
  end
92
92
 
93
93
  # TODO better name for interval_width
@@ -97,7 +97,7 @@ module Prophet
97
97
  df["cap"] = cap if cap
98
98
 
99
99
  m = Prophet.new(interval_width: interval_width, **options)
100
- m.logger.level = ::Logger::FATAL unless verbose
100
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
101
101
  m.add_country_holidays(country_holidays) if country_holidays
102
102
  m.fit(df)
103
103
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: prophet-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.2
4
+ version: 0.5.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-07-12 00:00:00.000000000 Z
11
+ date: 2023-09-05 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: cmdstan
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
159
159
  requirements:
160
160
  - - ">="
161
161
  - !ruby/object:Gem::Version
162
- version: '2.7'
162
+ version: '3'
163
163
  required_rubygems_version: !ruby/object:Gem::Requirement
164
164
  requirements:
165
165
  - - ">="
166
166
  - !ruby/object:Gem::Version
167
167
  version: '0'
168
168
  requirements: []
169
- rubygems_version: 3.3.7
169
+ rubygems_version: 3.4.10
170
170
  signing_key:
171
171
  specification_version: 4
172
172
  summary: Time series forecasting for Ruby