eps 0.3.0 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -5
- data/README.md +34 -0
- data/lib/eps.rb +19 -10
- data/lib/eps/base_estimator.rb +35 -129
- data/lib/eps/data_frame.rb +7 -1
- data/lib/eps/evaluators/linear_regression.rb +1 -1
- data/lib/eps/label_encoder.rb +7 -3
- data/lib/eps/lightgbm.rb +36 -76
- data/lib/eps/linear_regression.rb +26 -79
- data/lib/eps/metrics.rb +24 -12
- data/lib/eps/model.rb +6 -6
- data/lib/eps/naive_bayes.rb +2 -139
- data/lib/eps/pmml.rb +14 -0
- data/lib/eps/pmml/generator.rb +422 -0
- data/lib/eps/pmml/loader.rb +241 -0
- data/lib/eps/version.rb +1 -1
- metadata +7 -5
- data/lib/eps/pmml_generators/lightgbm.rb +0 -187
data/lib/eps/pmml.rb
ADDED
@@ -0,0 +1,422 @@
|
|
1
|
+
module Eps
|
2
|
+
module PMML
|
3
|
+
class Generator
|
4
|
+
attr_reader :model
|
5
|
+
|
6
|
+
def initialize(model)
|
7
|
+
@model = model
|
8
|
+
end
|
9
|
+
|
10
|
+
def generate
|
11
|
+
case @model
|
12
|
+
when LightGBM
|
13
|
+
lightgbm
|
14
|
+
when LinearRegression
|
15
|
+
linear_regression
|
16
|
+
when NaiveBayes
|
17
|
+
naive_bayes
|
18
|
+
else
|
19
|
+
raise "Unknown model"
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
|
25
|
+
def lightgbm
|
26
|
+
data_fields = {}
|
27
|
+
data_fields[target] = labels if labels
|
28
|
+
features.each_with_index do |(k, type), i|
|
29
|
+
# TODO remove zero importance features
|
30
|
+
if type == "categorical"
|
31
|
+
data_fields[k] = label_encoders[k].labels.keys
|
32
|
+
else
|
33
|
+
data_fields[k] = nil
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
build_pmml(data_fields) do |xml|
|
38
|
+
function_name = objective == "regression" ? "regression" : "classification"
|
39
|
+
xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
|
40
|
+
xml.MiningSchema do
|
41
|
+
xml.MiningField(name: target, usageType: "target")
|
42
|
+
features.keys.each_with_index do |k, i|
|
43
|
+
# next if feature_importance[i] == 0
|
44
|
+
# TODO add importance, but need to handle text features
|
45
|
+
xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
|
46
|
+
end
|
47
|
+
end
|
48
|
+
pmml_local_transformations(xml)
|
49
|
+
|
50
|
+
case objective
|
51
|
+
when "regression"
|
52
|
+
xml_segmentation(xml, trees)
|
53
|
+
when "binary"
|
54
|
+
xml.Segmentation(multipleModelMethod: "modelChain") do
|
55
|
+
xml.Segment(id: 1) do
|
56
|
+
xml.True
|
57
|
+
xml.MiningModel(functionName: "regression") do
|
58
|
+
xml.MiningSchema do
|
59
|
+
features.each do |k, _|
|
60
|
+
xml.MiningField(name: k)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
xml.Output do
|
64
|
+
xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
|
65
|
+
xml.Apply(function: "/") do
|
66
|
+
xml.Constant(dataType: "double") do
|
67
|
+
1.0
|
68
|
+
end
|
69
|
+
xml.Apply(function: "+") do
|
70
|
+
xml.Constant(dataType: "double") do
|
71
|
+
1.0
|
72
|
+
end
|
73
|
+
xml.Apply(function: "exp") do
|
74
|
+
xml.Apply(function: "*") do
|
75
|
+
xml.Constant(dataType: "double") do
|
76
|
+
-1.0
|
77
|
+
end
|
78
|
+
xml.FieldRef(field: "lgbmValue")
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
85
|
+
xml_segmentation(xml, trees)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
xml.Segment(id: 2) do
|
89
|
+
xml.True
|
90
|
+
xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
|
91
|
+
xml.MiningSchema do
|
92
|
+
xml.MiningField(name: target, usageType: "target")
|
93
|
+
xml.MiningField(name: "transformedLgbmValue")
|
94
|
+
end
|
95
|
+
xml.Output do
|
96
|
+
labels.each do |label|
|
97
|
+
xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
|
98
|
+
end
|
99
|
+
end
|
100
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: labels.last) do
|
101
|
+
xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
|
102
|
+
end
|
103
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: labels.first)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
else # multiclass
|
108
|
+
xml.Segmentation(multipleModelMethod: "modelChain") do
|
109
|
+
n = trees.size / labels.size
|
110
|
+
trees.each_slice(n).each_with_index do |trees, idx|
|
111
|
+
xml.Segment(id: idx + 1) do
|
112
|
+
xml.True
|
113
|
+
xml.MiningModel(functionName: "regression") do
|
114
|
+
xml.MiningSchema do
|
115
|
+
features.each do |k, _|
|
116
|
+
xml.MiningField(name: k)
|
117
|
+
end
|
118
|
+
end
|
119
|
+
xml.Output do
|
120
|
+
xml.OutputField(name: "lgbmValue(#{labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
|
121
|
+
end
|
122
|
+
xml_segmentation(xml, trees)
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
126
|
+
xml.Segment(id: labels.size + 1) do
|
127
|
+
xml.True
|
128
|
+
xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
|
129
|
+
xml.MiningSchema do
|
130
|
+
xml.MiningField(name: target, usageType: "target")
|
131
|
+
labels.each do |label|
|
132
|
+
xml.MiningField(name: "lgbmValue(#{label})")
|
133
|
+
end
|
134
|
+
end
|
135
|
+
xml.Output do
|
136
|
+
labels.each do |label|
|
137
|
+
xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
|
138
|
+
end
|
139
|
+
end
|
140
|
+
labels.each do |label|
|
141
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: label) do
|
142
|
+
xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
|
143
|
+
end
|
144
|
+
end
|
145
|
+
end
|
146
|
+
end
|
147
|
+
end
|
148
|
+
end
|
149
|
+
end
|
150
|
+
end
|
151
|
+
end
|
152
|
+
|
153
|
+
def linear_regression
|
154
|
+
predictors = model.instance_variable_get("@coefficients").dup
|
155
|
+
intercept = predictors.delete("_intercept") || 0.0
|
156
|
+
|
157
|
+
data_fields = {}
|
158
|
+
features.each do |k, type|
|
159
|
+
if type == "categorical"
|
160
|
+
data_fields[k] = predictors.keys.select { |k, v| k.is_a?(Array) && k.first == k }.map(&:last)
|
161
|
+
else
|
162
|
+
data_fields[k] = nil
|
163
|
+
end
|
164
|
+
end
|
165
|
+
|
166
|
+
build_pmml(data_fields) do |xml|
|
167
|
+
xml.RegressionModel(functionName: "regression") do
|
168
|
+
xml.MiningSchema do
|
169
|
+
features.each do |k, _|
|
170
|
+
xml.MiningField(name: k)
|
171
|
+
end
|
172
|
+
end
|
173
|
+
pmml_local_transformations(xml)
|
174
|
+
xml.RegressionTable(intercept: intercept) do
|
175
|
+
predictors.each do |k, v|
|
176
|
+
if k.is_a?(Array)
|
177
|
+
if features[k.first] == "text"
|
178
|
+
xml.NumericPredictor(name: display_field(k), coefficient: v)
|
179
|
+
else
|
180
|
+
xml.CategoricalPredictor(name: k[0], value: k[1], coefficient: v)
|
181
|
+
end
|
182
|
+
else
|
183
|
+
xml.NumericPredictor(name: k, coefficient: v)
|
184
|
+
end
|
185
|
+
end
|
186
|
+
end
|
187
|
+
end
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
def naive_bayes
|
192
|
+
data_fields = {}
|
193
|
+
data_fields[target] = probabilities[:prior].keys
|
194
|
+
probabilities[:conditional].each do |k, v|
|
195
|
+
if features[k] == "categorical"
|
196
|
+
data_fields[k] = v.keys
|
197
|
+
else
|
198
|
+
data_fields[k] = nil
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
build_pmml(data_fields) do |xml|
|
203
|
+
xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
|
204
|
+
xml.MiningSchema do
|
205
|
+
data_fields.each do |k, _|
|
206
|
+
xml.MiningField(name: k)
|
207
|
+
end
|
208
|
+
end
|
209
|
+
xml.BayesInputs do
|
210
|
+
probabilities[:conditional].each do |k, v|
|
211
|
+
xml.BayesInput(fieldName: k) do
|
212
|
+
if features[k] == "categorical"
|
213
|
+
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
214
|
+
xml.PairCounts(value: k2) do
|
215
|
+
xml.TargetValueCounts do
|
216
|
+
v2.sort_by { |k2, _| k2 }.each do |k3, v3|
|
217
|
+
xml.TargetValueCount(value: k3, count: v3)
|
218
|
+
end
|
219
|
+
end
|
220
|
+
end
|
221
|
+
end
|
222
|
+
else
|
223
|
+
xml.TargetValueStats do
|
224
|
+
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
225
|
+
xml.TargetValueStat(value: k2) do
|
226
|
+
xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
|
227
|
+
end
|
228
|
+
end
|
229
|
+
end
|
230
|
+
end
|
231
|
+
end
|
232
|
+
end
|
233
|
+
end
|
234
|
+
xml.BayesOutput(fieldName: "target") do
|
235
|
+
xml.TargetValueCounts do
|
236
|
+
probabilities[:prior].sort_by { |k, _| k }.each do |k, v|
|
237
|
+
xml.TargetValueCount(value: k, count: v)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
242
|
+
end
|
243
|
+
end
|
244
|
+
|
245
|
+
def display_field(k)
|
246
|
+
if k.is_a?(Array)
|
247
|
+
if features[k.first] == "text"
|
248
|
+
"#{k.first}(#{k.last})"
|
249
|
+
else
|
250
|
+
k.join("=")
|
251
|
+
end
|
252
|
+
else
|
253
|
+
k
|
254
|
+
end
|
255
|
+
end
|
256
|
+
|
257
|
+
def xml_segmentation(xml, trees)
|
258
|
+
xml.Segmentation(multipleModelMethod: "sum") do
|
259
|
+
trees.each_with_index do |node, i|
|
260
|
+
xml.Segment(id: i + 1) do
|
261
|
+
xml.True
|
262
|
+
xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
|
263
|
+
xml.MiningSchema do
|
264
|
+
node_fields(node).uniq.each do |k|
|
265
|
+
xml.MiningField(name: display_field(k))
|
266
|
+
end
|
267
|
+
end
|
268
|
+
node_pmml(node, xml)
|
269
|
+
end
|
270
|
+
end
|
271
|
+
end
|
272
|
+
end
|
273
|
+
end
|
274
|
+
|
275
|
+
def node_fields(node)
|
276
|
+
fields = []
|
277
|
+
fields << node.field if node.predicate
|
278
|
+
node.children.each do |n|
|
279
|
+
fields.concat(node_fields(n))
|
280
|
+
end
|
281
|
+
fields
|
282
|
+
end
|
283
|
+
|
284
|
+
def node_pmml(node, xml)
|
285
|
+
xml.Node(score: node.score) do
|
286
|
+
if node.predicate.nil?
|
287
|
+
xml.True
|
288
|
+
elsif node.operator == "in"
|
289
|
+
xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
|
290
|
+
xml.Array(type: "string") do
|
291
|
+
xml.text node.value.map { |v| escape_element(v) }.join(" ")
|
292
|
+
end
|
293
|
+
end
|
294
|
+
else
|
295
|
+
xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
|
296
|
+
end
|
297
|
+
node.children.each do |n|
|
298
|
+
node_pmml(n, xml)
|
299
|
+
end
|
300
|
+
end
|
301
|
+
end
|
302
|
+
|
303
|
+
def escape_element(v)
|
304
|
+
"\"#{v.gsub("\"", "\\\"")}\""
|
305
|
+
end
|
306
|
+
|
307
|
+
def build_pmml(data_fields)
|
308
|
+
Nokogiri::XML::Builder.new do |xml|
|
309
|
+
xml.PMML(version: "4.4", xmlns: "http://www.dmg.org/PMML-4_4", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do
|
310
|
+
pmml_header(xml)
|
311
|
+
pmml_data_dictionary(xml, data_fields)
|
312
|
+
pmml_transformation_dictionary(xml)
|
313
|
+
yield xml
|
314
|
+
end
|
315
|
+
end.to_xml
|
316
|
+
end
|
317
|
+
|
318
|
+
def pmml_header(xml)
|
319
|
+
xml.Header do
|
320
|
+
xml.Application(name: "Eps", version: Eps::VERSION)
|
321
|
+
# xml.Timestamp Time.now.utc.iso8601
|
322
|
+
end
|
323
|
+
end
|
324
|
+
|
325
|
+
def pmml_data_dictionary(xml, data_fields)
|
326
|
+
xml.DataDictionary do
|
327
|
+
data_fields.each do |k, vs|
|
328
|
+
case features[k]
|
329
|
+
when "categorical", nil
|
330
|
+
xml.DataField(name: k, optype: "categorical", dataType: "string") do
|
331
|
+
vs.map(&:to_s).sort.each do |v|
|
332
|
+
xml.Value(value: v)
|
333
|
+
end
|
334
|
+
end
|
335
|
+
when "text"
|
336
|
+
xml.DataField(name: k, optype: "categorical", dataType: "string")
|
337
|
+
else
|
338
|
+
xml.DataField(name: k, optype: "continuous", dataType: "double")
|
339
|
+
end
|
340
|
+
end
|
341
|
+
end
|
342
|
+
end
|
343
|
+
|
344
|
+
def pmml_transformation_dictionary(xml)
|
345
|
+
if text_features.any?
|
346
|
+
xml.TransformationDictionary do
|
347
|
+
text_features.each do |k, text_options|
|
348
|
+
xml.DefineFunction(name: "#{k}Transform", optype: "continuous") do
|
349
|
+
xml.ParameterField(name: "text")
|
350
|
+
xml.ParameterField(name: "term")
|
351
|
+
xml.TextIndex(textField: "text", localTermWeights: "termFrequency", wordSeparatorCharacterRE: text_options[:tokenizer].source, isCaseSensitive: !!text_options[:case_sensitive]) do
|
352
|
+
xml.FieldRef(field: "term")
|
353
|
+
end
|
354
|
+
end
|
355
|
+
end
|
356
|
+
end
|
357
|
+
end
|
358
|
+
end
|
359
|
+
|
360
|
+
def pmml_local_transformations(xml)
|
361
|
+
if text_features.any?
|
362
|
+
xml.LocalTransformations do
|
363
|
+
text_features.each do |k, _|
|
364
|
+
text_encoders[k].vocabulary.each do |v|
|
365
|
+
xml.DerivedField(name: display_field([k, v]), optype: "continuous", dataType: "integer") do
|
366
|
+
xml.Apply(function: "#{k}Transform") do
|
367
|
+
xml.FieldRef(field: k)
|
368
|
+
xml.Constant v
|
369
|
+
end
|
370
|
+
end
|
371
|
+
end
|
372
|
+
end
|
373
|
+
end
|
374
|
+
end
|
375
|
+
end
|
376
|
+
|
377
|
+
# TODO create instance methods on model for all of these features
|
378
|
+
|
379
|
+
def features
|
380
|
+
model.instance_variable_get("@features")
|
381
|
+
end
|
382
|
+
|
383
|
+
def text_features
|
384
|
+
model.instance_variable_get("@text_features")
|
385
|
+
end
|
386
|
+
|
387
|
+
def text_encoders
|
388
|
+
model.instance_variable_get("@text_encoders")
|
389
|
+
end
|
390
|
+
|
391
|
+
def feature_importance
|
392
|
+
model.instance_variable_get("@feature_importance")
|
393
|
+
end
|
394
|
+
|
395
|
+
def labels
|
396
|
+
model.instance_variable_get("@labels")
|
397
|
+
end
|
398
|
+
|
399
|
+
def trees
|
400
|
+
model.instance_variable_get("@trees")
|
401
|
+
end
|
402
|
+
|
403
|
+
def target
|
404
|
+
model.instance_variable_get("@target")
|
405
|
+
end
|
406
|
+
|
407
|
+
def label_encoders
|
408
|
+
model.instance_variable_get("@label_encoders")
|
409
|
+
end
|
410
|
+
|
411
|
+
def objective
|
412
|
+
model.instance_variable_get("@objective")
|
413
|
+
end
|
414
|
+
|
415
|
+
def probabilities
|
416
|
+
model.instance_variable_get("@probabilities")
|
417
|
+
end
|
418
|
+
|
419
|
+
# end TODO
|
420
|
+
end
|
421
|
+
end
|
422
|
+
end
|
@@ -0,0 +1,241 @@
|
|
1
|
+
module Eps
|
2
|
+
module PMML
|
3
|
+
class Loader
|
4
|
+
attr_reader :data
|
5
|
+
|
6
|
+
def initialize(pmml)
|
7
|
+
if pmml.is_a?(String)
|
8
|
+
pmml = Nokogiri::XML(pmml) { |config| config.strict }
|
9
|
+
end
|
10
|
+
@data = pmml
|
11
|
+
end
|
12
|
+
|
13
|
+
def load
|
14
|
+
if data.css("Segmentation").any?
|
15
|
+
lightgbm
|
16
|
+
elsif data.css("RegressionModel").any?
|
17
|
+
linear_regression
|
18
|
+
elsif data.css("NaiveBayesModel").any?
|
19
|
+
naive_bayes
|
20
|
+
else
|
21
|
+
raise "Unknown model"
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
private
|
26
|
+
|
27
|
+
def lightgbm
|
28
|
+
objective = data.css("MiningModel").first.attribute("functionName").value
|
29
|
+
if objective == "classification"
|
30
|
+
labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
|
31
|
+
objective = labels.size > 2 ? "multiclass" : "binary"
|
32
|
+
end
|
33
|
+
|
34
|
+
features = {}
|
35
|
+
text_features, derived_fields = extract_text_features(data, features)
|
36
|
+
node = data.css("DataDictionary").first
|
37
|
+
node.css("DataField")[1..-1].to_a.each do |node|
|
38
|
+
features[node.attribute("name").value] =
|
39
|
+
if node.attribute("optype").value == "categorical"
|
40
|
+
"categorical"
|
41
|
+
else
|
42
|
+
"numeric"
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
trees = []
|
47
|
+
data.css("Segmentation TreeModel").each do |tree|
|
48
|
+
node = find_nodes(tree.css("Node").first, derived_fields)
|
49
|
+
trees << node
|
50
|
+
end
|
51
|
+
|
52
|
+
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
|
53
|
+
end
|
54
|
+
|
55
|
+
def linear_regression
|
56
|
+
node = data.css("RegressionTable")
|
57
|
+
|
58
|
+
coefficients = {
|
59
|
+
"_intercept" => node.attribute("intercept").value.to_f
|
60
|
+
}
|
61
|
+
|
62
|
+
features = {}
|
63
|
+
|
64
|
+
text_features, derived_fields = extract_text_features(data, features)
|
65
|
+
|
66
|
+
node.css("NumericPredictor").each do |n|
|
67
|
+
name = n.attribute("name").value
|
68
|
+
if derived_fields[name]
|
69
|
+
name = derived_fields[name]
|
70
|
+
else
|
71
|
+
features[name] = "numeric"
|
72
|
+
end
|
73
|
+
coefficients[name] = n.attribute("coefficient").value.to_f
|
74
|
+
end
|
75
|
+
|
76
|
+
node.css("CategoricalPredictor").each do |n|
|
77
|
+
name = n.attribute("name").value
|
78
|
+
coefficients[[name, n.attribute("value").value]] = n.attribute("coefficient").value.to_f
|
79
|
+
features[name] = "categorical"
|
80
|
+
end
|
81
|
+
|
82
|
+
Evaluators::LinearRegression.new(coefficients: coefficients, features: features, text_features: text_features)
|
83
|
+
end
|
84
|
+
|
85
|
+
def naive_bayes
|
86
|
+
node = data.css("NaiveBayesModel")
|
87
|
+
|
88
|
+
prior = {}
|
89
|
+
node.css("BayesOutput TargetValueCount").each do |n|
|
90
|
+
prior[n.attribute("value").value] = n.attribute("count").value.to_f
|
91
|
+
end
|
92
|
+
|
93
|
+
legacy = false
|
94
|
+
|
95
|
+
conditional = {}
|
96
|
+
features = {}
|
97
|
+
node.css("BayesInput").each do |n|
|
98
|
+
prob = {}
|
99
|
+
|
100
|
+
# numeric
|
101
|
+
n.css("TargetValueStat").each do |n2|
|
102
|
+
n3 = n2.css("GaussianDistribution")
|
103
|
+
prob[n2.attribute("value").value] = {
|
104
|
+
mean: n3.attribute("mean").value.to_f,
|
105
|
+
stdev: Math.sqrt(n3.attribute("variance").value.to_f)
|
106
|
+
}
|
107
|
+
end
|
108
|
+
|
109
|
+
# detect bad form in Eps < 0.3
|
110
|
+
bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
|
111
|
+
|
112
|
+
# categorical
|
113
|
+
n.css("PairCounts").each do |n2|
|
114
|
+
if bad_format
|
115
|
+
n2.css("TargetValueCount").each do |n3|
|
116
|
+
prob[n3.attribute("value").value] ||= {}
|
117
|
+
prob[n3.attribute("value").value][n2.attribute("value").value] = n3.attribute("count").value.to_f
|
118
|
+
end
|
119
|
+
else
|
120
|
+
boom = {}
|
121
|
+
n2.css("TargetValueCount").each do |n3|
|
122
|
+
boom[n3.attribute("value").value] = n3.attribute("count").value.to_f
|
123
|
+
end
|
124
|
+
prob[n2.attribute("value").value] = boom
|
125
|
+
end
|
126
|
+
end
|
127
|
+
|
128
|
+
if bad_format
|
129
|
+
legacy = true
|
130
|
+
prob.each do |k, v|
|
131
|
+
prior.keys.each do |k|
|
132
|
+
v[k] ||= 0.0
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
name = n.attribute("fieldName").value
|
138
|
+
conditional[name] = prob
|
139
|
+
features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
|
140
|
+
end
|
141
|
+
|
142
|
+
target = node.css("BayesOutput").attribute("fieldName").value
|
143
|
+
|
144
|
+
probabilities = {
|
145
|
+
prior: prior,
|
146
|
+
conditional: conditional
|
147
|
+
}
|
148
|
+
|
149
|
+
# get derived fields
|
150
|
+
derived = {}
|
151
|
+
data.css("DerivedField").each do |n|
|
152
|
+
name = n.attribute("name").value
|
153
|
+
field = n.css("NormDiscrete").attribute("field").value
|
154
|
+
value = n.css("NormDiscrete").attribute("value").value
|
155
|
+
features.delete(name)
|
156
|
+
features[field] = "derived"
|
157
|
+
derived[field] ||= {}
|
158
|
+
derived[field][name] = value
|
159
|
+
end
|
160
|
+
|
161
|
+
Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
|
162
|
+
end
|
163
|
+
|
164
|
+
def extract_text_features(data, features)
|
165
|
+
# updates features object
|
166
|
+
vocabulary = {}
|
167
|
+
function_mapping = {}
|
168
|
+
derived_fields = {}
|
169
|
+
data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n|
|
170
|
+
name = n.attribute("name")&.value
|
171
|
+
field = n.css("FieldRef").attribute("field").value
|
172
|
+
value = n.css("Constant").text
|
173
|
+
|
174
|
+
field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/
|
175
|
+
next if value.empty?
|
176
|
+
|
177
|
+
(vocabulary[field] ||= []) << value
|
178
|
+
|
179
|
+
function_mapping[field] = n.css("Apply").attribute("function").value
|
180
|
+
|
181
|
+
derived_fields[name] = [field, value]
|
182
|
+
end
|
183
|
+
|
184
|
+
functions = {}
|
185
|
+
data.css("TransformationDictionary DefineFunction").each do |n|
|
186
|
+
name = n.attribute("name").value
|
187
|
+
text_index = n.css("TextIndex")
|
188
|
+
functions[name] = {
|
189
|
+
tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value),
|
190
|
+
case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true"
|
191
|
+
}
|
192
|
+
end
|
193
|
+
|
194
|
+
text_features = {}
|
195
|
+
function_mapping.each do |field, function|
|
196
|
+
text_features[field] = functions[function].merge(vocabulary: vocabulary[field])
|
197
|
+
features[field] = "text"
|
198
|
+
end
|
199
|
+
|
200
|
+
[text_features, derived_fields]
|
201
|
+
end
|
202
|
+
|
203
|
+
def find_nodes(xml, derived_fields)
|
204
|
+
score = xml.attribute("score").value.to_f
|
205
|
+
|
206
|
+
elements = xml.elements
|
207
|
+
xml_predicate = elements.first
|
208
|
+
|
209
|
+
predicate =
|
210
|
+
if xml_predicate.name == "True"
|
211
|
+
nil
|
212
|
+
elsif xml_predicate.name == "SimpleSetPredicate"
|
213
|
+
operator = "in"
|
214
|
+
value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
|
215
|
+
field = xml_predicate.attribute("field").value
|
216
|
+
field = derived_fields[field] if derived_fields[field]
|
217
|
+
{
|
218
|
+
field: field,
|
219
|
+
operator: operator,
|
220
|
+
value: value
|
221
|
+
}
|
222
|
+
else
|
223
|
+
operator = xml_predicate.attribute("operator").value
|
224
|
+
value = xml_predicate.attribute("value").value
|
225
|
+
value = value.to_f if operator == "greaterThan"
|
226
|
+
field = xml_predicate.attribute("field").value
|
227
|
+
field = derived_fields[field] if derived_fields[field]
|
228
|
+
{
|
229
|
+
field: field,
|
230
|
+
operator: operator,
|
231
|
+
value: value
|
232
|
+
}
|
233
|
+
end
|
234
|
+
|
235
|
+
children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
|
236
|
+
|
237
|
+
Evaluators::Node.new(score: score, predicate: predicate, children: children)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|