eps 0.3.0 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -5
- data/README.md +34 -0
- data/lib/eps.rb +19 -10
- data/lib/eps/base_estimator.rb +35 -129
- data/lib/eps/data_frame.rb +7 -1
- data/lib/eps/evaluators/linear_regression.rb +1 -1
- data/lib/eps/label_encoder.rb +7 -3
- data/lib/eps/lightgbm.rb +36 -76
- data/lib/eps/linear_regression.rb +26 -79
- data/lib/eps/metrics.rb +24 -12
- data/lib/eps/model.rb +6 -6
- data/lib/eps/naive_bayes.rb +2 -139
- data/lib/eps/pmml.rb +14 -0
- data/lib/eps/pmml/generator.rb +422 -0
- data/lib/eps/pmml/loader.rb +241 -0
- data/lib/eps/version.rb +1 -1
- metadata +7 -5
- data/lib/eps/pmml_generators/lightgbm.rb +0 -187
data/lib/eps/lightgbm.rb
CHANGED
@@ -1,39 +1,5 @@
|
|
1
|
-
require "eps/pmml_generators/lightgbm"
|
2
|
-
|
3
1
|
module Eps
|
4
2
|
class LightGBM < BaseEstimator
|
5
|
-
include PmmlGenerators::LightGBM
|
6
|
-
|
7
|
-
def self.load_pmml(data)
|
8
|
-
super do |data|
|
9
|
-
objective = data.css("MiningModel").first.attribute("functionName").value
|
10
|
-
if objective == "classification"
|
11
|
-
labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
|
12
|
-
objective = labels.size > 2 ? "multiclass" : "binary"
|
13
|
-
end
|
14
|
-
|
15
|
-
features = {}
|
16
|
-
text_features, derived_fields = extract_text_features(data, features)
|
17
|
-
node = data.css("DataDictionary").first
|
18
|
-
node.css("DataField")[1..-1].to_a.each do |node|
|
19
|
-
features[node.attribute("name").value] =
|
20
|
-
if node.attribute("optype").value == "categorical"
|
21
|
-
"categorical"
|
22
|
-
else
|
23
|
-
"numeric"
|
24
|
-
end
|
25
|
-
end
|
26
|
-
|
27
|
-
trees = []
|
28
|
-
data.css("Segmentation TreeModel").each do |tree|
|
29
|
-
node = find_nodes(tree.css("Node").first, derived_fields)
|
30
|
-
trees << node
|
31
|
-
end
|
32
|
-
|
33
|
-
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
|
34
|
-
end
|
35
|
-
end
|
36
|
-
|
37
3
|
private
|
38
4
|
|
39
5
|
def _summary(extended: false)
|
@@ -51,48 +17,16 @@ module Eps
|
|
51
17
|
str
|
52
18
|
end
|
53
19
|
|
54
|
-
def self.find_nodes(xml, derived_fields)
|
55
|
-
score = BigDecimal(xml.attribute("score").value).to_f
|
56
|
-
|
57
|
-
elements = xml.elements
|
58
|
-
xml_predicate = elements.first
|
59
|
-
|
60
|
-
predicate =
|
61
|
-
if xml_predicate.name == "True"
|
62
|
-
nil
|
63
|
-
elsif xml_predicate.name == "SimpleSetPredicate"
|
64
|
-
operator = "in"
|
65
|
-
value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
|
66
|
-
field = xml_predicate.attribute("field").value
|
67
|
-
field = derived_fields[field] if derived_fields[field]
|
68
|
-
{
|
69
|
-
field: field,
|
70
|
-
operator: operator,
|
71
|
-
value: value
|
72
|
-
}
|
73
|
-
else
|
74
|
-
operator = xml_predicate.attribute("operator").value
|
75
|
-
value = xml_predicate.attribute("value").value
|
76
|
-
value = BigDecimal(value).to_f if operator == "greaterThan"
|
77
|
-
field = xml_predicate.attribute("field").value
|
78
|
-
field = derived_fields[field] if derived_fields[field]
|
79
|
-
{
|
80
|
-
field: field,
|
81
|
-
operator: operator,
|
82
|
-
value: value
|
83
|
-
}
|
84
|
-
end
|
85
|
-
|
86
|
-
children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
|
87
|
-
|
88
|
-
Evaluators::Node.new(score: score, predicate: predicate, children: children)
|
89
|
-
end
|
90
|
-
|
91
20
|
def _train(verbose: nil, early_stopping: nil)
|
92
21
|
train_set = @train_set
|
93
22
|
validation_set = @validation_set.dup
|
94
23
|
summary_label = train_set.label
|
95
24
|
|
25
|
+
# create check set
|
26
|
+
evaluator_set = validation_set || train_set
|
27
|
+
check_idx = 100.times.map { rand(evaluator_set.size) }.uniq
|
28
|
+
evaluator_set = evaluator_set[check_idx]
|
29
|
+
|
96
30
|
# objective
|
97
31
|
objective =
|
98
32
|
if @target_type == "numeric"
|
@@ -135,8 +69,8 @@ module Eps
|
|
135
69
|
|
136
70
|
# create datasets
|
137
71
|
categorical_idx = @features.values.map.with_index.select { |type, _| type == "categorical" }.map(&:last)
|
138
|
-
train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, categorical_feature: categorical_idx, params: params)
|
139
|
-
validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set
|
72
|
+
train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, weight: train_set.weight, categorical_feature: categorical_idx, params: params)
|
73
|
+
validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, weight: validation_set.weight, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set
|
140
74
|
|
141
75
|
# train
|
142
76
|
valid_sets = [train_ds]
|
@@ -176,11 +110,37 @@ module Eps
|
|
176
110
|
# reset pmml
|
177
111
|
@pmml = nil
|
178
112
|
|
179
|
-
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features)
|
113
|
+
evaluator = Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features)
|
114
|
+
booster_set = validation_set ? validation_set[check_idx] : train_set[check_idx]
|
115
|
+
check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set)
|
116
|
+
evaluator
|
180
117
|
end
|
181
118
|
|
182
|
-
|
183
|
-
|
119
|
+
# compare a subset of predictions to check for possible bugs in evaluator
|
120
|
+
# NOTE LightGBM must use double data type for prediction input for these to be consistent
|
121
|
+
def check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set)
|
122
|
+
expected = @booster.predict(booster_set.map_rows(&:to_a))
|
123
|
+
if objective == "multiclass"
|
124
|
+
expected.map! do |v|
|
125
|
+
labels[v.map.with_index.max_by { |v2, _| v2 }.last]
|
126
|
+
end
|
127
|
+
elsif objective == "binary"
|
128
|
+
expected.map! { |v| labels[v >= 0.5 ? 1 : 0] }
|
129
|
+
end
|
130
|
+
actual = evaluator.predict(evaluator_set)
|
131
|
+
|
132
|
+
regression = objective == "regression"
|
133
|
+
bad_observations = []
|
134
|
+
expected.zip(actual).each_with_index do |(exp, act), i|
|
135
|
+
success = regression ? (act - exp).abs < 0.001 : act == exp
|
136
|
+
unless success
|
137
|
+
bad_observations << {expected: exp, actual: act, data_point: evaluator_set[i].map(&:itself).first}
|
138
|
+
end
|
139
|
+
end
|
140
|
+
|
141
|
+
if bad_observations.any?
|
142
|
+
raise "Bug detected in evaluator. Please report an issue. Bad data points: #{bad_observations.inspect}"
|
143
|
+
end
|
184
144
|
end
|
185
145
|
|
186
146
|
# for evaluator
|
@@ -1,40 +1,5 @@
|
|
1
1
|
module Eps
|
2
2
|
class LinearRegression < BaseEstimator
|
3
|
-
# pmml
|
4
|
-
|
5
|
-
def self.load_pmml(data)
|
6
|
-
super do |data|
|
7
|
-
# TODO more validation
|
8
|
-
node = data.css("RegressionTable")
|
9
|
-
|
10
|
-
coefficients = {
|
11
|
-
"_intercept" => node.attribute("intercept").value.to_f
|
12
|
-
}
|
13
|
-
|
14
|
-
features = {}
|
15
|
-
|
16
|
-
text_features, derived_fields = extract_text_features(data, features)
|
17
|
-
|
18
|
-
node.css("NumericPredictor").each do |n|
|
19
|
-
name = n.attribute("name").value
|
20
|
-
if derived_fields[name]
|
21
|
-
name = derived_fields[name]
|
22
|
-
else
|
23
|
-
features[name] = "numeric"
|
24
|
-
end
|
25
|
-
coefficients[name] = n.attribute("coefficient").value.to_f
|
26
|
-
end
|
27
|
-
|
28
|
-
node.css("CategoricalPredictor").each do |n|
|
29
|
-
name = n.attribute("name").value
|
30
|
-
coefficients[[name, n.attribute("value").value]] = n.attribute("coefficient").value.to_f
|
31
|
-
features[name] = "categorical"
|
32
|
-
end
|
33
|
-
|
34
|
-
Evaluators::LinearRegression.new(coefficients: coefficients, features: features, text_features: text_features)
|
35
|
-
end
|
36
|
-
end
|
37
|
-
|
38
3
|
def coefficients
|
39
4
|
@evaluator.coefficients
|
40
5
|
end
|
@@ -84,9 +49,12 @@ module Eps
|
|
84
49
|
end
|
85
50
|
|
86
51
|
x = data.map_rows(&:to_a)
|
87
|
-
|
88
|
-
|
89
|
-
|
52
|
+
|
53
|
+
intercept = @options.key?(:intercept) ? @options[:intercept] : true
|
54
|
+
if intercept
|
55
|
+
data.size.times do |i|
|
56
|
+
x[i].unshift(1)
|
57
|
+
end
|
90
58
|
end
|
91
59
|
|
92
60
|
gsl = options.key?(:gsl) ? options[:gsl] : defined?(GSL)
|
@@ -95,22 +63,32 @@ module Eps
|
|
95
63
|
if gsl
|
96
64
|
x = GSL::Matrix.alloc(*x)
|
97
65
|
y = GSL::Vector.alloc(data.label)
|
98
|
-
|
66
|
+
w = GSL::Vector.alloc(data.weight) if data.weight
|
67
|
+
c, @covariance, _, _ = w ? GSL::MultiFit.wlinear(x, w, y) : GSL::MultiFit.linear(x, y)
|
99
68
|
c.to_a
|
100
69
|
else
|
101
70
|
x = Matrix.rows(x)
|
102
71
|
y = Matrix.column_vector(data.label)
|
72
|
+
|
73
|
+
# weighted OLS
|
74
|
+
# http://www.real-statistics.com/multiple-regression/weighted-linear-regression/weighted-regression-basics/
|
75
|
+
w = Matrix.diagonal(*data.weight) if data.weight
|
76
|
+
|
103
77
|
removed = []
|
104
78
|
|
105
79
|
# https://statsmaths.github.io/stat612/lectures/lec13/lecture13.pdf
|
106
|
-
#
|
80
|
+
# unfortunately, this method is unstable
|
107
81
|
# haven't found an efficient way to do QR-factorization in Ruby
|
108
82
|
# the extendmatrix gem has householder and givens (givens has bug)
|
109
83
|
# but methods are too slow
|
110
84
|
xt = x.t
|
85
|
+
xt *= w if w
|
111
86
|
begin
|
112
87
|
@xtxi = (xt * x).inverse
|
113
88
|
rescue ExceptionForMatrix::ErrNotRegular
|
89
|
+
# matrix cannot be inverted
|
90
|
+
# https://en.wikipedia.org/wiki/Multicollinearity
|
91
|
+
|
114
92
|
constant = {}
|
115
93
|
(1...x.column_count).each do |i|
|
116
94
|
constant[i] = constant?(x.column(i))
|
@@ -134,6 +112,7 @@ module Eps
|
|
134
112
|
end
|
135
113
|
x = Matrix.columns(vectors)
|
136
114
|
xt = x.t
|
115
|
+
xt *= w if w
|
137
116
|
|
138
117
|
# try again
|
139
118
|
begin
|
@@ -144,6 +123,7 @@ module Eps
|
|
144
123
|
end
|
145
124
|
# huge performance boost
|
146
125
|
# by multiplying xt * y first
|
126
|
+
# for weighted, w is already included in wt
|
147
127
|
v2 = @xtxi * (xt * y)
|
148
128
|
|
149
129
|
# convert to array
|
@@ -158,47 +138,14 @@ module Eps
|
|
158
138
|
v2
|
159
139
|
end
|
160
140
|
|
161
|
-
@
|
162
|
-
|
163
|
-
Evaluators::LinearRegression.new(coefficients: @coefficients, features: @features, text_features: @text_features)
|
164
|
-
end
|
165
|
-
|
166
|
-
def generate_pmml
|
167
|
-
predictors = @coefficients.dup
|
168
|
-
predictors.delete("_intercept")
|
169
|
-
|
170
|
-
data_fields = {}
|
171
|
-
@features.each do |k, type|
|
172
|
-
if type == "categorical"
|
173
|
-
data_fields[k] = predictors.keys.select { |k, v| k.is_a?(Array) && k.first == k }.map(&:last)
|
174
|
-
else
|
175
|
-
data_fields[k] = nil
|
176
|
-
end
|
141
|
+
if @xtxi && @xtxi.each(:diagonal).any? { |v| v < 0 }
|
142
|
+
raise UnstableSolution, "GSL is needed to find a stable solution for this dataset"
|
177
143
|
end
|
178
144
|
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
xml.MiningField(name: k)
|
184
|
-
end
|
185
|
-
end
|
186
|
-
pmml_local_transformations(xml)
|
187
|
-
xml.RegressionTable(intercept: @coefficients["_intercept"]) do
|
188
|
-
predictors.each do |k, v|
|
189
|
-
if k.is_a?(Array)
|
190
|
-
if @features[k.first] == "text"
|
191
|
-
xml.NumericPredictor(name: display_field(k), coefficient: v)
|
192
|
-
else
|
193
|
-
xml.CategoricalPredictor(name: k[0], value: k[1], coefficient: v)
|
194
|
-
end
|
195
|
-
else
|
196
|
-
xml.NumericPredictor(name: k, coefficient: v)
|
197
|
-
end
|
198
|
-
end
|
199
|
-
end
|
200
|
-
end
|
201
|
-
end
|
145
|
+
@coefficient_names = data.columns.keys
|
146
|
+
@coefficient_names.unshift("_intercept") if intercept
|
147
|
+
@coefficients = Hash[@coefficient_names.zip(v3)]
|
148
|
+
Evaluators::LinearRegression.new(coefficients: @coefficients, features: @features, text_features: @text_features)
|
202
149
|
end
|
203
150
|
|
204
151
|
def prep_x(x)
|
data/lib/eps/metrics.rb
CHANGED
@@ -1,31 +1,39 @@
|
|
1
1
|
module Eps
|
2
2
|
module Metrics
|
3
3
|
class << self
|
4
|
-
def rmse(y_true, y_pred)
|
4
|
+
def rmse(y_true, y_pred, weight: nil)
|
5
5
|
check_size(y_true, y_pred)
|
6
|
-
Math.sqrt(mean(errors(y_true, y_pred).map { |v| v**2 }))
|
6
|
+
Math.sqrt(mean(errors(y_true, y_pred).map { |v| v**2 }, weight: weight))
|
7
7
|
end
|
8
8
|
|
9
|
-
def mae(y_true, y_pred)
|
9
|
+
def mae(y_true, y_pred, weight: nil)
|
10
10
|
check_size(y_true, y_pred)
|
11
|
-
mean(errors(y_true, y_pred).map { |v| v.abs })
|
11
|
+
mean(errors(y_true, y_pred).map { |v| v.abs }, weight: weight)
|
12
12
|
end
|
13
13
|
|
14
|
-
def me(y_true, y_pred)
|
14
|
+
def me(y_true, y_pred, weight: nil)
|
15
15
|
check_size(y_true, y_pred)
|
16
|
-
mean(errors(y_true, y_pred))
|
16
|
+
mean(errors(y_true, y_pred), weight: weight)
|
17
17
|
end
|
18
18
|
|
19
|
-
def accuracy(y_true, y_pred)
|
19
|
+
def accuracy(y_true, y_pred, weight: nil)
|
20
20
|
check_size(y_true, y_pred)
|
21
|
-
y_true.zip(y_pred).
|
21
|
+
values = y_true.zip(y_pred).map { |yt, yp| yt == yp ? 1 : 0 }
|
22
|
+
if weight
|
23
|
+
values.each_with_index do |v, i|
|
24
|
+
values[i] *= weight[i]
|
25
|
+
end
|
26
|
+
values.sum / weight.sum.to_f
|
27
|
+
else
|
28
|
+
values.sum / y_true.size.to_f
|
29
|
+
end
|
22
30
|
end
|
23
31
|
|
24
32
|
# http://wiki.fast.ai/index.php/Log_Loss
|
25
|
-
def log_loss(y_true, y_pred, eps: 1e-15)
|
33
|
+
def log_loss(y_true, y_pred, eps: 1e-15, weight: nil)
|
26
34
|
check_size(y_true, y_pred)
|
27
35
|
p = y_pred.map { |yp| yp.clamp(eps, 1 - eps) }
|
28
|
-
mean(y_true.zip(p).map { |yt, pi| yt == 1 ? -Math.log(pi) : -Math.log(1 - pi) })
|
36
|
+
mean(y_true.zip(p).map { |yt, pi| yt == 1 ? -Math.log(pi) : -Math.log(1 - pi) }, weight: weight)
|
29
37
|
end
|
30
38
|
|
31
39
|
private
|
@@ -34,8 +42,12 @@ module Eps
|
|
34
42
|
raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
|
35
43
|
end
|
36
44
|
|
37
|
-
def mean(arr)
|
38
|
-
|
45
|
+
def mean(arr, weight: nil)
|
46
|
+
if weight
|
47
|
+
arr.map.with_index { |v, i| v * weight[i] }.sum / weight.sum.to_f
|
48
|
+
else
|
49
|
+
arr.sum / arr.size.to_f
|
50
|
+
end
|
39
51
|
end
|
40
52
|
|
41
53
|
def errors(y_true, y_pred)
|
data/lib/eps/model.rb
CHANGED
@@ -17,11 +17,11 @@ module Eps
|
|
17
17
|
|
18
18
|
estimator_class =
|
19
19
|
if data.css("Segmentation").any?
|
20
|
-
|
20
|
+
LightGBM
|
21
21
|
elsif data.css("RegressionModel").any?
|
22
|
-
|
22
|
+
LinearRegression
|
23
23
|
elsif data.css("NaiveBayesModel").any?
|
24
|
-
|
24
|
+
NaiveBayes
|
25
25
|
else
|
26
26
|
raise "Unknown model"
|
27
27
|
end
|
@@ -35,11 +35,11 @@ module Eps
|
|
35
35
|
estimator_class =
|
36
36
|
case algorithm
|
37
37
|
when :lightgbm
|
38
|
-
|
38
|
+
LightGBM
|
39
39
|
when :linear_regression
|
40
|
-
|
40
|
+
LinearRegression
|
41
41
|
when :naive_bayes
|
42
|
-
|
42
|
+
NaiveBayes
|
43
43
|
else
|
44
44
|
raise ArgumentError, "Unknown algorithm: #{algorithm}"
|
45
45
|
end
|
data/lib/eps/naive_bayes.rb
CHANGED
@@ -3,91 +3,7 @@ module Eps
|
|
3
3
|
attr_reader :probabilities
|
4
4
|
|
5
5
|
def accuracy
|
6
|
-
Eps::Metrics.accuracy(@train_set.label, predict(@train_set))
|
7
|
-
end
|
8
|
-
|
9
|
-
# pmml
|
10
|
-
|
11
|
-
def self.load_pmml(data)
|
12
|
-
super do |data|
|
13
|
-
# TODO more validation
|
14
|
-
node = data.css("NaiveBayesModel")
|
15
|
-
|
16
|
-
prior = {}
|
17
|
-
node.css("BayesOutput TargetValueCount").each do |n|
|
18
|
-
prior[n.attribute("value").value] = n.attribute("count").value.to_f
|
19
|
-
end
|
20
|
-
|
21
|
-
legacy = false
|
22
|
-
|
23
|
-
conditional = {}
|
24
|
-
features = {}
|
25
|
-
node.css("BayesInput").each do |n|
|
26
|
-
prob = {}
|
27
|
-
|
28
|
-
# numeric
|
29
|
-
n.css("TargetValueStat").each do |n2|
|
30
|
-
n3 = n2.css("GaussianDistribution")
|
31
|
-
prob[n2.attribute("value").value] = {
|
32
|
-
mean: n3.attribute("mean").value.to_f,
|
33
|
-
stdev: Math.sqrt(n3.attribute("variance").value.to_f)
|
34
|
-
}
|
35
|
-
end
|
36
|
-
|
37
|
-
# detect bad form in Eps < 0.3
|
38
|
-
bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
|
39
|
-
|
40
|
-
# categorical
|
41
|
-
n.css("PairCounts").each do |n2|
|
42
|
-
if bad_format
|
43
|
-
n2.css("TargetValueCount").each do |n3|
|
44
|
-
prob[n3.attribute("value").value] ||= {}
|
45
|
-
prob[n3.attribute("value").value][n2.attribute("value").value] = BigDecimal(n3.attribute("count").value)
|
46
|
-
end
|
47
|
-
else
|
48
|
-
boom = {}
|
49
|
-
n2.css("TargetValueCount").each do |n3|
|
50
|
-
boom[n3.attribute("value").value] = BigDecimal(n3.attribute("count").value)
|
51
|
-
end
|
52
|
-
prob[n2.attribute("value").value] = boom
|
53
|
-
end
|
54
|
-
end
|
55
|
-
|
56
|
-
if bad_format
|
57
|
-
legacy = true
|
58
|
-
prob.each do |k, v|
|
59
|
-
prior.keys.each do |k|
|
60
|
-
v[k] ||= 0.0
|
61
|
-
end
|
62
|
-
end
|
63
|
-
end
|
64
|
-
|
65
|
-
name = n.attribute("fieldName").value
|
66
|
-
conditional[name] = prob
|
67
|
-
features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
|
68
|
-
end
|
69
|
-
|
70
|
-
target = node.css("BayesOutput").attribute("fieldName").value
|
71
|
-
|
72
|
-
probabilities = {
|
73
|
-
prior: prior,
|
74
|
-
conditional: conditional
|
75
|
-
}
|
76
|
-
|
77
|
-
# get derived fields
|
78
|
-
derived = {}
|
79
|
-
data.css("DerivedField").each do |n|
|
80
|
-
name = n.attribute("name").value
|
81
|
-
field = n.css("NormDiscrete").attribute("field").value
|
82
|
-
value = n.css("NormDiscrete").attribute("value").value
|
83
|
-
features.delete(name)
|
84
|
-
features[field] = "derived"
|
85
|
-
derived[field] ||= {}
|
86
|
-
derived[field][name] = value
|
87
|
-
end
|
88
|
-
|
89
|
-
Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
|
90
|
-
end
|
6
|
+
Eps::Metrics.accuracy(@train_set.label, predict(@train_set), weight: @train_set.weight)
|
91
7
|
end
|
92
8
|
|
93
9
|
private
|
@@ -105,6 +21,7 @@ module Eps
|
|
105
21
|
raise "Target must be strings" if @target_type != "categorical"
|
106
22
|
check_missing_value(@train_set)
|
107
23
|
check_missing_value(@validation_set) if @validation_set
|
24
|
+
raise ArgumentError, "weight not supported" if @train_set.weight
|
108
25
|
|
109
26
|
data = @train_set
|
110
27
|
|
@@ -185,60 +102,6 @@ module Eps
|
|
185
102
|
Evaluators::NaiveBayes.new(probabilities: probabilities, features: @features)
|
186
103
|
end
|
187
104
|
|
188
|
-
def generate_pmml
|
189
|
-
data_fields = {}
|
190
|
-
data_fields[@target] = probabilities[:prior].keys
|
191
|
-
probabilities[:conditional].each do |k, v|
|
192
|
-
if @features[k] == "categorical"
|
193
|
-
data_fields[k] = v.keys
|
194
|
-
else
|
195
|
-
data_fields[k] = nil
|
196
|
-
end
|
197
|
-
end
|
198
|
-
|
199
|
-
build_pmml(data_fields) do |xml|
|
200
|
-
xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
|
201
|
-
xml.MiningSchema do
|
202
|
-
data_fields.each do |k, _|
|
203
|
-
xml.MiningField(name: k)
|
204
|
-
end
|
205
|
-
end
|
206
|
-
xml.BayesInputs do
|
207
|
-
probabilities[:conditional].each do |k, v|
|
208
|
-
xml.BayesInput(fieldName: k) do
|
209
|
-
if @features[k] == "categorical"
|
210
|
-
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
211
|
-
xml.PairCounts(value: k2) do
|
212
|
-
xml.TargetValueCounts do
|
213
|
-
v2.sort_by { |k2, _| k2 }.each do |k3, v3|
|
214
|
-
xml.TargetValueCount(value: k3, count: v3)
|
215
|
-
end
|
216
|
-
end
|
217
|
-
end
|
218
|
-
end
|
219
|
-
else
|
220
|
-
xml.TargetValueStats do
|
221
|
-
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
222
|
-
xml.TargetValueStat(value: k2) do
|
223
|
-
xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
|
224
|
-
end
|
225
|
-
end
|
226
|
-
end
|
227
|
-
end
|
228
|
-
end
|
229
|
-
end
|
230
|
-
end
|
231
|
-
xml.BayesOutput(fieldName: "target") do
|
232
|
-
xml.TargetValueCounts do
|
233
|
-
probabilities[:prior].sort_by { |k, _| k }.each do |k, v|
|
234
|
-
xml.TargetValueCount(value: k, count: v)
|
235
|
-
end
|
236
|
-
end
|
237
|
-
end
|
238
|
-
end
|
239
|
-
end
|
240
|
-
end
|
241
|
-
|
242
105
|
def group_count(arr, start)
|
243
106
|
arr.inject(start) { |h, e| h[e] += 1; h }
|
244
107
|
end
|