eps 0.3.0 → 0.3.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -5
- data/README.md +77 -9
- data/lib/eps.rb +19 -10
- data/lib/eps/base_estimator.rb +63 -145
- data/lib/eps/data_frame.rb +19 -3
- data/lib/eps/evaluators/lightgbm.rb +20 -7
- data/lib/eps/evaluators/linear_regression.rb +7 -4
- data/lib/eps/evaluators/naive_bayes.rb +9 -7
- data/lib/eps/label_encoder.rb +7 -3
- data/lib/eps/lightgbm.rb +43 -78
- data/lib/eps/linear_regression.rb +53 -83
- data/lib/eps/metrics.rb +24 -12
- data/lib/eps/model.rb +6 -6
- data/lib/eps/naive_bayes.rb +3 -140
- data/lib/eps/pmml.rb +14 -0
- data/lib/eps/pmml/generator.rb +422 -0
- data/lib/eps/pmml/loader.rb +241 -0
- data/lib/eps/version.rb +1 -1
- metadata +36 -6
- data/lib/eps/pmml_generators/lightgbm.rb +0 -187
data/lib/eps/metrics.rb
CHANGED
@@ -1,31 +1,39 @@
|
|
1
1
|
module Eps
|
2
2
|
module Metrics
|
3
3
|
class << self
|
4
|
-
def rmse(y_true, y_pred)
|
4
|
+
def rmse(y_true, y_pred, weight: nil)
|
5
5
|
check_size(y_true, y_pred)
|
6
|
-
Math.sqrt(mean(errors(y_true, y_pred).map { |v| v**2 }))
|
6
|
+
Math.sqrt(mean(errors(y_true, y_pred).map { |v| v**2 }, weight: weight))
|
7
7
|
end
|
8
8
|
|
9
|
-
def mae(y_true, y_pred)
|
9
|
+
def mae(y_true, y_pred, weight: nil)
|
10
10
|
check_size(y_true, y_pred)
|
11
|
-
mean(errors(y_true, y_pred).map { |v| v.abs })
|
11
|
+
mean(errors(y_true, y_pred).map { |v| v.abs }, weight: weight)
|
12
12
|
end
|
13
13
|
|
14
|
-
def me(y_true, y_pred)
|
14
|
+
def me(y_true, y_pred, weight: nil)
|
15
15
|
check_size(y_true, y_pred)
|
16
|
-
mean(errors(y_true, y_pred))
|
16
|
+
mean(errors(y_true, y_pred), weight: weight)
|
17
17
|
end
|
18
18
|
|
19
|
-
def accuracy(y_true, y_pred)
|
19
|
+
def accuracy(y_true, y_pred, weight: nil)
|
20
20
|
check_size(y_true, y_pred)
|
21
|
-
y_true.zip(y_pred).
|
21
|
+
values = y_true.zip(y_pred).map { |yt, yp| yt == yp ? 1 : 0 }
|
22
|
+
if weight
|
23
|
+
values.each_with_index do |v, i|
|
24
|
+
values[i] *= weight[i]
|
25
|
+
end
|
26
|
+
values.sum / weight.sum.to_f
|
27
|
+
else
|
28
|
+
values.sum / y_true.size.to_f
|
29
|
+
end
|
22
30
|
end
|
23
31
|
|
24
32
|
# http://wiki.fast.ai/index.php/Log_Loss
|
25
|
-
def log_loss(y_true, y_pred, eps: 1e-15)
|
33
|
+
def log_loss(y_true, y_pred, eps: 1e-15, weight: nil)
|
26
34
|
check_size(y_true, y_pred)
|
27
35
|
p = y_pred.map { |yp| yp.clamp(eps, 1 - eps) }
|
28
|
-
mean(y_true.zip(p).map { |yt, pi| yt == 1 ? -Math.log(pi) : -Math.log(1 - pi) })
|
36
|
+
mean(y_true.zip(p).map { |yt, pi| yt == 1 ? -Math.log(pi) : -Math.log(1 - pi) }, weight: weight)
|
29
37
|
end
|
30
38
|
|
31
39
|
private
|
@@ -34,8 +42,12 @@ module Eps
|
|
34
42
|
raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
|
35
43
|
end
|
36
44
|
|
37
|
-
def mean(arr)
|
38
|
-
|
45
|
+
def mean(arr, weight: nil)
|
46
|
+
if weight
|
47
|
+
arr.map.with_index { |v, i| v * weight[i] }.sum / weight.sum.to_f
|
48
|
+
else
|
49
|
+
arr.sum / arr.size.to_f
|
50
|
+
end
|
39
51
|
end
|
40
52
|
|
41
53
|
def errors(y_true, y_pred)
|
data/lib/eps/model.rb
CHANGED
@@ -17,11 +17,11 @@ module Eps
|
|
17
17
|
|
18
18
|
estimator_class =
|
19
19
|
if data.css("Segmentation").any?
|
20
|
-
|
20
|
+
LightGBM
|
21
21
|
elsif data.css("RegressionModel").any?
|
22
|
-
|
22
|
+
LinearRegression
|
23
23
|
elsif data.css("NaiveBayesModel").any?
|
24
|
-
|
24
|
+
NaiveBayes
|
25
25
|
else
|
26
26
|
raise "Unknown model"
|
27
27
|
end
|
@@ -35,11 +35,11 @@ module Eps
|
|
35
35
|
estimator_class =
|
36
36
|
case algorithm
|
37
37
|
when :lightgbm
|
38
|
-
|
38
|
+
LightGBM
|
39
39
|
when :linear_regression
|
40
|
-
|
40
|
+
LinearRegression
|
41
41
|
when :naive_bayes
|
42
|
-
|
42
|
+
NaiveBayes
|
43
43
|
else
|
44
44
|
raise ArgumentError, "Unknown algorithm: #{algorithm}"
|
45
45
|
end
|
data/lib/eps/naive_bayes.rb
CHANGED
@@ -3,91 +3,7 @@ module Eps
|
|
3
3
|
attr_reader :probabilities
|
4
4
|
|
5
5
|
def accuracy
|
6
|
-
Eps::Metrics.accuracy(@train_set.label, predict(@train_set))
|
7
|
-
end
|
8
|
-
|
9
|
-
# pmml
|
10
|
-
|
11
|
-
def self.load_pmml(data)
|
12
|
-
super do |data|
|
13
|
-
# TODO more validation
|
14
|
-
node = data.css("NaiveBayesModel")
|
15
|
-
|
16
|
-
prior = {}
|
17
|
-
node.css("BayesOutput TargetValueCount").each do |n|
|
18
|
-
prior[n.attribute("value").value] = n.attribute("count").value.to_f
|
19
|
-
end
|
20
|
-
|
21
|
-
legacy = false
|
22
|
-
|
23
|
-
conditional = {}
|
24
|
-
features = {}
|
25
|
-
node.css("BayesInput").each do |n|
|
26
|
-
prob = {}
|
27
|
-
|
28
|
-
# numeric
|
29
|
-
n.css("TargetValueStat").each do |n2|
|
30
|
-
n3 = n2.css("GaussianDistribution")
|
31
|
-
prob[n2.attribute("value").value] = {
|
32
|
-
mean: n3.attribute("mean").value.to_f,
|
33
|
-
stdev: Math.sqrt(n3.attribute("variance").value.to_f)
|
34
|
-
}
|
35
|
-
end
|
36
|
-
|
37
|
-
# detect bad form in Eps < 0.3
|
38
|
-
bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
|
39
|
-
|
40
|
-
# categorical
|
41
|
-
n.css("PairCounts").each do |n2|
|
42
|
-
if bad_format
|
43
|
-
n2.css("TargetValueCount").each do |n3|
|
44
|
-
prob[n3.attribute("value").value] ||= {}
|
45
|
-
prob[n3.attribute("value").value][n2.attribute("value").value] = BigDecimal(n3.attribute("count").value)
|
46
|
-
end
|
47
|
-
else
|
48
|
-
boom = {}
|
49
|
-
n2.css("TargetValueCount").each do |n3|
|
50
|
-
boom[n3.attribute("value").value] = BigDecimal(n3.attribute("count").value)
|
51
|
-
end
|
52
|
-
prob[n2.attribute("value").value] = boom
|
53
|
-
end
|
54
|
-
end
|
55
|
-
|
56
|
-
if bad_format
|
57
|
-
legacy = true
|
58
|
-
prob.each do |k, v|
|
59
|
-
prior.keys.each do |k|
|
60
|
-
v[k] ||= 0.0
|
61
|
-
end
|
62
|
-
end
|
63
|
-
end
|
64
|
-
|
65
|
-
name = n.attribute("fieldName").value
|
66
|
-
conditional[name] = prob
|
67
|
-
features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
|
68
|
-
end
|
69
|
-
|
70
|
-
target = node.css("BayesOutput").attribute("fieldName").value
|
71
|
-
|
72
|
-
probabilities = {
|
73
|
-
prior: prior,
|
74
|
-
conditional: conditional
|
75
|
-
}
|
76
|
-
|
77
|
-
# get derived fields
|
78
|
-
derived = {}
|
79
|
-
data.css("DerivedField").each do |n|
|
80
|
-
name = n.attribute("name").value
|
81
|
-
field = n.css("NormDiscrete").attribute("field").value
|
82
|
-
value = n.css("NormDiscrete").attribute("value").value
|
83
|
-
features.delete(name)
|
84
|
-
features[field] = "derived"
|
85
|
-
derived[field] ||= {}
|
86
|
-
derived[field][name] = value
|
87
|
-
end
|
88
|
-
|
89
|
-
Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
|
90
|
-
end
|
6
|
+
Eps::Metrics.accuracy(@train_set.label, predict(@train_set), weight: @train_set.weight)
|
91
7
|
end
|
92
8
|
|
93
9
|
private
|
@@ -101,10 +17,11 @@ module Eps
|
|
101
17
|
str
|
102
18
|
end
|
103
19
|
|
104
|
-
def _train(smoothing: 1
|
20
|
+
def _train(smoothing: 1)
|
105
21
|
raise "Target must be strings" if @target_type != "categorical"
|
106
22
|
check_missing_value(@train_set)
|
107
23
|
check_missing_value(@validation_set) if @validation_set
|
24
|
+
raise ArgumentError, "weight not supported" if @train_set.weight
|
108
25
|
|
109
26
|
data = @train_set
|
110
27
|
|
@@ -185,60 +102,6 @@ module Eps
|
|
185
102
|
Evaluators::NaiveBayes.new(probabilities: probabilities, features: @features)
|
186
103
|
end
|
187
104
|
|
188
|
-
def generate_pmml
|
189
|
-
data_fields = {}
|
190
|
-
data_fields[@target] = probabilities[:prior].keys
|
191
|
-
probabilities[:conditional].each do |k, v|
|
192
|
-
if @features[k] == "categorical"
|
193
|
-
data_fields[k] = v.keys
|
194
|
-
else
|
195
|
-
data_fields[k] = nil
|
196
|
-
end
|
197
|
-
end
|
198
|
-
|
199
|
-
build_pmml(data_fields) do |xml|
|
200
|
-
xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
|
201
|
-
xml.MiningSchema do
|
202
|
-
data_fields.each do |k, _|
|
203
|
-
xml.MiningField(name: k)
|
204
|
-
end
|
205
|
-
end
|
206
|
-
xml.BayesInputs do
|
207
|
-
probabilities[:conditional].each do |k, v|
|
208
|
-
xml.BayesInput(fieldName: k) do
|
209
|
-
if @features[k] == "categorical"
|
210
|
-
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
211
|
-
xml.PairCounts(value: k2) do
|
212
|
-
xml.TargetValueCounts do
|
213
|
-
v2.sort_by { |k2, _| k2 }.each do |k3, v3|
|
214
|
-
xml.TargetValueCount(value: k3, count: v3)
|
215
|
-
end
|
216
|
-
end
|
217
|
-
end
|
218
|
-
end
|
219
|
-
else
|
220
|
-
xml.TargetValueStats do
|
221
|
-
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
222
|
-
xml.TargetValueStat(value: k2) do
|
223
|
-
xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
|
224
|
-
end
|
225
|
-
end
|
226
|
-
end
|
227
|
-
end
|
228
|
-
end
|
229
|
-
end
|
230
|
-
end
|
231
|
-
xml.BayesOutput(fieldName: "target") do
|
232
|
-
xml.TargetValueCounts do
|
233
|
-
probabilities[:prior].sort_by { |k, _| k }.each do |k, v|
|
234
|
-
xml.TargetValueCount(value: k, count: v)
|
235
|
-
end
|
236
|
-
end
|
237
|
-
end
|
238
|
-
end
|
239
|
-
end
|
240
|
-
end
|
241
|
-
|
242
105
|
def group_count(arr, start)
|
243
106
|
arr.inject(start) { |h, e| h[e] += 1; h }
|
244
107
|
end
|
data/lib/eps/pmml.rb
ADDED
@@ -0,0 +1,422 @@
|
|
1
|
+
module Eps
|
2
|
+
module PMML
|
3
|
+
class Generator
|
4
|
+
attr_reader :model
|
5
|
+
|
6
|
+
def initialize(model)
|
7
|
+
@model = model
|
8
|
+
end
|
9
|
+
|
10
|
+
def generate
|
11
|
+
case @model
|
12
|
+
when LightGBM
|
13
|
+
lightgbm
|
14
|
+
when LinearRegression
|
15
|
+
linear_regression
|
16
|
+
when NaiveBayes
|
17
|
+
naive_bayes
|
18
|
+
else
|
19
|
+
raise "Unknown model"
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
|
25
|
+
def lightgbm
|
26
|
+
data_fields = {}
|
27
|
+
data_fields[target] = labels if labels
|
28
|
+
features.each_with_index do |(k, type), i|
|
29
|
+
# TODO remove zero importance features
|
30
|
+
if type == "categorical"
|
31
|
+
data_fields[k] = label_encoders[k].labels.keys
|
32
|
+
else
|
33
|
+
data_fields[k] = nil
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
build_pmml(data_fields) do |xml|
|
38
|
+
function_name = objective == "regression" ? "regression" : "classification"
|
39
|
+
xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
|
40
|
+
xml.MiningSchema do
|
41
|
+
xml.MiningField(name: target, usageType: "target")
|
42
|
+
features.keys.each_with_index do |k, i|
|
43
|
+
# next if feature_importance[i] == 0
|
44
|
+
# TODO add importance, but need to handle text features
|
45
|
+
xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
|
46
|
+
end
|
47
|
+
end
|
48
|
+
pmml_local_transformations(xml)
|
49
|
+
|
50
|
+
case objective
|
51
|
+
when "regression"
|
52
|
+
xml_segmentation(xml, trees)
|
53
|
+
when "binary"
|
54
|
+
xml.Segmentation(multipleModelMethod: "modelChain") do
|
55
|
+
xml.Segment(id: 1) do
|
56
|
+
xml.True
|
57
|
+
xml.MiningModel(functionName: "regression") do
|
58
|
+
xml.MiningSchema do
|
59
|
+
features.each do |k, _|
|
60
|
+
xml.MiningField(name: k)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
xml.Output do
|
64
|
+
xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
|
65
|
+
xml.Apply(function: "/") do
|
66
|
+
xml.Constant(dataType: "double") do
|
67
|
+
1.0
|
68
|
+
end
|
69
|
+
xml.Apply(function: "+") do
|
70
|
+
xml.Constant(dataType: "double") do
|
71
|
+
1.0
|
72
|
+
end
|
73
|
+
xml.Apply(function: "exp") do
|
74
|
+
xml.Apply(function: "*") do
|
75
|
+
xml.Constant(dataType: "double") do
|
76
|
+
-1.0
|
77
|
+
end
|
78
|
+
xml.FieldRef(field: "lgbmValue")
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
85
|
+
xml_segmentation(xml, trees)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
xml.Segment(id: 2) do
|
89
|
+
xml.True
|
90
|
+
xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
|
91
|
+
xml.MiningSchema do
|
92
|
+
xml.MiningField(name: target, usageType: "target")
|
93
|
+
xml.MiningField(name: "transformedLgbmValue")
|
94
|
+
end
|
95
|
+
xml.Output do
|
96
|
+
labels.each do |label|
|
97
|
+
xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
|
98
|
+
end
|
99
|
+
end
|
100
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: labels.last) do
|
101
|
+
xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
|
102
|
+
end
|
103
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: labels.first)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
else # multiclass
|
108
|
+
xml.Segmentation(multipleModelMethod: "modelChain") do
|
109
|
+
n = trees.size / labels.size
|
110
|
+
trees.each_slice(n).each_with_index do |trees, idx|
|
111
|
+
xml.Segment(id: idx + 1) do
|
112
|
+
xml.True
|
113
|
+
xml.MiningModel(functionName: "regression") do
|
114
|
+
xml.MiningSchema do
|
115
|
+
features.each do |k, _|
|
116
|
+
xml.MiningField(name: k)
|
117
|
+
end
|
118
|
+
end
|
119
|
+
xml.Output do
|
120
|
+
xml.OutputField(name: "lgbmValue(#{labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
|
121
|
+
end
|
122
|
+
xml_segmentation(xml, trees)
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
126
|
+
xml.Segment(id: labels.size + 1) do
|
127
|
+
xml.True
|
128
|
+
xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
|
129
|
+
xml.MiningSchema do
|
130
|
+
xml.MiningField(name: target, usageType: "target")
|
131
|
+
labels.each do |label|
|
132
|
+
xml.MiningField(name: "lgbmValue(#{label})")
|
133
|
+
end
|
134
|
+
end
|
135
|
+
xml.Output do
|
136
|
+
labels.each do |label|
|
137
|
+
xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
|
138
|
+
end
|
139
|
+
end
|
140
|
+
labels.each do |label|
|
141
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: label) do
|
142
|
+
xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
|
143
|
+
end
|
144
|
+
end
|
145
|
+
end
|
146
|
+
end
|
147
|
+
end
|
148
|
+
end
|
149
|
+
end
|
150
|
+
end
|
151
|
+
end
|
152
|
+
|
153
|
+
def linear_regression
|
154
|
+
predictors = model.instance_variable_get("@coefficients").dup
|
155
|
+
intercept = predictors.delete("_intercept") || 0.0
|
156
|
+
|
157
|
+
data_fields = {}
|
158
|
+
features.each do |k, type|
|
159
|
+
if type == "categorical"
|
160
|
+
data_fields[k] = predictors.keys.select { |k, v| k.is_a?(Array) && k.first == k }.map(&:last)
|
161
|
+
else
|
162
|
+
data_fields[k] = nil
|
163
|
+
end
|
164
|
+
end
|
165
|
+
|
166
|
+
build_pmml(data_fields) do |xml|
|
167
|
+
xml.RegressionModel(functionName: "regression") do
|
168
|
+
xml.MiningSchema do
|
169
|
+
features.each do |k, _|
|
170
|
+
xml.MiningField(name: k)
|
171
|
+
end
|
172
|
+
end
|
173
|
+
pmml_local_transformations(xml)
|
174
|
+
xml.RegressionTable(intercept: intercept) do
|
175
|
+
predictors.each do |k, v|
|
176
|
+
if k.is_a?(Array)
|
177
|
+
if features[k.first] == "text"
|
178
|
+
xml.NumericPredictor(name: display_field(k), coefficient: v)
|
179
|
+
else
|
180
|
+
xml.CategoricalPredictor(name: k[0], value: k[1], coefficient: v)
|
181
|
+
end
|
182
|
+
else
|
183
|
+
xml.NumericPredictor(name: k, coefficient: v)
|
184
|
+
end
|
185
|
+
end
|
186
|
+
end
|
187
|
+
end
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
def naive_bayes
|
192
|
+
data_fields = {}
|
193
|
+
data_fields[target] = probabilities[:prior].keys
|
194
|
+
probabilities[:conditional].each do |k, v|
|
195
|
+
if features[k] == "categorical"
|
196
|
+
data_fields[k] = v.keys
|
197
|
+
else
|
198
|
+
data_fields[k] = nil
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
build_pmml(data_fields) do |xml|
|
203
|
+
xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
|
204
|
+
xml.MiningSchema do
|
205
|
+
data_fields.each do |k, _|
|
206
|
+
xml.MiningField(name: k)
|
207
|
+
end
|
208
|
+
end
|
209
|
+
xml.BayesInputs do
|
210
|
+
probabilities[:conditional].each do |k, v|
|
211
|
+
xml.BayesInput(fieldName: k) do
|
212
|
+
if features[k] == "categorical"
|
213
|
+
v.sort_by { |k2, _| k2.to_s }.each do |k2, v2|
|
214
|
+
xml.PairCounts(value: k2) do
|
215
|
+
xml.TargetValueCounts do
|
216
|
+
v2.sort_by { |k2, _| k2.to_s }.each do |k3, v3|
|
217
|
+
xml.TargetValueCount(value: k3, count: v3)
|
218
|
+
end
|
219
|
+
end
|
220
|
+
end
|
221
|
+
end
|
222
|
+
else
|
223
|
+
xml.TargetValueStats do
|
224
|
+
v.sort_by { |k2, _| k2.to_s }.each do |k2, v2|
|
225
|
+
xml.TargetValueStat(value: k2) do
|
226
|
+
xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
|
227
|
+
end
|
228
|
+
end
|
229
|
+
end
|
230
|
+
end
|
231
|
+
end
|
232
|
+
end
|
233
|
+
end
|
234
|
+
xml.BayesOutput(fieldName: "target") do
|
235
|
+
xml.TargetValueCounts do
|
236
|
+
probabilities[:prior].sort_by { |k, _| k.to_s }.each do |k, v|
|
237
|
+
xml.TargetValueCount(value: k, count: v)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
242
|
+
end
|
243
|
+
end
|
244
|
+
|
245
|
+
def display_field(k)
|
246
|
+
if k.is_a?(Array)
|
247
|
+
if features[k.first] == "text"
|
248
|
+
"#{k.first}(#{k.last})"
|
249
|
+
else
|
250
|
+
k.join("=")
|
251
|
+
end
|
252
|
+
else
|
253
|
+
k
|
254
|
+
end
|
255
|
+
end
|
256
|
+
|
257
|
+
def xml_segmentation(xml, trees)
|
258
|
+
xml.Segmentation(multipleModelMethod: "sum") do
|
259
|
+
trees.each_with_index do |node, i|
|
260
|
+
xml.Segment(id: i + 1) do
|
261
|
+
xml.True
|
262
|
+
xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
|
263
|
+
xml.MiningSchema do
|
264
|
+
node_fields(node).uniq.each do |k|
|
265
|
+
xml.MiningField(name: display_field(k))
|
266
|
+
end
|
267
|
+
end
|
268
|
+
node_pmml(node, xml)
|
269
|
+
end
|
270
|
+
end
|
271
|
+
end
|
272
|
+
end
|
273
|
+
end
|
274
|
+
|
275
|
+
def node_fields(node)
|
276
|
+
fields = []
|
277
|
+
fields << node.field if node.predicate
|
278
|
+
node.children.each do |n|
|
279
|
+
fields.concat(node_fields(n))
|
280
|
+
end
|
281
|
+
fields
|
282
|
+
end
|
283
|
+
|
284
|
+
def node_pmml(node, xml)
|
285
|
+
xml.Node(score: node.score) do
|
286
|
+
if node.predicate.nil?
|
287
|
+
xml.True
|
288
|
+
elsif node.operator == "in"
|
289
|
+
xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
|
290
|
+
xml.Array(type: "string") do
|
291
|
+
xml.text node.value.map { |v| escape_element(v) }.join(" ")
|
292
|
+
end
|
293
|
+
end
|
294
|
+
else
|
295
|
+
xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
|
296
|
+
end
|
297
|
+
node.children.each do |n|
|
298
|
+
node_pmml(n, xml)
|
299
|
+
end
|
300
|
+
end
|
301
|
+
end
|
302
|
+
|
303
|
+
def escape_element(v)
|
304
|
+
"\"#{v.gsub("\"", "\\\"")}\""
|
305
|
+
end
|
306
|
+
|
307
|
+
def build_pmml(data_fields)
|
308
|
+
Nokogiri::XML::Builder.new do |xml|
|
309
|
+
xml.PMML(version: "4.4", xmlns: "http://www.dmg.org/PMML-4_4", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do
|
310
|
+
pmml_header(xml)
|
311
|
+
pmml_data_dictionary(xml, data_fields)
|
312
|
+
pmml_transformation_dictionary(xml)
|
313
|
+
yield xml
|
314
|
+
end
|
315
|
+
end.to_xml
|
316
|
+
end
|
317
|
+
|
318
|
+
def pmml_header(xml)
|
319
|
+
xml.Header do
|
320
|
+
xml.Application(name: "Eps", version: Eps::VERSION)
|
321
|
+
# xml.Timestamp Time.now.utc.iso8601
|
322
|
+
end
|
323
|
+
end
|
324
|
+
|
325
|
+
def pmml_data_dictionary(xml, data_fields)
|
326
|
+
xml.DataDictionary do
|
327
|
+
data_fields.each do |k, vs|
|
328
|
+
case features[k]
|
329
|
+
when "categorical", nil
|
330
|
+
xml.DataField(name: k, optype: "categorical", dataType: "string") do
|
331
|
+
vs.map(&:to_s).sort.each do |v|
|
332
|
+
xml.Value(value: v)
|
333
|
+
end
|
334
|
+
end
|
335
|
+
when "text"
|
336
|
+
xml.DataField(name: k, optype: "categorical", dataType: "string")
|
337
|
+
else
|
338
|
+
xml.DataField(name: k, optype: "continuous", dataType: "double")
|
339
|
+
end
|
340
|
+
end
|
341
|
+
end
|
342
|
+
end
|
343
|
+
|
344
|
+
def pmml_transformation_dictionary(xml)
|
345
|
+
if text_features.any?
|
346
|
+
xml.TransformationDictionary do
|
347
|
+
text_features.each do |k, text_options|
|
348
|
+
xml.DefineFunction(name: "#{k}Transform", optype: "continuous") do
|
349
|
+
xml.ParameterField(name: "text")
|
350
|
+
xml.ParameterField(name: "term")
|
351
|
+
xml.TextIndex(textField: "text", localTermWeights: "termFrequency", wordSeparatorCharacterRE: text_options[:tokenizer].source, isCaseSensitive: !!text_options[:case_sensitive]) do
|
352
|
+
xml.FieldRef(field: "term")
|
353
|
+
end
|
354
|
+
end
|
355
|
+
end
|
356
|
+
end
|
357
|
+
end
|
358
|
+
end
|
359
|
+
|
360
|
+
def pmml_local_transformations(xml)
|
361
|
+
if text_features.any?
|
362
|
+
xml.LocalTransformations do
|
363
|
+
text_features.each do |k, _|
|
364
|
+
text_encoders[k].vocabulary.each do |v|
|
365
|
+
xml.DerivedField(name: display_field([k, v]), optype: "continuous", dataType: "integer") do
|
366
|
+
xml.Apply(function: "#{k}Transform") do
|
367
|
+
xml.FieldRef(field: k)
|
368
|
+
xml.Constant v
|
369
|
+
end
|
370
|
+
end
|
371
|
+
end
|
372
|
+
end
|
373
|
+
end
|
374
|
+
end
|
375
|
+
end
|
376
|
+
|
377
|
+
# TODO create instance methods on model for all of these features
|
378
|
+
|
379
|
+
def features
|
380
|
+
model.instance_variable_get("@features")
|
381
|
+
end
|
382
|
+
|
383
|
+
def text_features
|
384
|
+
model.instance_variable_get("@text_features")
|
385
|
+
end
|
386
|
+
|
387
|
+
def text_encoders
|
388
|
+
model.instance_variable_get("@text_encoders")
|
389
|
+
end
|
390
|
+
|
391
|
+
def feature_importance
|
392
|
+
model.instance_variable_get("@feature_importance")
|
393
|
+
end
|
394
|
+
|
395
|
+
def labels
|
396
|
+
model.instance_variable_get("@labels")
|
397
|
+
end
|
398
|
+
|
399
|
+
def trees
|
400
|
+
model.instance_variable_get("@trees")
|
401
|
+
end
|
402
|
+
|
403
|
+
def target
|
404
|
+
model.instance_variable_get("@target")
|
405
|
+
end
|
406
|
+
|
407
|
+
def label_encoders
|
408
|
+
model.instance_variable_get("@label_encoders")
|
409
|
+
end
|
410
|
+
|
411
|
+
def objective
|
412
|
+
model.instance_variable_get("@objective")
|
413
|
+
end
|
414
|
+
|
415
|
+
def probabilities
|
416
|
+
model.instance_variable_get("@probabilities")
|
417
|
+
end
|
418
|
+
|
419
|
+
# end TODO
|
420
|
+
end
|
421
|
+
end
|
422
|
+
end
|