eps 0.3.0 → 0.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -5
- data/README.md +77 -9
- data/lib/eps.rb +19 -10
- data/lib/eps/base_estimator.rb +63 -145
- data/lib/eps/data_frame.rb +19 -3
- data/lib/eps/evaluators/lightgbm.rb +20 -7
- data/lib/eps/evaluators/linear_regression.rb +7 -4
- data/lib/eps/evaluators/naive_bayes.rb +9 -7
- data/lib/eps/label_encoder.rb +7 -3
- data/lib/eps/lightgbm.rb +43 -78
- data/lib/eps/linear_regression.rb +53 -83
- data/lib/eps/metrics.rb +24 -12
- data/lib/eps/model.rb +6 -6
- data/lib/eps/naive_bayes.rb +3 -140
- data/lib/eps/pmml.rb +14 -0
- data/lib/eps/pmml/generator.rb +422 -0
- data/lib/eps/pmml/loader.rb +241 -0
- data/lib/eps/version.rb +1 -1
- metadata +36 -6
- data/lib/eps/pmml_generators/lightgbm.rb +0 -187
data/lib/eps/metrics.rb
CHANGED
@@ -1,31 +1,39 @@
|
|
1
1
|
module Eps
|
2
2
|
module Metrics
|
3
3
|
class << self
|
4
|
-
def rmse(y_true, y_pred)
|
4
|
+
def rmse(y_true, y_pred, weight: nil)
|
5
5
|
check_size(y_true, y_pred)
|
6
|
-
Math.sqrt(mean(errors(y_true, y_pred).map { |v| v**2 }))
|
6
|
+
Math.sqrt(mean(errors(y_true, y_pred).map { |v| v**2 }, weight: weight))
|
7
7
|
end
|
8
8
|
|
9
|
-
def mae(y_true, y_pred)
|
9
|
+
def mae(y_true, y_pred, weight: nil)
|
10
10
|
check_size(y_true, y_pred)
|
11
|
-
mean(errors(y_true, y_pred).map { |v| v.abs })
|
11
|
+
mean(errors(y_true, y_pred).map { |v| v.abs }, weight: weight)
|
12
12
|
end
|
13
13
|
|
14
|
-
def me(y_true, y_pred)
|
14
|
+
def me(y_true, y_pred, weight: nil)
|
15
15
|
check_size(y_true, y_pred)
|
16
|
-
mean(errors(y_true, y_pred))
|
16
|
+
mean(errors(y_true, y_pred), weight: weight)
|
17
17
|
end
|
18
18
|
|
19
|
-
def accuracy(y_true, y_pred)
|
19
|
+
def accuracy(y_true, y_pred, weight: nil)
|
20
20
|
check_size(y_true, y_pred)
|
21
|
-
y_true.zip(y_pred).
|
21
|
+
values = y_true.zip(y_pred).map { |yt, yp| yt == yp ? 1 : 0 }
|
22
|
+
if weight
|
23
|
+
values.each_with_index do |v, i|
|
24
|
+
values[i] *= weight[i]
|
25
|
+
end
|
26
|
+
values.sum / weight.sum.to_f
|
27
|
+
else
|
28
|
+
values.sum / y_true.size.to_f
|
29
|
+
end
|
22
30
|
end
|
23
31
|
|
24
32
|
# http://wiki.fast.ai/index.php/Log_Loss
|
25
|
-
def log_loss(y_true, y_pred, eps: 1e-15)
|
33
|
+
def log_loss(y_true, y_pred, eps: 1e-15, weight: nil)
|
26
34
|
check_size(y_true, y_pred)
|
27
35
|
p = y_pred.map { |yp| yp.clamp(eps, 1 - eps) }
|
28
|
-
mean(y_true.zip(p).map { |yt, pi| yt == 1 ? -Math.log(pi) : -Math.log(1 - pi) })
|
36
|
+
mean(y_true.zip(p).map { |yt, pi| yt == 1 ? -Math.log(pi) : -Math.log(1 - pi) }, weight: weight)
|
29
37
|
end
|
30
38
|
|
31
39
|
private
|
@@ -34,8 +42,12 @@ module Eps
|
|
34
42
|
raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
|
35
43
|
end
|
36
44
|
|
37
|
-
def mean(arr)
|
38
|
-
|
45
|
+
def mean(arr, weight: nil)
|
46
|
+
if weight
|
47
|
+
arr.map.with_index { |v, i| v * weight[i] }.sum / weight.sum.to_f
|
48
|
+
else
|
49
|
+
arr.sum / arr.size.to_f
|
50
|
+
end
|
39
51
|
end
|
40
52
|
|
41
53
|
def errors(y_true, y_pred)
|
data/lib/eps/model.rb
CHANGED
@@ -17,11 +17,11 @@ module Eps
|
|
17
17
|
|
18
18
|
estimator_class =
|
19
19
|
if data.css("Segmentation").any?
|
20
|
-
|
20
|
+
LightGBM
|
21
21
|
elsif data.css("RegressionModel").any?
|
22
|
-
|
22
|
+
LinearRegression
|
23
23
|
elsif data.css("NaiveBayesModel").any?
|
24
|
-
|
24
|
+
NaiveBayes
|
25
25
|
else
|
26
26
|
raise "Unknown model"
|
27
27
|
end
|
@@ -35,11 +35,11 @@ module Eps
|
|
35
35
|
estimator_class =
|
36
36
|
case algorithm
|
37
37
|
when :lightgbm
|
38
|
-
|
38
|
+
LightGBM
|
39
39
|
when :linear_regression
|
40
|
-
|
40
|
+
LinearRegression
|
41
41
|
when :naive_bayes
|
42
|
-
|
42
|
+
NaiveBayes
|
43
43
|
else
|
44
44
|
raise ArgumentError, "Unknown algorithm: #{algorithm}"
|
45
45
|
end
|
data/lib/eps/naive_bayes.rb
CHANGED
@@ -3,91 +3,7 @@ module Eps
|
|
3
3
|
attr_reader :probabilities
|
4
4
|
|
5
5
|
def accuracy
|
6
|
-
Eps::Metrics.accuracy(@train_set.label, predict(@train_set))
|
7
|
-
end
|
8
|
-
|
9
|
-
# pmml
|
10
|
-
|
11
|
-
def self.load_pmml(data)
|
12
|
-
super do |data|
|
13
|
-
# TODO more validation
|
14
|
-
node = data.css("NaiveBayesModel")
|
15
|
-
|
16
|
-
prior = {}
|
17
|
-
node.css("BayesOutput TargetValueCount").each do |n|
|
18
|
-
prior[n.attribute("value").value] = n.attribute("count").value.to_f
|
19
|
-
end
|
20
|
-
|
21
|
-
legacy = false
|
22
|
-
|
23
|
-
conditional = {}
|
24
|
-
features = {}
|
25
|
-
node.css("BayesInput").each do |n|
|
26
|
-
prob = {}
|
27
|
-
|
28
|
-
# numeric
|
29
|
-
n.css("TargetValueStat").each do |n2|
|
30
|
-
n3 = n2.css("GaussianDistribution")
|
31
|
-
prob[n2.attribute("value").value] = {
|
32
|
-
mean: n3.attribute("mean").value.to_f,
|
33
|
-
stdev: Math.sqrt(n3.attribute("variance").value.to_f)
|
34
|
-
}
|
35
|
-
end
|
36
|
-
|
37
|
-
# detect bad form in Eps < 0.3
|
38
|
-
bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
|
39
|
-
|
40
|
-
# categorical
|
41
|
-
n.css("PairCounts").each do |n2|
|
42
|
-
if bad_format
|
43
|
-
n2.css("TargetValueCount").each do |n3|
|
44
|
-
prob[n3.attribute("value").value] ||= {}
|
45
|
-
prob[n3.attribute("value").value][n2.attribute("value").value] = BigDecimal(n3.attribute("count").value)
|
46
|
-
end
|
47
|
-
else
|
48
|
-
boom = {}
|
49
|
-
n2.css("TargetValueCount").each do |n3|
|
50
|
-
boom[n3.attribute("value").value] = BigDecimal(n3.attribute("count").value)
|
51
|
-
end
|
52
|
-
prob[n2.attribute("value").value] = boom
|
53
|
-
end
|
54
|
-
end
|
55
|
-
|
56
|
-
if bad_format
|
57
|
-
legacy = true
|
58
|
-
prob.each do |k, v|
|
59
|
-
prior.keys.each do |k|
|
60
|
-
v[k] ||= 0.0
|
61
|
-
end
|
62
|
-
end
|
63
|
-
end
|
64
|
-
|
65
|
-
name = n.attribute("fieldName").value
|
66
|
-
conditional[name] = prob
|
67
|
-
features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
|
68
|
-
end
|
69
|
-
|
70
|
-
target = node.css("BayesOutput").attribute("fieldName").value
|
71
|
-
|
72
|
-
probabilities = {
|
73
|
-
prior: prior,
|
74
|
-
conditional: conditional
|
75
|
-
}
|
76
|
-
|
77
|
-
# get derived fields
|
78
|
-
derived = {}
|
79
|
-
data.css("DerivedField").each do |n|
|
80
|
-
name = n.attribute("name").value
|
81
|
-
field = n.css("NormDiscrete").attribute("field").value
|
82
|
-
value = n.css("NormDiscrete").attribute("value").value
|
83
|
-
features.delete(name)
|
84
|
-
features[field] = "derived"
|
85
|
-
derived[field] ||= {}
|
86
|
-
derived[field][name] = value
|
87
|
-
end
|
88
|
-
|
89
|
-
Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
|
90
|
-
end
|
6
|
+
Eps::Metrics.accuracy(@train_set.label, predict(@train_set), weight: @train_set.weight)
|
91
7
|
end
|
92
8
|
|
93
9
|
private
|
@@ -101,10 +17,11 @@ module Eps
|
|
101
17
|
str
|
102
18
|
end
|
103
19
|
|
104
|
-
def _train(smoothing: 1
|
20
|
+
def _train(smoothing: 1)
|
105
21
|
raise "Target must be strings" if @target_type != "categorical"
|
106
22
|
check_missing_value(@train_set)
|
107
23
|
check_missing_value(@validation_set) if @validation_set
|
24
|
+
raise ArgumentError, "weight not supported" if @train_set.weight
|
108
25
|
|
109
26
|
data = @train_set
|
110
27
|
|
@@ -185,60 +102,6 @@ module Eps
|
|
185
102
|
Evaluators::NaiveBayes.new(probabilities: probabilities, features: @features)
|
186
103
|
end
|
187
104
|
|
188
|
-
def generate_pmml
|
189
|
-
data_fields = {}
|
190
|
-
data_fields[@target] = probabilities[:prior].keys
|
191
|
-
probabilities[:conditional].each do |k, v|
|
192
|
-
if @features[k] == "categorical"
|
193
|
-
data_fields[k] = v.keys
|
194
|
-
else
|
195
|
-
data_fields[k] = nil
|
196
|
-
end
|
197
|
-
end
|
198
|
-
|
199
|
-
build_pmml(data_fields) do |xml|
|
200
|
-
xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
|
201
|
-
xml.MiningSchema do
|
202
|
-
data_fields.each do |k, _|
|
203
|
-
xml.MiningField(name: k)
|
204
|
-
end
|
205
|
-
end
|
206
|
-
xml.BayesInputs do
|
207
|
-
probabilities[:conditional].each do |k, v|
|
208
|
-
xml.BayesInput(fieldName: k) do
|
209
|
-
if @features[k] == "categorical"
|
210
|
-
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
211
|
-
xml.PairCounts(value: k2) do
|
212
|
-
xml.TargetValueCounts do
|
213
|
-
v2.sort_by { |k2, _| k2 }.each do |k3, v3|
|
214
|
-
xml.TargetValueCount(value: k3, count: v3)
|
215
|
-
end
|
216
|
-
end
|
217
|
-
end
|
218
|
-
end
|
219
|
-
else
|
220
|
-
xml.TargetValueStats do
|
221
|
-
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
222
|
-
xml.TargetValueStat(value: k2) do
|
223
|
-
xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
|
224
|
-
end
|
225
|
-
end
|
226
|
-
end
|
227
|
-
end
|
228
|
-
end
|
229
|
-
end
|
230
|
-
end
|
231
|
-
xml.BayesOutput(fieldName: "target") do
|
232
|
-
xml.TargetValueCounts do
|
233
|
-
probabilities[:prior].sort_by { |k, _| k }.each do |k, v|
|
234
|
-
xml.TargetValueCount(value: k, count: v)
|
235
|
-
end
|
236
|
-
end
|
237
|
-
end
|
238
|
-
end
|
239
|
-
end
|
240
|
-
end
|
241
|
-
|
242
105
|
def group_count(arr, start)
|
243
106
|
arr.inject(start) { |h, e| h[e] += 1; h }
|
244
107
|
end
|
data/lib/eps/pmml.rb
ADDED
@@ -0,0 +1,422 @@
|
|
1
|
+
module Eps
|
2
|
+
module PMML
|
3
|
+
class Generator
|
4
|
+
attr_reader :model
|
5
|
+
|
6
|
+
def initialize(model)
|
7
|
+
@model = model
|
8
|
+
end
|
9
|
+
|
10
|
+
def generate
|
11
|
+
case @model
|
12
|
+
when LightGBM
|
13
|
+
lightgbm
|
14
|
+
when LinearRegression
|
15
|
+
linear_regression
|
16
|
+
when NaiveBayes
|
17
|
+
naive_bayes
|
18
|
+
else
|
19
|
+
raise "Unknown model"
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
|
25
|
+
def lightgbm
|
26
|
+
data_fields = {}
|
27
|
+
data_fields[target] = labels if labels
|
28
|
+
features.each_with_index do |(k, type), i|
|
29
|
+
# TODO remove zero importance features
|
30
|
+
if type == "categorical"
|
31
|
+
data_fields[k] = label_encoders[k].labels.keys
|
32
|
+
else
|
33
|
+
data_fields[k] = nil
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
build_pmml(data_fields) do |xml|
|
38
|
+
function_name = objective == "regression" ? "regression" : "classification"
|
39
|
+
xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
|
40
|
+
xml.MiningSchema do
|
41
|
+
xml.MiningField(name: target, usageType: "target")
|
42
|
+
features.keys.each_with_index do |k, i|
|
43
|
+
# next if feature_importance[i] == 0
|
44
|
+
# TODO add importance, but need to handle text features
|
45
|
+
xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
|
46
|
+
end
|
47
|
+
end
|
48
|
+
pmml_local_transformations(xml)
|
49
|
+
|
50
|
+
case objective
|
51
|
+
when "regression"
|
52
|
+
xml_segmentation(xml, trees)
|
53
|
+
when "binary"
|
54
|
+
xml.Segmentation(multipleModelMethod: "modelChain") do
|
55
|
+
xml.Segment(id: 1) do
|
56
|
+
xml.True
|
57
|
+
xml.MiningModel(functionName: "regression") do
|
58
|
+
xml.MiningSchema do
|
59
|
+
features.each do |k, _|
|
60
|
+
xml.MiningField(name: k)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
xml.Output do
|
64
|
+
xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
|
65
|
+
xml.Apply(function: "/") do
|
66
|
+
xml.Constant(dataType: "double") do
|
67
|
+
1.0
|
68
|
+
end
|
69
|
+
xml.Apply(function: "+") do
|
70
|
+
xml.Constant(dataType: "double") do
|
71
|
+
1.0
|
72
|
+
end
|
73
|
+
xml.Apply(function: "exp") do
|
74
|
+
xml.Apply(function: "*") do
|
75
|
+
xml.Constant(dataType: "double") do
|
76
|
+
-1.0
|
77
|
+
end
|
78
|
+
xml.FieldRef(field: "lgbmValue")
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
85
|
+
xml_segmentation(xml, trees)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
xml.Segment(id: 2) do
|
89
|
+
xml.True
|
90
|
+
xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
|
91
|
+
xml.MiningSchema do
|
92
|
+
xml.MiningField(name: target, usageType: "target")
|
93
|
+
xml.MiningField(name: "transformedLgbmValue")
|
94
|
+
end
|
95
|
+
xml.Output do
|
96
|
+
labels.each do |label|
|
97
|
+
xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
|
98
|
+
end
|
99
|
+
end
|
100
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: labels.last) do
|
101
|
+
xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
|
102
|
+
end
|
103
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: labels.first)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
else # multiclass
|
108
|
+
xml.Segmentation(multipleModelMethod: "modelChain") do
|
109
|
+
n = trees.size / labels.size
|
110
|
+
trees.each_slice(n).each_with_index do |trees, idx|
|
111
|
+
xml.Segment(id: idx + 1) do
|
112
|
+
xml.True
|
113
|
+
xml.MiningModel(functionName: "regression") do
|
114
|
+
xml.MiningSchema do
|
115
|
+
features.each do |k, _|
|
116
|
+
xml.MiningField(name: k)
|
117
|
+
end
|
118
|
+
end
|
119
|
+
xml.Output do
|
120
|
+
xml.OutputField(name: "lgbmValue(#{labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
|
121
|
+
end
|
122
|
+
xml_segmentation(xml, trees)
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
126
|
+
xml.Segment(id: labels.size + 1) do
|
127
|
+
xml.True
|
128
|
+
xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
|
129
|
+
xml.MiningSchema do
|
130
|
+
xml.MiningField(name: target, usageType: "target")
|
131
|
+
labels.each do |label|
|
132
|
+
xml.MiningField(name: "lgbmValue(#{label})")
|
133
|
+
end
|
134
|
+
end
|
135
|
+
xml.Output do
|
136
|
+
labels.each do |label|
|
137
|
+
xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
|
138
|
+
end
|
139
|
+
end
|
140
|
+
labels.each do |label|
|
141
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: label) do
|
142
|
+
xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
|
143
|
+
end
|
144
|
+
end
|
145
|
+
end
|
146
|
+
end
|
147
|
+
end
|
148
|
+
end
|
149
|
+
end
|
150
|
+
end
|
151
|
+
end
|
152
|
+
|
153
|
+
def linear_regression
|
154
|
+
predictors = model.instance_variable_get("@coefficients").dup
|
155
|
+
intercept = predictors.delete("_intercept") || 0.0
|
156
|
+
|
157
|
+
data_fields = {}
|
158
|
+
features.each do |k, type|
|
159
|
+
if type == "categorical"
|
160
|
+
data_fields[k] = predictors.keys.select { |k, v| k.is_a?(Array) && k.first == k }.map(&:last)
|
161
|
+
else
|
162
|
+
data_fields[k] = nil
|
163
|
+
end
|
164
|
+
end
|
165
|
+
|
166
|
+
build_pmml(data_fields) do |xml|
|
167
|
+
xml.RegressionModel(functionName: "regression") do
|
168
|
+
xml.MiningSchema do
|
169
|
+
features.each do |k, _|
|
170
|
+
xml.MiningField(name: k)
|
171
|
+
end
|
172
|
+
end
|
173
|
+
pmml_local_transformations(xml)
|
174
|
+
xml.RegressionTable(intercept: intercept) do
|
175
|
+
predictors.each do |k, v|
|
176
|
+
if k.is_a?(Array)
|
177
|
+
if features[k.first] == "text"
|
178
|
+
xml.NumericPredictor(name: display_field(k), coefficient: v)
|
179
|
+
else
|
180
|
+
xml.CategoricalPredictor(name: k[0], value: k[1], coefficient: v)
|
181
|
+
end
|
182
|
+
else
|
183
|
+
xml.NumericPredictor(name: k, coefficient: v)
|
184
|
+
end
|
185
|
+
end
|
186
|
+
end
|
187
|
+
end
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
def naive_bayes
|
192
|
+
data_fields = {}
|
193
|
+
data_fields[target] = probabilities[:prior].keys
|
194
|
+
probabilities[:conditional].each do |k, v|
|
195
|
+
if features[k] == "categorical"
|
196
|
+
data_fields[k] = v.keys
|
197
|
+
else
|
198
|
+
data_fields[k] = nil
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
build_pmml(data_fields) do |xml|
|
203
|
+
xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
|
204
|
+
xml.MiningSchema do
|
205
|
+
data_fields.each do |k, _|
|
206
|
+
xml.MiningField(name: k)
|
207
|
+
end
|
208
|
+
end
|
209
|
+
xml.BayesInputs do
|
210
|
+
probabilities[:conditional].each do |k, v|
|
211
|
+
xml.BayesInput(fieldName: k) do
|
212
|
+
if features[k] == "categorical"
|
213
|
+
v.sort_by { |k2, _| k2.to_s }.each do |k2, v2|
|
214
|
+
xml.PairCounts(value: k2) do
|
215
|
+
xml.TargetValueCounts do
|
216
|
+
v2.sort_by { |k2, _| k2.to_s }.each do |k3, v3|
|
217
|
+
xml.TargetValueCount(value: k3, count: v3)
|
218
|
+
end
|
219
|
+
end
|
220
|
+
end
|
221
|
+
end
|
222
|
+
else
|
223
|
+
xml.TargetValueStats do
|
224
|
+
v.sort_by { |k2, _| k2.to_s }.each do |k2, v2|
|
225
|
+
xml.TargetValueStat(value: k2) do
|
226
|
+
xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
|
227
|
+
end
|
228
|
+
end
|
229
|
+
end
|
230
|
+
end
|
231
|
+
end
|
232
|
+
end
|
233
|
+
end
|
234
|
+
xml.BayesOutput(fieldName: "target") do
|
235
|
+
xml.TargetValueCounts do
|
236
|
+
probabilities[:prior].sort_by { |k, _| k.to_s }.each do |k, v|
|
237
|
+
xml.TargetValueCount(value: k, count: v)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
242
|
+
end
|
243
|
+
end
|
244
|
+
|
245
|
+
def display_field(k)
|
246
|
+
if k.is_a?(Array)
|
247
|
+
if features[k.first] == "text"
|
248
|
+
"#{k.first}(#{k.last})"
|
249
|
+
else
|
250
|
+
k.join("=")
|
251
|
+
end
|
252
|
+
else
|
253
|
+
k
|
254
|
+
end
|
255
|
+
end
|
256
|
+
|
257
|
+
def xml_segmentation(xml, trees)
|
258
|
+
xml.Segmentation(multipleModelMethod: "sum") do
|
259
|
+
trees.each_with_index do |node, i|
|
260
|
+
xml.Segment(id: i + 1) do
|
261
|
+
xml.True
|
262
|
+
xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
|
263
|
+
xml.MiningSchema do
|
264
|
+
node_fields(node).uniq.each do |k|
|
265
|
+
xml.MiningField(name: display_field(k))
|
266
|
+
end
|
267
|
+
end
|
268
|
+
node_pmml(node, xml)
|
269
|
+
end
|
270
|
+
end
|
271
|
+
end
|
272
|
+
end
|
273
|
+
end
|
274
|
+
|
275
|
+
def node_fields(node)
|
276
|
+
fields = []
|
277
|
+
fields << node.field if node.predicate
|
278
|
+
node.children.each do |n|
|
279
|
+
fields.concat(node_fields(n))
|
280
|
+
end
|
281
|
+
fields
|
282
|
+
end
|
283
|
+
|
284
|
+
def node_pmml(node, xml)
|
285
|
+
xml.Node(score: node.score) do
|
286
|
+
if node.predicate.nil?
|
287
|
+
xml.True
|
288
|
+
elsif node.operator == "in"
|
289
|
+
xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
|
290
|
+
xml.Array(type: "string") do
|
291
|
+
xml.text node.value.map { |v| escape_element(v) }.join(" ")
|
292
|
+
end
|
293
|
+
end
|
294
|
+
else
|
295
|
+
xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
|
296
|
+
end
|
297
|
+
node.children.each do |n|
|
298
|
+
node_pmml(n, xml)
|
299
|
+
end
|
300
|
+
end
|
301
|
+
end
|
302
|
+
|
303
|
+
def escape_element(v)
|
304
|
+
"\"#{v.gsub("\"", "\\\"")}\""
|
305
|
+
end
|
306
|
+
|
307
|
+
def build_pmml(data_fields)
|
308
|
+
Nokogiri::XML::Builder.new do |xml|
|
309
|
+
xml.PMML(version: "4.4", xmlns: "http://www.dmg.org/PMML-4_4", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do
|
310
|
+
pmml_header(xml)
|
311
|
+
pmml_data_dictionary(xml, data_fields)
|
312
|
+
pmml_transformation_dictionary(xml)
|
313
|
+
yield xml
|
314
|
+
end
|
315
|
+
end.to_xml
|
316
|
+
end
|
317
|
+
|
318
|
+
def pmml_header(xml)
|
319
|
+
xml.Header do
|
320
|
+
xml.Application(name: "Eps", version: Eps::VERSION)
|
321
|
+
# xml.Timestamp Time.now.utc.iso8601
|
322
|
+
end
|
323
|
+
end
|
324
|
+
|
325
|
+
def pmml_data_dictionary(xml, data_fields)
|
326
|
+
xml.DataDictionary do
|
327
|
+
data_fields.each do |k, vs|
|
328
|
+
case features[k]
|
329
|
+
when "categorical", nil
|
330
|
+
xml.DataField(name: k, optype: "categorical", dataType: "string") do
|
331
|
+
vs.map(&:to_s).sort.each do |v|
|
332
|
+
xml.Value(value: v)
|
333
|
+
end
|
334
|
+
end
|
335
|
+
when "text"
|
336
|
+
xml.DataField(name: k, optype: "categorical", dataType: "string")
|
337
|
+
else
|
338
|
+
xml.DataField(name: k, optype: "continuous", dataType: "double")
|
339
|
+
end
|
340
|
+
end
|
341
|
+
end
|
342
|
+
end
|
343
|
+
|
344
|
+
def pmml_transformation_dictionary(xml)
|
345
|
+
if text_features.any?
|
346
|
+
xml.TransformationDictionary do
|
347
|
+
text_features.each do |k, text_options|
|
348
|
+
xml.DefineFunction(name: "#{k}Transform", optype: "continuous") do
|
349
|
+
xml.ParameterField(name: "text")
|
350
|
+
xml.ParameterField(name: "term")
|
351
|
+
xml.TextIndex(textField: "text", localTermWeights: "termFrequency", wordSeparatorCharacterRE: text_options[:tokenizer].source, isCaseSensitive: !!text_options[:case_sensitive]) do
|
352
|
+
xml.FieldRef(field: "term")
|
353
|
+
end
|
354
|
+
end
|
355
|
+
end
|
356
|
+
end
|
357
|
+
end
|
358
|
+
end
|
359
|
+
|
360
|
+
def pmml_local_transformations(xml)
|
361
|
+
if text_features.any?
|
362
|
+
xml.LocalTransformations do
|
363
|
+
text_features.each do |k, _|
|
364
|
+
text_encoders[k].vocabulary.each do |v|
|
365
|
+
xml.DerivedField(name: display_field([k, v]), optype: "continuous", dataType: "integer") do
|
366
|
+
xml.Apply(function: "#{k}Transform") do
|
367
|
+
xml.FieldRef(field: k)
|
368
|
+
xml.Constant v
|
369
|
+
end
|
370
|
+
end
|
371
|
+
end
|
372
|
+
end
|
373
|
+
end
|
374
|
+
end
|
375
|
+
end
|
376
|
+
|
377
|
+
# TODO create instance methods on model for all of these features
|
378
|
+
|
379
|
+
def features
|
380
|
+
model.instance_variable_get("@features")
|
381
|
+
end
|
382
|
+
|
383
|
+
def text_features
|
384
|
+
model.instance_variable_get("@text_features")
|
385
|
+
end
|
386
|
+
|
387
|
+
def text_encoders
|
388
|
+
model.instance_variable_get("@text_encoders")
|
389
|
+
end
|
390
|
+
|
391
|
+
def feature_importance
|
392
|
+
model.instance_variable_get("@feature_importance")
|
393
|
+
end
|
394
|
+
|
395
|
+
def labels
|
396
|
+
model.instance_variable_get("@labels")
|
397
|
+
end
|
398
|
+
|
399
|
+
def trees
|
400
|
+
model.instance_variable_get("@trees")
|
401
|
+
end
|
402
|
+
|
403
|
+
def target
|
404
|
+
model.instance_variable_get("@target")
|
405
|
+
end
|
406
|
+
|
407
|
+
def label_encoders
|
408
|
+
model.instance_variable_get("@label_encoders")
|
409
|
+
end
|
410
|
+
|
411
|
+
def objective
|
412
|
+
model.instance_variable_get("@objective")
|
413
|
+
end
|
414
|
+
|
415
|
+
def probabilities
|
416
|
+
model.instance_variable_get("@probabilities")
|
417
|
+
end
|
418
|
+
|
419
|
+
# end TODO
|
420
|
+
end
|
421
|
+
end
|
422
|
+
end
|