eps 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -5
- data/README.md +34 -0
- data/lib/eps.rb +19 -10
- data/lib/eps/base_estimator.rb +35 -129
- data/lib/eps/data_frame.rb +7 -1
- data/lib/eps/evaluators/linear_regression.rb +1 -1
- data/lib/eps/label_encoder.rb +7 -3
- data/lib/eps/lightgbm.rb +36 -76
- data/lib/eps/linear_regression.rb +26 -79
- data/lib/eps/metrics.rb +24 -12
- data/lib/eps/model.rb +6 -6
- data/lib/eps/naive_bayes.rb +2 -139
- data/lib/eps/pmml.rb +14 -0
- data/lib/eps/pmml/generator.rb +422 -0
- data/lib/eps/pmml/loader.rb +241 -0
- data/lib/eps/version.rb +1 -1
- metadata +7 -5
- data/lib/eps/pmml_generators/lightgbm.rb +0 -187
data/lib/eps/lightgbm.rb
CHANGED
@@ -1,39 +1,5 @@
|
|
1
|
-
require "eps/pmml_generators/lightgbm"
|
2
|
-
|
3
1
|
module Eps
|
4
2
|
class LightGBM < BaseEstimator
|
5
|
-
include PmmlGenerators::LightGBM
|
6
|
-
|
7
|
-
def self.load_pmml(data)
|
8
|
-
super do |data|
|
9
|
-
objective = data.css("MiningModel").first.attribute("functionName").value
|
10
|
-
if objective == "classification"
|
11
|
-
labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
|
12
|
-
objective = labels.size > 2 ? "multiclass" : "binary"
|
13
|
-
end
|
14
|
-
|
15
|
-
features = {}
|
16
|
-
text_features, derived_fields = extract_text_features(data, features)
|
17
|
-
node = data.css("DataDictionary").first
|
18
|
-
node.css("DataField")[1..-1].to_a.each do |node|
|
19
|
-
features[node.attribute("name").value] =
|
20
|
-
if node.attribute("optype").value == "categorical"
|
21
|
-
"categorical"
|
22
|
-
else
|
23
|
-
"numeric"
|
24
|
-
end
|
25
|
-
end
|
26
|
-
|
27
|
-
trees = []
|
28
|
-
data.css("Segmentation TreeModel").each do |tree|
|
29
|
-
node = find_nodes(tree.css("Node").first, derived_fields)
|
30
|
-
trees << node
|
31
|
-
end
|
32
|
-
|
33
|
-
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
|
34
|
-
end
|
35
|
-
end
|
36
|
-
|
37
3
|
private
|
38
4
|
|
39
5
|
def _summary(extended: false)
|
@@ -51,48 +17,16 @@ module Eps
|
|
51
17
|
str
|
52
18
|
end
|
53
19
|
|
54
|
-
def self.find_nodes(xml, derived_fields)
|
55
|
-
score = BigDecimal(xml.attribute("score").value).to_f
|
56
|
-
|
57
|
-
elements = xml.elements
|
58
|
-
xml_predicate = elements.first
|
59
|
-
|
60
|
-
predicate =
|
61
|
-
if xml_predicate.name == "True"
|
62
|
-
nil
|
63
|
-
elsif xml_predicate.name == "SimpleSetPredicate"
|
64
|
-
operator = "in"
|
65
|
-
value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
|
66
|
-
field = xml_predicate.attribute("field").value
|
67
|
-
field = derived_fields[field] if derived_fields[field]
|
68
|
-
{
|
69
|
-
field: field,
|
70
|
-
operator: operator,
|
71
|
-
value: value
|
72
|
-
}
|
73
|
-
else
|
74
|
-
operator = xml_predicate.attribute("operator").value
|
75
|
-
value = xml_predicate.attribute("value").value
|
76
|
-
value = BigDecimal(value).to_f if operator == "greaterThan"
|
77
|
-
field = xml_predicate.attribute("field").value
|
78
|
-
field = derived_fields[field] if derived_fields[field]
|
79
|
-
{
|
80
|
-
field: field,
|
81
|
-
operator: operator,
|
82
|
-
value: value
|
83
|
-
}
|
84
|
-
end
|
85
|
-
|
86
|
-
children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
|
87
|
-
|
88
|
-
Evaluators::Node.new(score: score, predicate: predicate, children: children)
|
89
|
-
end
|
90
|
-
|
91
20
|
def _train(verbose: nil, early_stopping: nil)
|
92
21
|
train_set = @train_set
|
93
22
|
validation_set = @validation_set.dup
|
94
23
|
summary_label = train_set.label
|
95
24
|
|
25
|
+
# create check set
|
26
|
+
evaluator_set = validation_set || train_set
|
27
|
+
check_idx = 100.times.map { rand(evaluator_set.size) }.uniq
|
28
|
+
evaluator_set = evaluator_set[check_idx]
|
29
|
+
|
96
30
|
# objective
|
97
31
|
objective =
|
98
32
|
if @target_type == "numeric"
|
@@ -135,8 +69,8 @@ module Eps
|
|
135
69
|
|
136
70
|
# create datasets
|
137
71
|
categorical_idx = @features.values.map.with_index.select { |type, _| type == "categorical" }.map(&:last)
|
138
|
-
train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, categorical_feature: categorical_idx, params: params)
|
139
|
-
validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set
|
72
|
+
train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, weight: train_set.weight, categorical_feature: categorical_idx, params: params)
|
73
|
+
validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, weight: validation_set.weight, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set
|
140
74
|
|
141
75
|
# train
|
142
76
|
valid_sets = [train_ds]
|
@@ -176,11 +110,37 @@ module Eps
|
|
176
110
|
# reset pmml
|
177
111
|
@pmml = nil
|
178
112
|
|
179
|
-
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features)
|
113
|
+
evaluator = Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features)
|
114
|
+
booster_set = validation_set ? validation_set[check_idx] : train_set[check_idx]
|
115
|
+
check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set)
|
116
|
+
evaluator
|
180
117
|
end
|
181
118
|
|
182
|
-
|
183
|
-
|
119
|
+
# compare a subset of predictions to check for possible bugs in evaluator
|
120
|
+
# NOTE LightGBM must use double data type for prediction input for these to be consistent
|
121
|
+
def check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set)
|
122
|
+
expected = @booster.predict(booster_set.map_rows(&:to_a))
|
123
|
+
if objective == "multiclass"
|
124
|
+
expected.map! do |v|
|
125
|
+
labels[v.map.with_index.max_by { |v2, _| v2 }.last]
|
126
|
+
end
|
127
|
+
elsif objective == "binary"
|
128
|
+
expected.map! { |v| labels[v >= 0.5 ? 1 : 0] }
|
129
|
+
end
|
130
|
+
actual = evaluator.predict(evaluator_set)
|
131
|
+
|
132
|
+
regression = objective == "regression"
|
133
|
+
bad_observations = []
|
134
|
+
expected.zip(actual).each_with_index do |(exp, act), i|
|
135
|
+
success = regression ? (act - exp).abs < 0.001 : act == exp
|
136
|
+
unless success
|
137
|
+
bad_observations << {expected: exp, actual: act, data_point: evaluator_set[i].map(&:itself).first}
|
138
|
+
end
|
139
|
+
end
|
140
|
+
|
141
|
+
if bad_observations.any?
|
142
|
+
raise "Bug detected in evaluator. Please report an issue. Bad data points: #{bad_observations.inspect}"
|
143
|
+
end
|
184
144
|
end
|
185
145
|
|
186
146
|
# for evaluator
|
@@ -1,40 +1,5 @@
|
|
1
1
|
module Eps
|
2
2
|
class LinearRegression < BaseEstimator
|
3
|
-
# pmml
|
4
|
-
|
5
|
-
def self.load_pmml(data)
|
6
|
-
super do |data|
|
7
|
-
# TODO more validation
|
8
|
-
node = data.css("RegressionTable")
|
9
|
-
|
10
|
-
coefficients = {
|
11
|
-
"_intercept" => node.attribute("intercept").value.to_f
|
12
|
-
}
|
13
|
-
|
14
|
-
features = {}
|
15
|
-
|
16
|
-
text_features, derived_fields = extract_text_features(data, features)
|
17
|
-
|
18
|
-
node.css("NumericPredictor").each do |n|
|
19
|
-
name = n.attribute("name").value
|
20
|
-
if derived_fields[name]
|
21
|
-
name = derived_fields[name]
|
22
|
-
else
|
23
|
-
features[name] = "numeric"
|
24
|
-
end
|
25
|
-
coefficients[name] = n.attribute("coefficient").value.to_f
|
26
|
-
end
|
27
|
-
|
28
|
-
node.css("CategoricalPredictor").each do |n|
|
29
|
-
name = n.attribute("name").value
|
30
|
-
coefficients[[name, n.attribute("value").value]] = n.attribute("coefficient").value.to_f
|
31
|
-
features[name] = "categorical"
|
32
|
-
end
|
33
|
-
|
34
|
-
Evaluators::LinearRegression.new(coefficients: coefficients, features: features, text_features: text_features)
|
35
|
-
end
|
36
|
-
end
|
37
|
-
|
38
3
|
def coefficients
|
39
4
|
@evaluator.coefficients
|
40
5
|
end
|
@@ -84,9 +49,12 @@ module Eps
|
|
84
49
|
end
|
85
50
|
|
86
51
|
x = data.map_rows(&:to_a)
|
87
|
-
|
88
|
-
|
89
|
-
|
52
|
+
|
53
|
+
intercept = @options.key?(:intercept) ? @options[:intercept] : true
|
54
|
+
if intercept
|
55
|
+
data.size.times do |i|
|
56
|
+
x[i].unshift(1)
|
57
|
+
end
|
90
58
|
end
|
91
59
|
|
92
60
|
gsl = options.key?(:gsl) ? options[:gsl] : defined?(GSL)
|
@@ -95,22 +63,32 @@ module Eps
|
|
95
63
|
if gsl
|
96
64
|
x = GSL::Matrix.alloc(*x)
|
97
65
|
y = GSL::Vector.alloc(data.label)
|
98
|
-
|
66
|
+
w = GSL::Vector.alloc(data.weight) if data.weight
|
67
|
+
c, @covariance, _, _ = w ? GSL::MultiFit.wlinear(x, w, y) : GSL::MultiFit.linear(x, y)
|
99
68
|
c.to_a
|
100
69
|
else
|
101
70
|
x = Matrix.rows(x)
|
102
71
|
y = Matrix.column_vector(data.label)
|
72
|
+
|
73
|
+
# weighted OLS
|
74
|
+
# http://www.real-statistics.com/multiple-regression/weighted-linear-regression/weighted-regression-basics/
|
75
|
+
w = Matrix.diagonal(*data.weight) if data.weight
|
76
|
+
|
103
77
|
removed = []
|
104
78
|
|
105
79
|
# https://statsmaths.github.io/stat612/lectures/lec13/lecture13.pdf
|
106
|
-
#
|
80
|
+
# unfortunately, this method is unstable
|
107
81
|
# haven't found an efficient way to do QR-factorization in Ruby
|
108
82
|
# the extendmatrix gem has householder and givens (givens has bug)
|
109
83
|
# but methods are too slow
|
110
84
|
xt = x.t
|
85
|
+
xt *= w if w
|
111
86
|
begin
|
112
87
|
@xtxi = (xt * x).inverse
|
113
88
|
rescue ExceptionForMatrix::ErrNotRegular
|
89
|
+
# matrix cannot be inverted
|
90
|
+
# https://en.wikipedia.org/wiki/Multicollinearity
|
91
|
+
|
114
92
|
constant = {}
|
115
93
|
(1...x.column_count).each do |i|
|
116
94
|
constant[i] = constant?(x.column(i))
|
@@ -134,6 +112,7 @@ module Eps
|
|
134
112
|
end
|
135
113
|
x = Matrix.columns(vectors)
|
136
114
|
xt = x.t
|
115
|
+
xt *= w if w
|
137
116
|
|
138
117
|
# try again
|
139
118
|
begin
|
@@ -144,6 +123,7 @@ module Eps
|
|
144
123
|
end
|
145
124
|
# huge performance boost
|
146
125
|
# by multiplying xt * y first
|
126
|
+
# for weighted, w is already included in wt
|
147
127
|
v2 = @xtxi * (xt * y)
|
148
128
|
|
149
129
|
# convert to array
|
@@ -158,47 +138,14 @@ module Eps
|
|
158
138
|
v2
|
159
139
|
end
|
160
140
|
|
161
|
-
@
|
162
|
-
|
163
|
-
Evaluators::LinearRegression.new(coefficients: @coefficients, features: @features, text_features: @text_features)
|
164
|
-
end
|
165
|
-
|
166
|
-
def generate_pmml
|
167
|
-
predictors = @coefficients.dup
|
168
|
-
predictors.delete("_intercept")
|
169
|
-
|
170
|
-
data_fields = {}
|
171
|
-
@features.each do |k, type|
|
172
|
-
if type == "categorical"
|
173
|
-
data_fields[k] = predictors.keys.select { |k, v| k.is_a?(Array) && k.first == k }.map(&:last)
|
174
|
-
else
|
175
|
-
data_fields[k] = nil
|
176
|
-
end
|
141
|
+
if @xtxi && @xtxi.each(:diagonal).any? { |v| v < 0 }
|
142
|
+
raise UnstableSolution, "GSL is needed to find a stable solution for this dataset"
|
177
143
|
end
|
178
144
|
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
xml.MiningField(name: k)
|
184
|
-
end
|
185
|
-
end
|
186
|
-
pmml_local_transformations(xml)
|
187
|
-
xml.RegressionTable(intercept: @coefficients["_intercept"]) do
|
188
|
-
predictors.each do |k, v|
|
189
|
-
if k.is_a?(Array)
|
190
|
-
if @features[k.first] == "text"
|
191
|
-
xml.NumericPredictor(name: display_field(k), coefficient: v)
|
192
|
-
else
|
193
|
-
xml.CategoricalPredictor(name: k[0], value: k[1], coefficient: v)
|
194
|
-
end
|
195
|
-
else
|
196
|
-
xml.NumericPredictor(name: k, coefficient: v)
|
197
|
-
end
|
198
|
-
end
|
199
|
-
end
|
200
|
-
end
|
201
|
-
end
|
145
|
+
@coefficient_names = data.columns.keys
|
146
|
+
@coefficient_names.unshift("_intercept") if intercept
|
147
|
+
@coefficients = Hash[@coefficient_names.zip(v3)]
|
148
|
+
Evaluators::LinearRegression.new(coefficients: @coefficients, features: @features, text_features: @text_features)
|
202
149
|
end
|
203
150
|
|
204
151
|
def prep_x(x)
|
data/lib/eps/metrics.rb
CHANGED
@@ -1,31 +1,39 @@
|
|
1
1
|
module Eps
|
2
2
|
module Metrics
|
3
3
|
class << self
|
4
|
-
def rmse(y_true, y_pred)
|
4
|
+
def rmse(y_true, y_pred, weight: nil)
|
5
5
|
check_size(y_true, y_pred)
|
6
|
-
Math.sqrt(mean(errors(y_true, y_pred).map { |v| v**2 }))
|
6
|
+
Math.sqrt(mean(errors(y_true, y_pred).map { |v| v**2 }, weight: weight))
|
7
7
|
end
|
8
8
|
|
9
|
-
def mae(y_true, y_pred)
|
9
|
+
def mae(y_true, y_pred, weight: nil)
|
10
10
|
check_size(y_true, y_pred)
|
11
|
-
mean(errors(y_true, y_pred).map { |v| v.abs })
|
11
|
+
mean(errors(y_true, y_pred).map { |v| v.abs }, weight: weight)
|
12
12
|
end
|
13
13
|
|
14
|
-
def me(y_true, y_pred)
|
14
|
+
def me(y_true, y_pred, weight: nil)
|
15
15
|
check_size(y_true, y_pred)
|
16
|
-
mean(errors(y_true, y_pred))
|
16
|
+
mean(errors(y_true, y_pred), weight: weight)
|
17
17
|
end
|
18
18
|
|
19
|
-
def accuracy(y_true, y_pred)
|
19
|
+
def accuracy(y_true, y_pred, weight: nil)
|
20
20
|
check_size(y_true, y_pred)
|
21
|
-
y_true.zip(y_pred).
|
21
|
+
values = y_true.zip(y_pred).map { |yt, yp| yt == yp ? 1 : 0 }
|
22
|
+
if weight
|
23
|
+
values.each_with_index do |v, i|
|
24
|
+
values[i] *= weight[i]
|
25
|
+
end
|
26
|
+
values.sum / weight.sum.to_f
|
27
|
+
else
|
28
|
+
values.sum / y_true.size.to_f
|
29
|
+
end
|
22
30
|
end
|
23
31
|
|
24
32
|
# http://wiki.fast.ai/index.php/Log_Loss
|
25
|
-
def log_loss(y_true, y_pred, eps: 1e-15)
|
33
|
+
def log_loss(y_true, y_pred, eps: 1e-15, weight: nil)
|
26
34
|
check_size(y_true, y_pred)
|
27
35
|
p = y_pred.map { |yp| yp.clamp(eps, 1 - eps) }
|
28
|
-
mean(y_true.zip(p).map { |yt, pi| yt == 1 ? -Math.log(pi) : -Math.log(1 - pi) })
|
36
|
+
mean(y_true.zip(p).map { |yt, pi| yt == 1 ? -Math.log(pi) : -Math.log(1 - pi) }, weight: weight)
|
29
37
|
end
|
30
38
|
|
31
39
|
private
|
@@ -34,8 +42,12 @@ module Eps
|
|
34
42
|
raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
|
35
43
|
end
|
36
44
|
|
37
|
-
def mean(arr)
|
38
|
-
|
45
|
+
def mean(arr, weight: nil)
|
46
|
+
if weight
|
47
|
+
arr.map.with_index { |v, i| v * weight[i] }.sum / weight.sum.to_f
|
48
|
+
else
|
49
|
+
arr.sum / arr.size.to_f
|
50
|
+
end
|
39
51
|
end
|
40
52
|
|
41
53
|
def errors(y_true, y_pred)
|
data/lib/eps/model.rb
CHANGED
@@ -17,11 +17,11 @@ module Eps
|
|
17
17
|
|
18
18
|
estimator_class =
|
19
19
|
if data.css("Segmentation").any?
|
20
|
-
|
20
|
+
LightGBM
|
21
21
|
elsif data.css("RegressionModel").any?
|
22
|
-
|
22
|
+
LinearRegression
|
23
23
|
elsif data.css("NaiveBayesModel").any?
|
24
|
-
|
24
|
+
NaiveBayes
|
25
25
|
else
|
26
26
|
raise "Unknown model"
|
27
27
|
end
|
@@ -35,11 +35,11 @@ module Eps
|
|
35
35
|
estimator_class =
|
36
36
|
case algorithm
|
37
37
|
when :lightgbm
|
38
|
-
|
38
|
+
LightGBM
|
39
39
|
when :linear_regression
|
40
|
-
|
40
|
+
LinearRegression
|
41
41
|
when :naive_bayes
|
42
|
-
|
42
|
+
NaiveBayes
|
43
43
|
else
|
44
44
|
raise ArgumentError, "Unknown algorithm: #{algorithm}"
|
45
45
|
end
|
data/lib/eps/naive_bayes.rb
CHANGED
@@ -3,91 +3,7 @@ module Eps
|
|
3
3
|
attr_reader :probabilities
|
4
4
|
|
5
5
|
def accuracy
|
6
|
-
Eps::Metrics.accuracy(@train_set.label, predict(@train_set))
|
7
|
-
end
|
8
|
-
|
9
|
-
# pmml
|
10
|
-
|
11
|
-
def self.load_pmml(data)
|
12
|
-
super do |data|
|
13
|
-
# TODO more validation
|
14
|
-
node = data.css("NaiveBayesModel")
|
15
|
-
|
16
|
-
prior = {}
|
17
|
-
node.css("BayesOutput TargetValueCount").each do |n|
|
18
|
-
prior[n.attribute("value").value] = n.attribute("count").value.to_f
|
19
|
-
end
|
20
|
-
|
21
|
-
legacy = false
|
22
|
-
|
23
|
-
conditional = {}
|
24
|
-
features = {}
|
25
|
-
node.css("BayesInput").each do |n|
|
26
|
-
prob = {}
|
27
|
-
|
28
|
-
# numeric
|
29
|
-
n.css("TargetValueStat").each do |n2|
|
30
|
-
n3 = n2.css("GaussianDistribution")
|
31
|
-
prob[n2.attribute("value").value] = {
|
32
|
-
mean: n3.attribute("mean").value.to_f,
|
33
|
-
stdev: Math.sqrt(n3.attribute("variance").value.to_f)
|
34
|
-
}
|
35
|
-
end
|
36
|
-
|
37
|
-
# detect bad form in Eps < 0.3
|
38
|
-
bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
|
39
|
-
|
40
|
-
# categorical
|
41
|
-
n.css("PairCounts").each do |n2|
|
42
|
-
if bad_format
|
43
|
-
n2.css("TargetValueCount").each do |n3|
|
44
|
-
prob[n3.attribute("value").value] ||= {}
|
45
|
-
prob[n3.attribute("value").value][n2.attribute("value").value] = BigDecimal(n3.attribute("count").value)
|
46
|
-
end
|
47
|
-
else
|
48
|
-
boom = {}
|
49
|
-
n2.css("TargetValueCount").each do |n3|
|
50
|
-
boom[n3.attribute("value").value] = BigDecimal(n3.attribute("count").value)
|
51
|
-
end
|
52
|
-
prob[n2.attribute("value").value] = boom
|
53
|
-
end
|
54
|
-
end
|
55
|
-
|
56
|
-
if bad_format
|
57
|
-
legacy = true
|
58
|
-
prob.each do |k, v|
|
59
|
-
prior.keys.each do |k|
|
60
|
-
v[k] ||= 0.0
|
61
|
-
end
|
62
|
-
end
|
63
|
-
end
|
64
|
-
|
65
|
-
name = n.attribute("fieldName").value
|
66
|
-
conditional[name] = prob
|
67
|
-
features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
|
68
|
-
end
|
69
|
-
|
70
|
-
target = node.css("BayesOutput").attribute("fieldName").value
|
71
|
-
|
72
|
-
probabilities = {
|
73
|
-
prior: prior,
|
74
|
-
conditional: conditional
|
75
|
-
}
|
76
|
-
|
77
|
-
# get derived fields
|
78
|
-
derived = {}
|
79
|
-
data.css("DerivedField").each do |n|
|
80
|
-
name = n.attribute("name").value
|
81
|
-
field = n.css("NormDiscrete").attribute("field").value
|
82
|
-
value = n.css("NormDiscrete").attribute("value").value
|
83
|
-
features.delete(name)
|
84
|
-
features[field] = "derived"
|
85
|
-
derived[field] ||= {}
|
86
|
-
derived[field][name] = value
|
87
|
-
end
|
88
|
-
|
89
|
-
Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
|
90
|
-
end
|
6
|
+
Eps::Metrics.accuracy(@train_set.label, predict(@train_set), weight: @train_set.weight)
|
91
7
|
end
|
92
8
|
|
93
9
|
private
|
@@ -105,6 +21,7 @@ module Eps
|
|
105
21
|
raise "Target must be strings" if @target_type != "categorical"
|
106
22
|
check_missing_value(@train_set)
|
107
23
|
check_missing_value(@validation_set) if @validation_set
|
24
|
+
raise ArgumentError, "weight not supported" if @train_set.weight
|
108
25
|
|
109
26
|
data = @train_set
|
110
27
|
|
@@ -185,60 +102,6 @@ module Eps
|
|
185
102
|
Evaluators::NaiveBayes.new(probabilities: probabilities, features: @features)
|
186
103
|
end
|
187
104
|
|
188
|
-
def generate_pmml
|
189
|
-
data_fields = {}
|
190
|
-
data_fields[@target] = probabilities[:prior].keys
|
191
|
-
probabilities[:conditional].each do |k, v|
|
192
|
-
if @features[k] == "categorical"
|
193
|
-
data_fields[k] = v.keys
|
194
|
-
else
|
195
|
-
data_fields[k] = nil
|
196
|
-
end
|
197
|
-
end
|
198
|
-
|
199
|
-
build_pmml(data_fields) do |xml|
|
200
|
-
xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
|
201
|
-
xml.MiningSchema do
|
202
|
-
data_fields.each do |k, _|
|
203
|
-
xml.MiningField(name: k)
|
204
|
-
end
|
205
|
-
end
|
206
|
-
xml.BayesInputs do
|
207
|
-
probabilities[:conditional].each do |k, v|
|
208
|
-
xml.BayesInput(fieldName: k) do
|
209
|
-
if @features[k] == "categorical"
|
210
|
-
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
211
|
-
xml.PairCounts(value: k2) do
|
212
|
-
xml.TargetValueCounts do
|
213
|
-
v2.sort_by { |k2, _| k2 }.each do |k3, v3|
|
214
|
-
xml.TargetValueCount(value: k3, count: v3)
|
215
|
-
end
|
216
|
-
end
|
217
|
-
end
|
218
|
-
end
|
219
|
-
else
|
220
|
-
xml.TargetValueStats do
|
221
|
-
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
222
|
-
xml.TargetValueStat(value: k2) do
|
223
|
-
xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
|
224
|
-
end
|
225
|
-
end
|
226
|
-
end
|
227
|
-
end
|
228
|
-
end
|
229
|
-
end
|
230
|
-
end
|
231
|
-
xml.BayesOutput(fieldName: "target") do
|
232
|
-
xml.TargetValueCounts do
|
233
|
-
probabilities[:prior].sort_by { |k, _| k }.each do |k, v|
|
234
|
-
xml.TargetValueCount(value: k, count: v)
|
235
|
-
end
|
236
|
-
end
|
237
|
-
end
|
238
|
-
end
|
239
|
-
end
|
240
|
-
end
|
241
|
-
|
242
105
|
def group_count(arr, start)
|
243
106
|
arr.inject(start) { |h, e| h[e] += 1; h }
|
244
107
|
end
|