eps 0.1.1 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +235 -84
- data/lib/eps.rb +9 -4
- data/lib/eps/base.rb +19 -0
- data/lib/eps/base_estimator.rb +84 -0
- data/lib/eps/linear_regression.rb +558 -0
- data/lib/eps/model.rb +108 -0
- data/lib/eps/naive_bayes.rb +240 -0
- data/lib/eps/version.rb +1 -1
- metadata +13 -18
- data/.gitignore +0 -9
- data/.travis.yml +0 -15
- data/Gemfile +0 -11
- data/Rakefile +0 -34
- data/eps.gemspec +0 -30
- data/guides/Modeling.md +0 -152
- data/lib/eps/base_regressor.rb +0 -232
- data/lib/eps/metrics.rb +0 -35
- data/lib/eps/regressor.rb +0 -314
data/lib/eps.rb
CHANGED
@@ -3,13 +3,18 @@ require "matrix"
|
|
3
3
|
require "json"
|
4
4
|
|
5
5
|
# modules
|
6
|
-
require "eps/
|
7
|
-
require "eps/
|
8
|
-
require "eps/
|
6
|
+
require "eps/base"
|
7
|
+
require "eps/base_estimator"
|
8
|
+
require "eps/linear_regression"
|
9
|
+
require "eps/model"
|
10
|
+
require "eps/naive_bayes"
|
9
11
|
require "eps/version"
|
10
12
|
|
11
13
|
module Eps
|
12
14
|
def self.metrics(actual, estimated)
|
13
|
-
Eps::
|
15
|
+
Eps::Model.metrics(actual, estimated)
|
14
16
|
end
|
17
|
+
|
18
|
+
# backwards compatibility
|
19
|
+
Regressor = Model
|
15
20
|
end
|
data/lib/eps/base.rb
ADDED
@@ -0,0 +1,84 @@
|
|
1
|
+
module Eps
|
2
|
+
class BaseEstimator
|
3
|
+
def train(data, y, target: nil, **options)
|
4
|
+
# TODO more performant conversion
|
5
|
+
if daru?(data)
|
6
|
+
x = data.dup
|
7
|
+
x = x.delete_vector(target) if target
|
8
|
+
else
|
9
|
+
x = data.map(&:dup)
|
10
|
+
x.each { |r| r.delete(target) } if target
|
11
|
+
end
|
12
|
+
|
13
|
+
y = prep_y(y.to_a)
|
14
|
+
|
15
|
+
@target = target || "target"
|
16
|
+
|
17
|
+
if x.size != y.size
|
18
|
+
raise "Number of samples differs from target"
|
19
|
+
end
|
20
|
+
|
21
|
+
@x = x
|
22
|
+
@y = y
|
23
|
+
@target = target
|
24
|
+
end
|
25
|
+
|
26
|
+
def predict(x)
|
27
|
+
singular = !(x.is_a?(Array) || daru?(x))
|
28
|
+
x = [x] if singular
|
29
|
+
|
30
|
+
pred = _predict(x)
|
31
|
+
|
32
|
+
singular ? pred[0] : pred
|
33
|
+
end
|
34
|
+
|
35
|
+
def evaluate(data, y = nil, target: nil)
|
36
|
+
target ||= @target
|
37
|
+
raise ArgumentError, "missing target" if !target && !y
|
38
|
+
|
39
|
+
actual = y
|
40
|
+
actual ||=
|
41
|
+
if daru?(data)
|
42
|
+
data[target].to_a
|
43
|
+
else
|
44
|
+
data.map { |v| v[target] }
|
45
|
+
end
|
46
|
+
|
47
|
+
actual = prep_y(actual)
|
48
|
+
estimated = predict(data)
|
49
|
+
|
50
|
+
self.class.metrics(actual, estimated)
|
51
|
+
end
|
52
|
+
|
53
|
+
private
|
54
|
+
|
55
|
+
def categorical?(v)
|
56
|
+
!v.is_a?(Numeric)
|
57
|
+
end
|
58
|
+
|
59
|
+
def daru?(x)
|
60
|
+
defined?(Daru) && x.is_a?(Daru::DataFrame)
|
61
|
+
end
|
62
|
+
|
63
|
+
def flip_target(target)
|
64
|
+
target.is_a?(String) ? target.to_sym : target.to_s
|
65
|
+
end
|
66
|
+
|
67
|
+
def prep_y(y)
|
68
|
+
y.each do |yi|
|
69
|
+
raise "Target missing in data" if yi.nil?
|
70
|
+
end
|
71
|
+
y
|
72
|
+
end
|
73
|
+
|
74
|
+
# determine if target is a string or symbol
|
75
|
+
def prep_target(target, data)
|
76
|
+
if daru?(data)
|
77
|
+
data.has_vector?(target) ? target : flip_target(target)
|
78
|
+
else
|
79
|
+
x = data[0] || {}
|
80
|
+
x[target] ? target : flip_target(target)
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
@@ -0,0 +1,558 @@
|
|
1
|
+
module Eps
|
2
|
+
class LinearRegression < BaseEstimator
|
3
|
+
def initialize(coefficients: nil, gsl: nil)
|
4
|
+
@coefficients = Hash[coefficients.map { |k, v| [k.is_a?(Array) ? [k[0].to_sym, k[1]] : k.to_sym, v] }] if coefficients
|
5
|
+
@gsl = gsl.nil? ? defined?(GSL) : gsl
|
6
|
+
end
|
7
|
+
|
8
|
+
def train(*args)
|
9
|
+
super
|
10
|
+
|
11
|
+
x, @coefficient_names = prep_x(@x)
|
12
|
+
|
13
|
+
if x.size <= @coefficient_names.size
|
14
|
+
raise "Number of samples must be at least two more than number of features"
|
15
|
+
end
|
16
|
+
|
17
|
+
v3 =
|
18
|
+
if @gsl
|
19
|
+
x = GSL::Matrix.alloc(*x)
|
20
|
+
y = GSL::Vector.alloc(@y)
|
21
|
+
c, @covariance, _, _ = GSL::MultiFit::linear(x, y)
|
22
|
+
c.to_a
|
23
|
+
else
|
24
|
+
x = Matrix.rows(x)
|
25
|
+
y = Matrix.column_vector(@y)
|
26
|
+
removed = []
|
27
|
+
|
28
|
+
# https://statsmaths.github.io/stat612/lectures/lec13/lecture13.pdf
|
29
|
+
# unforutnately, this method is unstable
|
30
|
+
# haven't found an efficient way to do QR-factorization in Ruby
|
31
|
+
# the extendmatrix gem has householder and givens (givens has bug)
|
32
|
+
# but methods are too slow
|
33
|
+
xt = x.t
|
34
|
+
begin
|
35
|
+
@xtxi = (xt * x).inverse
|
36
|
+
rescue ExceptionForMatrix::ErrNotRegular
|
37
|
+
constant = {}
|
38
|
+
(1...x.column_count).each do |i|
|
39
|
+
constant[i] = constant?(x.column(i))
|
40
|
+
end
|
41
|
+
|
42
|
+
# remove constant columns
|
43
|
+
removed = constant.select { |_, v| v }.keys
|
44
|
+
|
45
|
+
# remove non-independent columns
|
46
|
+
constant.select { |_, v| !v }.keys.combination(2) do |c2|
|
47
|
+
if !x.column(c2[0]).independent?(x.column(c2[1]))
|
48
|
+
removed << c2[1]
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
vectors = x.column_vectors
|
53
|
+
# delete in reverse of indexes stay the same
|
54
|
+
removed.sort.reverse.each do |i|
|
55
|
+
# @coefficient_names.delete_at(i)
|
56
|
+
vectors.delete_at(i)
|
57
|
+
end
|
58
|
+
x = Matrix.columns(vectors)
|
59
|
+
xt = x.t
|
60
|
+
|
61
|
+
# try again
|
62
|
+
begin
|
63
|
+
@xtxi = (xt * x).inverse
|
64
|
+
rescue ExceptionForMatrix::ErrNotRegular
|
65
|
+
raise "Multiple solutions - GSL is needed to select one"
|
66
|
+
end
|
67
|
+
end
|
68
|
+
# huge performance boost
|
69
|
+
# by multiplying xt * y first
|
70
|
+
v2 = matrix_arr(@xtxi * (xt * y))
|
71
|
+
|
72
|
+
# add back removed
|
73
|
+
removed.sort.each do |i|
|
74
|
+
v2.insert(i, 0)
|
75
|
+
end
|
76
|
+
@removed = removed
|
77
|
+
|
78
|
+
v2
|
79
|
+
end
|
80
|
+
|
81
|
+
@coefficients = Hash[@coefficient_names.zip(v3)]
|
82
|
+
end
|
83
|
+
|
84
|
+
# legacy
|
85
|
+
|
86
|
+
def coefficients
|
87
|
+
Hash[@coefficients.map { |k, v| [Array(k).join.to_sym, v] }]
|
88
|
+
end
|
89
|
+
|
90
|
+
# ruby
|
91
|
+
|
92
|
+
def self.load(data)
|
93
|
+
new(Hash[data.map { |k, v| [k.to_sym, v] }])
|
94
|
+
end
|
95
|
+
|
96
|
+
def dump
|
97
|
+
{coefficients: coefficients}
|
98
|
+
end
|
99
|
+
|
100
|
+
# json
|
101
|
+
|
102
|
+
def self.load_json(data)
|
103
|
+
data = JSON.parse(data) if data.is_a?(String)
|
104
|
+
coefficients = data["coefficients"]
|
105
|
+
|
106
|
+
# for R models
|
107
|
+
if coefficients["(Intercept)"]
|
108
|
+
coefficients = coefficients.dup
|
109
|
+
coefficients["_intercept"] = coefficients.delete("(Intercept)")
|
110
|
+
end
|
111
|
+
|
112
|
+
new(coefficients: coefficients)
|
113
|
+
end
|
114
|
+
|
115
|
+
def to_json
|
116
|
+
JSON.generate(dump)
|
117
|
+
end
|
118
|
+
|
119
|
+
# pmml
|
120
|
+
|
121
|
+
def self.load_pmml(data)
|
122
|
+
# TODO more validation
|
123
|
+
node = data.css("RegressionTable")
|
124
|
+
coefficients = {
|
125
|
+
_intercept: node.attribute("intercept").value.to_f
|
126
|
+
}
|
127
|
+
node.css("NumericPredictor").each do |n|
|
128
|
+
coefficients[n.attribute("name").value] = n.attribute("coefficient").value.to_f
|
129
|
+
end
|
130
|
+
node.css("CategoricalPredictor").each do |n|
|
131
|
+
coefficients[[n.attribute("name").value.to_sym, n.attribute("value").value]] = n.attribute("coefficient").value.to_f
|
132
|
+
end
|
133
|
+
new(coefficients: coefficients)
|
134
|
+
end
|
135
|
+
|
136
|
+
def to_pmml
|
137
|
+
predictors = @coefficients.reject { |k| k == :_intercept }
|
138
|
+
|
139
|
+
data_fields = {}
|
140
|
+
predictors.each do |k, v|
|
141
|
+
if k.is_a?(Array)
|
142
|
+
(data_fields[k[0]] ||= []) << k[1]
|
143
|
+
else
|
144
|
+
data_fields[k] = nil
|
145
|
+
end
|
146
|
+
end
|
147
|
+
|
148
|
+
builder = Nokogiri::XML::Builder.new do |xml|
|
149
|
+
xml.PMML(version: "4.3", xmlns: "http://www.dmg.org/PMML-4_3", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do
|
150
|
+
xml.Header
|
151
|
+
xml.DataDictionary do
|
152
|
+
data_fields.each do |k, vs|
|
153
|
+
if vs
|
154
|
+
xml.DataField(name: k, optype: "categorical", dataType: "string") do
|
155
|
+
vs.each do |v|
|
156
|
+
xml.Value(value: v)
|
157
|
+
end
|
158
|
+
end
|
159
|
+
else
|
160
|
+
xml.DataField(name: k, optype: "continuous", dataType: "double")
|
161
|
+
end
|
162
|
+
end
|
163
|
+
end
|
164
|
+
xml.RegressionModel(functionName: "regression") do
|
165
|
+
xml.MiningSchema do
|
166
|
+
data_fields.each do |k, _|
|
167
|
+
xml.MiningField(name: k)
|
168
|
+
end
|
169
|
+
end
|
170
|
+
xml.RegressionTable(intercept: @coefficients[:_intercept]) do
|
171
|
+
predictors.each do |k, v|
|
172
|
+
if k.is_a?(Array)
|
173
|
+
xml.CategoricalPredictor(name: k[0], value: k[1], coefficient: v)
|
174
|
+
else
|
175
|
+
xml.NumericPredictor(name: k, coefficient: v)
|
176
|
+
end
|
177
|
+
end
|
178
|
+
end
|
179
|
+
end
|
180
|
+
end
|
181
|
+
end.to_xml
|
182
|
+
end
|
183
|
+
|
184
|
+
# pfa
|
185
|
+
|
186
|
+
def self.load_pfa(data)
|
187
|
+
data = JSON.parse(data) if data.is_a?(String)
|
188
|
+
init = data["cells"].first[1]["init"]
|
189
|
+
names =
|
190
|
+
if data["input"]["fields"]
|
191
|
+
data["input"]["fields"].map { |f| f["name"] }
|
192
|
+
else
|
193
|
+
init["coeff"].map.with_index { |_, i| "x#{i}" }
|
194
|
+
end
|
195
|
+
coefficients = {
|
196
|
+
_intercept: init["const"]
|
197
|
+
}
|
198
|
+
init["coeff"].each_with_index do |c, i|
|
199
|
+
name = names[i]
|
200
|
+
# R can export coefficients with same name
|
201
|
+
raise "Coefficients with same name" if coefficients[name]
|
202
|
+
coefficients[name] = c
|
203
|
+
end
|
204
|
+
new(coefficients: coefficients)
|
205
|
+
end
|
206
|
+
|
207
|
+
# metrics
|
208
|
+
|
209
|
+
def self.metrics(actual, estimated)
|
210
|
+
errors = actual.zip(estimated).map { |yi, yi2| yi - yi2 }
|
211
|
+
|
212
|
+
{
|
213
|
+
me: mean(errors),
|
214
|
+
mae: mean(errors.map { |v| v.abs }),
|
215
|
+
rmse: Math.sqrt(mean(errors.map { |v| v**2 }))
|
216
|
+
}
|
217
|
+
end
|
218
|
+
|
219
|
+
# private
|
220
|
+
def self.mean(arr)
|
221
|
+
arr.inject(0, &:+) / arr.size.to_f
|
222
|
+
end
|
223
|
+
|
224
|
+
# https://people.richland.edu/james/ictcm/2004/multiple.html
|
225
|
+
def summary(extended: false)
|
226
|
+
@summary_str ||= begin
|
227
|
+
str = String.new("")
|
228
|
+
len = [coefficients.keys.map(&:size).max, 15].max
|
229
|
+
if extended
|
230
|
+
str += "%-#{len}s %12s %12s %12s %12s\n" % ["", "coef", "stderr", "t", "p"]
|
231
|
+
else
|
232
|
+
str += "%-#{len}s %12s %12s\n" % ["", "coef", "p"]
|
233
|
+
end
|
234
|
+
coefficients.each do |k, v|
|
235
|
+
if extended
|
236
|
+
str += "%-#{len}s %12.2f %12.2f %12.2f %12.3f\n" % [display_field(k), v, std_err[k], t_value[k], p_value[k]]
|
237
|
+
else
|
238
|
+
str += "%-#{len}s %12.2f %12.3f\n" % [display_field(k), v, p_value[k]]
|
239
|
+
end
|
240
|
+
end
|
241
|
+
str += "\n"
|
242
|
+
str += "r2: %.3f\n" % [r2] if extended
|
243
|
+
str += "adjusted r2: %.3f\n" % [adjusted_r2]
|
244
|
+
str
|
245
|
+
end
|
246
|
+
end
|
247
|
+
|
248
|
+
def r2
|
249
|
+
@r2 ||= (sst - sse) / sst
|
250
|
+
end
|
251
|
+
|
252
|
+
def adjusted_r2
|
253
|
+
@adjusted_r2 ||= (mst - mse) / mst
|
254
|
+
end
|
255
|
+
|
256
|
+
private
|
257
|
+
|
258
|
+
def _predict(x)
|
259
|
+
x, c = prep_x(x, train: false)
|
260
|
+
coef = c.map do |v|
|
261
|
+
# use 0 if coefficient does not exist
|
262
|
+
# this can happen for categorical features
|
263
|
+
# since only n-1 coefficients are stored
|
264
|
+
@coefficients[v] || 0
|
265
|
+
end
|
266
|
+
|
267
|
+
x = Matrix.rows(x)
|
268
|
+
c = Matrix.column_vector(coef)
|
269
|
+
matrix_arr(x * c)
|
270
|
+
end
|
271
|
+
|
272
|
+
def display_field(k)
|
273
|
+
k.is_a?(Array) ? k.join("") : k
|
274
|
+
end
|
275
|
+
|
276
|
+
def constant?(arr)
|
277
|
+
arr.all? { |x| x == arr[0] }
|
278
|
+
end
|
279
|
+
|
280
|
+
# add epsilon for perfect fits
|
281
|
+
# consistent with GSL
|
282
|
+
def t_value
|
283
|
+
@t_value ||= Hash[coefficients.map { |k, v| [k, v / (std_err[k] + Float::EPSILON)] }]
|
284
|
+
end
|
285
|
+
|
286
|
+
def p_value
|
287
|
+
@p_value ||= begin
|
288
|
+
Hash[coefficients.map do |k, _|
|
289
|
+
tp =
|
290
|
+
if @gsl
|
291
|
+
GSL::Cdf.tdist_P(t_value[k].abs, degrees_of_freedom)
|
292
|
+
else
|
293
|
+
tdist_p(t_value[k].abs, degrees_of_freedom)
|
294
|
+
end
|
295
|
+
|
296
|
+
[k, 2 * (1 - tp)]
|
297
|
+
end]
|
298
|
+
end
|
299
|
+
end
|
300
|
+
|
301
|
+
def std_err
|
302
|
+
@std_err ||= begin
|
303
|
+
Hash[@coefficient_names.zip(diagonal.map { |v| Math.sqrt(v) })]
|
304
|
+
end
|
305
|
+
end
|
306
|
+
|
307
|
+
def diagonal
|
308
|
+
@diagonal ||= begin
|
309
|
+
if covariance.respond_to?(:each)
|
310
|
+
d = covariance.each(:diagonal).to_a
|
311
|
+
@removed.each do |i|
|
312
|
+
d.insert(i, 0)
|
313
|
+
end
|
314
|
+
d
|
315
|
+
else
|
316
|
+
covariance.diagonal.to_a
|
317
|
+
end
|
318
|
+
end
|
319
|
+
end
|
320
|
+
|
321
|
+
def covariance
|
322
|
+
@covariance ||= mse * @xtxi
|
323
|
+
end
|
324
|
+
|
325
|
+
def y_bar
|
326
|
+
@y_bar ||= mean(@y)
|
327
|
+
end
|
328
|
+
|
329
|
+
def y_hat
|
330
|
+
@y_hat ||= predict(@x)
|
331
|
+
end
|
332
|
+
|
333
|
+
# total sum of squares
|
334
|
+
def sst
|
335
|
+
@sst ||= @y.map { |y| (y - y_bar)**2 }.sum
|
336
|
+
end
|
337
|
+
|
338
|
+
# sum of squared errors of prediction
|
339
|
+
# not to be confused with "explained sum of squares"
|
340
|
+
def sse
|
341
|
+
@sse ||= @y.zip(y_hat).map { |y, yh| (y - yh)**2 }.sum
|
342
|
+
end
|
343
|
+
|
344
|
+
def mst
|
345
|
+
@mst ||= sst / (@y.size - 1)
|
346
|
+
end
|
347
|
+
|
348
|
+
def mse
|
349
|
+
@mse ||= sse / degrees_of_freedom
|
350
|
+
end
|
351
|
+
|
352
|
+
def degrees_of_freedom
|
353
|
+
@y.size - coefficients.size
|
354
|
+
end
|
355
|
+
|
356
|
+
def mean(arr)
|
357
|
+
arr.sum / arr.size.to_f
|
358
|
+
end
|
359
|
+
|
360
|
+
### Extracted from https://github.com/estebanz01/ruby-statistics
|
361
|
+
### The Ruby author is Esteban Zapata Rojas
|
362
|
+
###
|
363
|
+
### Originally extracted from https://codeplea.com/incomplete-beta-function-c
|
364
|
+
### This function is shared under zlib license and the author is Lewis Van Winkle
|
365
|
+
def tdist_p(value, degrees_of_freedom)
|
366
|
+
upper = (value + Math.sqrt(value * value + degrees_of_freedom))
|
367
|
+
lower = (2.0 * Math.sqrt(value * value + degrees_of_freedom))
|
368
|
+
|
369
|
+
x = upper/lower
|
370
|
+
|
371
|
+
alpha = degrees_of_freedom/2.0
|
372
|
+
beta = degrees_of_freedom/2.0
|
373
|
+
|
374
|
+
incomplete_beta_function(x, alpha, beta)
|
375
|
+
end
|
376
|
+
|
377
|
+
### Extracted from https://github.com/estebanz01/ruby-statistics
|
378
|
+
### The Ruby author is Esteban Zapata Rojas
|
379
|
+
###
|
380
|
+
### This implementation is an adaptation of the incomplete beta function made in C by
|
381
|
+
### Lewis Van Winkle, which released the code under the zlib license.
|
382
|
+
### The whole math behind this code is described in the following post: https://codeplea.com/incomplete-beta-function-c
|
383
|
+
def incomplete_beta_function(x, alp, bet)
|
384
|
+
return if x < 0.0
|
385
|
+
return 1.0 if x > 1.0
|
386
|
+
|
387
|
+
tiny = 1.0E-50
|
388
|
+
|
389
|
+
if x > ((alp + 1.0)/(alp + bet + 2.0))
|
390
|
+
return 1.0 - incomplete_beta_function(1.0 - x, bet, alp)
|
391
|
+
end
|
392
|
+
|
393
|
+
# To avoid overflow problems, the implementation applies the logarithm properties
|
394
|
+
# to calculate in a faster and safer way the values.
|
395
|
+
lbet_ab = (Math.lgamma(alp)[0] + Math.lgamma(bet)[0] - Math.lgamma(alp + bet)[0]).freeze
|
396
|
+
front = (Math.exp(Math.log(x) * alp + Math.log(1.0 - x) * bet - lbet_ab) / alp.to_f).freeze
|
397
|
+
|
398
|
+
# This is the non-log version of the left part of the formula (before the continuous fraction)
|
399
|
+
# down_left = alp * self.beta_function(alp, bet)
|
400
|
+
# upper_left = (x ** alp) * ((1.0 - x) ** bet)
|
401
|
+
# front = upper_left/down_left
|
402
|
+
|
403
|
+
f, c, d = 1.0, 1.0, 0.0
|
404
|
+
|
405
|
+
returned_value = nil
|
406
|
+
|
407
|
+
# Let's do more iterations than the proposed implementation (200 iters)
|
408
|
+
(0..500).each do |number|
|
409
|
+
m = number/2
|
410
|
+
|
411
|
+
numerator = if number == 0
|
412
|
+
1.0
|
413
|
+
elsif number % 2 == 0
|
414
|
+
(m * (bet - m) * x)/((alp + 2.0 * m - 1.0)* (alp + 2.0 * m))
|
415
|
+
else
|
416
|
+
top = -((alp + m) * (alp + bet + m) * x)
|
417
|
+
down = ((alp + 2.0 * m) * (alp + 2.0 * m + 1.0))
|
418
|
+
|
419
|
+
top/down
|
420
|
+
end
|
421
|
+
|
422
|
+
d = 1.0 + numerator * d
|
423
|
+
d = tiny if d.abs < tiny
|
424
|
+
d = 1.0 / d
|
425
|
+
|
426
|
+
c = 1.0 + numerator / c
|
427
|
+
c = tiny if c.abs < tiny
|
428
|
+
|
429
|
+
cd = (c*d).freeze
|
430
|
+
f = f * cd
|
431
|
+
|
432
|
+
if (1.0 - cd).abs < 1.0E-10
|
433
|
+
returned_value = front * (f - 1.0)
|
434
|
+
break
|
435
|
+
end
|
436
|
+
end
|
437
|
+
|
438
|
+
returned_value
|
439
|
+
end
|
440
|
+
|
441
|
+
def prep_x(x, train: true)
|
442
|
+
coefficients = @coefficients
|
443
|
+
|
444
|
+
if daru?(x)
|
445
|
+
x = x.to_a[0]
|
446
|
+
else
|
447
|
+
x = x.map do |xi|
|
448
|
+
case xi
|
449
|
+
when Hash
|
450
|
+
xi
|
451
|
+
when Array
|
452
|
+
Hash[xi.map.with_index { |v, i| [:"x#{i}", v] }]
|
453
|
+
else
|
454
|
+
{x0: xi}
|
455
|
+
end
|
456
|
+
end
|
457
|
+
end
|
458
|
+
|
459
|
+
# get column types
|
460
|
+
if train
|
461
|
+
column_types = {}
|
462
|
+
if x.any?
|
463
|
+
row = x.first
|
464
|
+
row.each do |k, v|
|
465
|
+
column_types[k] = categorical?(v) ? "categorical" : "numeric"
|
466
|
+
end
|
467
|
+
end
|
468
|
+
else
|
469
|
+
# get column types for prediction
|
470
|
+
column_types = {}
|
471
|
+
coefficients.each do |k, v|
|
472
|
+
next if k == :_intercept
|
473
|
+
if k.is_a?(Array)
|
474
|
+
column_types[k.first] = "categorical"
|
475
|
+
else
|
476
|
+
column_types[k] = "numeric"
|
477
|
+
end
|
478
|
+
end
|
479
|
+
end
|
480
|
+
|
481
|
+
# if !train && x.any?
|
482
|
+
# # check first row against coefficients
|
483
|
+
# ckeys = coefficients.keys.map(&:to_s)
|
484
|
+
# bad_keys = x[0].keys.map(&:to_s).reject { |k| ckeys.any? { |c| c.start_with?(k) } }
|
485
|
+
# raise "Unknown keys: #{bad_keys.join(", ")}" if bad_keys.any?
|
486
|
+
# end
|
487
|
+
|
488
|
+
supports_categorical = train || coefficients.any? { |k, _| k.is_a?(Array) }
|
489
|
+
|
490
|
+
cache = {}
|
491
|
+
first_key = {}
|
492
|
+
i = 0
|
493
|
+
rows = []
|
494
|
+
x.each do |xi|
|
495
|
+
row = {}
|
496
|
+
xi.each do |k, v|
|
497
|
+
categorical = column_types[k.to_sym] == "categorical" || (!supports_categorical && categorical?(v))
|
498
|
+
|
499
|
+
key = categorical ? [k.to_sym, v.to_s] : k.to_sym
|
500
|
+
v2 = categorical ? 1 : v
|
501
|
+
|
502
|
+
# TODO make more efficient
|
503
|
+
check_key = supports_categorical ? key : symbolize_coef(key)
|
504
|
+
next if !train && !coefficients.key?(check_key)
|
505
|
+
|
506
|
+
raise "Missing data" if v2.nil?
|
507
|
+
|
508
|
+
unless cache[key]
|
509
|
+
cache[key] = i
|
510
|
+
first_key[k] ||= key if categorical
|
511
|
+
i += 1
|
512
|
+
end
|
513
|
+
|
514
|
+
row[key] = v2
|
515
|
+
end
|
516
|
+
rows << row
|
517
|
+
end
|
518
|
+
|
519
|
+
if train
|
520
|
+
# remove one degree of freedom
|
521
|
+
first_key.values.each do |v|
|
522
|
+
num = cache.delete(v)
|
523
|
+
cache.each do |k, v2|
|
524
|
+
cache[k] -= 1 if v2 > num
|
525
|
+
end
|
526
|
+
end
|
527
|
+
end
|
528
|
+
|
529
|
+
ret2 = []
|
530
|
+
rows.each do |row|
|
531
|
+
ret = [0] * cache.size
|
532
|
+
row.each do |k, v|
|
533
|
+
if cache[k]
|
534
|
+
ret[cache[k]] = v
|
535
|
+
end
|
536
|
+
end
|
537
|
+
ret2 << ([1] + ret)
|
538
|
+
end
|
539
|
+
|
540
|
+
# flatten keys
|
541
|
+
c = [:_intercept] + cache.sort_by { |_, v| v }.map(&:first)
|
542
|
+
|
543
|
+
unless supports_categorical
|
544
|
+
c = c.map { |v| symbolize_coef(v) }
|
545
|
+
end
|
546
|
+
|
547
|
+
[ret2, c]
|
548
|
+
end
|
549
|
+
|
550
|
+
def symbolize_coef(k)
|
551
|
+
(k.is_a?(Array) ? k.join("") : k).to_sym
|
552
|
+
end
|
553
|
+
|
554
|
+
def matrix_arr(matrix)
|
555
|
+
matrix.to_a.map { |xi| xi[0].to_f }
|
556
|
+
end
|
557
|
+
end
|
558
|
+
end
|