eps 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -5
- data/README.md +34 -0
- data/lib/eps.rb +19 -10
- data/lib/eps/base_estimator.rb +35 -129
- data/lib/eps/data_frame.rb +7 -1
- data/lib/eps/evaluators/linear_regression.rb +1 -1
- data/lib/eps/label_encoder.rb +7 -3
- data/lib/eps/lightgbm.rb +36 -76
- data/lib/eps/linear_regression.rb +26 -79
- data/lib/eps/metrics.rb +24 -12
- data/lib/eps/model.rb +6 -6
- data/lib/eps/naive_bayes.rb +2 -139
- data/lib/eps/pmml.rb +14 -0
- data/lib/eps/pmml/generator.rb +422 -0
- data/lib/eps/pmml/loader.rb +241 -0
- data/lib/eps/version.rb +1 -1
- metadata +7 -5
- data/lib/eps/pmml_generators/lightgbm.rb +0 -187
data/lib/eps/pmml.rb
ADDED
@@ -0,0 +1,422 @@
|
|
1
|
+
module Eps
|
2
|
+
module PMML
|
3
|
+
class Generator
|
4
|
+
attr_reader :model
|
5
|
+
|
6
|
+
def initialize(model)
|
7
|
+
@model = model
|
8
|
+
end
|
9
|
+
|
10
|
+
def generate
|
11
|
+
case @model
|
12
|
+
when LightGBM
|
13
|
+
lightgbm
|
14
|
+
when LinearRegression
|
15
|
+
linear_regression
|
16
|
+
when NaiveBayes
|
17
|
+
naive_bayes
|
18
|
+
else
|
19
|
+
raise "Unknown model"
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
|
25
|
+
def lightgbm
|
26
|
+
data_fields = {}
|
27
|
+
data_fields[target] = labels if labels
|
28
|
+
features.each_with_index do |(k, type), i|
|
29
|
+
# TODO remove zero importance features
|
30
|
+
if type == "categorical"
|
31
|
+
data_fields[k] = label_encoders[k].labels.keys
|
32
|
+
else
|
33
|
+
data_fields[k] = nil
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
build_pmml(data_fields) do |xml|
|
38
|
+
function_name = objective == "regression" ? "regression" : "classification"
|
39
|
+
xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
|
40
|
+
xml.MiningSchema do
|
41
|
+
xml.MiningField(name: target, usageType: "target")
|
42
|
+
features.keys.each_with_index do |k, i|
|
43
|
+
# next if feature_importance[i] == 0
|
44
|
+
# TODO add importance, but need to handle text features
|
45
|
+
xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
|
46
|
+
end
|
47
|
+
end
|
48
|
+
pmml_local_transformations(xml)
|
49
|
+
|
50
|
+
case objective
|
51
|
+
when "regression"
|
52
|
+
xml_segmentation(xml, trees)
|
53
|
+
when "binary"
|
54
|
+
xml.Segmentation(multipleModelMethod: "modelChain") do
|
55
|
+
xml.Segment(id: 1) do
|
56
|
+
xml.True
|
57
|
+
xml.MiningModel(functionName: "regression") do
|
58
|
+
xml.MiningSchema do
|
59
|
+
features.each do |k, _|
|
60
|
+
xml.MiningField(name: k)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
xml.Output do
|
64
|
+
xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
|
65
|
+
xml.Apply(function: "/") do
|
66
|
+
xml.Constant(dataType: "double") do
|
67
|
+
1.0
|
68
|
+
end
|
69
|
+
xml.Apply(function: "+") do
|
70
|
+
xml.Constant(dataType: "double") do
|
71
|
+
1.0
|
72
|
+
end
|
73
|
+
xml.Apply(function: "exp") do
|
74
|
+
xml.Apply(function: "*") do
|
75
|
+
xml.Constant(dataType: "double") do
|
76
|
+
-1.0
|
77
|
+
end
|
78
|
+
xml.FieldRef(field: "lgbmValue")
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
85
|
+
xml_segmentation(xml, trees)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
xml.Segment(id: 2) do
|
89
|
+
xml.True
|
90
|
+
xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
|
91
|
+
xml.MiningSchema do
|
92
|
+
xml.MiningField(name: target, usageType: "target")
|
93
|
+
xml.MiningField(name: "transformedLgbmValue")
|
94
|
+
end
|
95
|
+
xml.Output do
|
96
|
+
labels.each do |label|
|
97
|
+
xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
|
98
|
+
end
|
99
|
+
end
|
100
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: labels.last) do
|
101
|
+
xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
|
102
|
+
end
|
103
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: labels.first)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
else # multiclass
|
108
|
+
xml.Segmentation(multipleModelMethod: "modelChain") do
|
109
|
+
n = trees.size / labels.size
|
110
|
+
trees.each_slice(n).each_with_index do |trees, idx|
|
111
|
+
xml.Segment(id: idx + 1) do
|
112
|
+
xml.True
|
113
|
+
xml.MiningModel(functionName: "regression") do
|
114
|
+
xml.MiningSchema do
|
115
|
+
features.each do |k, _|
|
116
|
+
xml.MiningField(name: k)
|
117
|
+
end
|
118
|
+
end
|
119
|
+
xml.Output do
|
120
|
+
xml.OutputField(name: "lgbmValue(#{labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
|
121
|
+
end
|
122
|
+
xml_segmentation(xml, trees)
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
126
|
+
xml.Segment(id: labels.size + 1) do
|
127
|
+
xml.True
|
128
|
+
xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
|
129
|
+
xml.MiningSchema do
|
130
|
+
xml.MiningField(name: target, usageType: "target")
|
131
|
+
labels.each do |label|
|
132
|
+
xml.MiningField(name: "lgbmValue(#{label})")
|
133
|
+
end
|
134
|
+
end
|
135
|
+
xml.Output do
|
136
|
+
labels.each do |label|
|
137
|
+
xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
|
138
|
+
end
|
139
|
+
end
|
140
|
+
labels.each do |label|
|
141
|
+
xml.RegressionTable(intercept: 0.0, targetCategory: label) do
|
142
|
+
xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
|
143
|
+
end
|
144
|
+
end
|
145
|
+
end
|
146
|
+
end
|
147
|
+
end
|
148
|
+
end
|
149
|
+
end
|
150
|
+
end
|
151
|
+
end
|
152
|
+
|
153
|
+
def linear_regression
|
154
|
+
predictors = model.instance_variable_get("@coefficients").dup
|
155
|
+
intercept = predictors.delete("_intercept") || 0.0
|
156
|
+
|
157
|
+
data_fields = {}
|
158
|
+
features.each do |k, type|
|
159
|
+
if type == "categorical"
|
160
|
+
data_fields[k] = predictors.keys.select { |k, v| k.is_a?(Array) && k.first == k }.map(&:last)
|
161
|
+
else
|
162
|
+
data_fields[k] = nil
|
163
|
+
end
|
164
|
+
end
|
165
|
+
|
166
|
+
build_pmml(data_fields) do |xml|
|
167
|
+
xml.RegressionModel(functionName: "regression") do
|
168
|
+
xml.MiningSchema do
|
169
|
+
features.each do |k, _|
|
170
|
+
xml.MiningField(name: k)
|
171
|
+
end
|
172
|
+
end
|
173
|
+
pmml_local_transformations(xml)
|
174
|
+
xml.RegressionTable(intercept: intercept) do
|
175
|
+
predictors.each do |k, v|
|
176
|
+
if k.is_a?(Array)
|
177
|
+
if features[k.first] == "text"
|
178
|
+
xml.NumericPredictor(name: display_field(k), coefficient: v)
|
179
|
+
else
|
180
|
+
xml.CategoricalPredictor(name: k[0], value: k[1], coefficient: v)
|
181
|
+
end
|
182
|
+
else
|
183
|
+
xml.NumericPredictor(name: k, coefficient: v)
|
184
|
+
end
|
185
|
+
end
|
186
|
+
end
|
187
|
+
end
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
def naive_bayes
|
192
|
+
data_fields = {}
|
193
|
+
data_fields[target] = probabilities[:prior].keys
|
194
|
+
probabilities[:conditional].each do |k, v|
|
195
|
+
if features[k] == "categorical"
|
196
|
+
data_fields[k] = v.keys
|
197
|
+
else
|
198
|
+
data_fields[k] = nil
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
build_pmml(data_fields) do |xml|
|
203
|
+
xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
|
204
|
+
xml.MiningSchema do
|
205
|
+
data_fields.each do |k, _|
|
206
|
+
xml.MiningField(name: k)
|
207
|
+
end
|
208
|
+
end
|
209
|
+
xml.BayesInputs do
|
210
|
+
probabilities[:conditional].each do |k, v|
|
211
|
+
xml.BayesInput(fieldName: k) do
|
212
|
+
if features[k] == "categorical"
|
213
|
+
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
214
|
+
xml.PairCounts(value: k2) do
|
215
|
+
xml.TargetValueCounts do
|
216
|
+
v2.sort_by { |k2, _| k2 }.each do |k3, v3|
|
217
|
+
xml.TargetValueCount(value: k3, count: v3)
|
218
|
+
end
|
219
|
+
end
|
220
|
+
end
|
221
|
+
end
|
222
|
+
else
|
223
|
+
xml.TargetValueStats do
|
224
|
+
v.sort_by { |k2, _| k2 }.each do |k2, v2|
|
225
|
+
xml.TargetValueStat(value: k2) do
|
226
|
+
xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
|
227
|
+
end
|
228
|
+
end
|
229
|
+
end
|
230
|
+
end
|
231
|
+
end
|
232
|
+
end
|
233
|
+
end
|
234
|
+
xml.BayesOutput(fieldName: "target") do
|
235
|
+
xml.TargetValueCounts do
|
236
|
+
probabilities[:prior].sort_by { |k, _| k }.each do |k, v|
|
237
|
+
xml.TargetValueCount(value: k, count: v)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
242
|
+
end
|
243
|
+
end
|
244
|
+
|
245
|
+
def display_field(k)
|
246
|
+
if k.is_a?(Array)
|
247
|
+
if features[k.first] == "text"
|
248
|
+
"#{k.first}(#{k.last})"
|
249
|
+
else
|
250
|
+
k.join("=")
|
251
|
+
end
|
252
|
+
else
|
253
|
+
k
|
254
|
+
end
|
255
|
+
end
|
256
|
+
|
257
|
+
def xml_segmentation(xml, trees)
|
258
|
+
xml.Segmentation(multipleModelMethod: "sum") do
|
259
|
+
trees.each_with_index do |node, i|
|
260
|
+
xml.Segment(id: i + 1) do
|
261
|
+
xml.True
|
262
|
+
xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
|
263
|
+
xml.MiningSchema do
|
264
|
+
node_fields(node).uniq.each do |k|
|
265
|
+
xml.MiningField(name: display_field(k))
|
266
|
+
end
|
267
|
+
end
|
268
|
+
node_pmml(node, xml)
|
269
|
+
end
|
270
|
+
end
|
271
|
+
end
|
272
|
+
end
|
273
|
+
end
|
274
|
+
|
275
|
+
def node_fields(node)
|
276
|
+
fields = []
|
277
|
+
fields << node.field if node.predicate
|
278
|
+
node.children.each do |n|
|
279
|
+
fields.concat(node_fields(n))
|
280
|
+
end
|
281
|
+
fields
|
282
|
+
end
|
283
|
+
|
284
|
+
def node_pmml(node, xml)
|
285
|
+
xml.Node(score: node.score) do
|
286
|
+
if node.predicate.nil?
|
287
|
+
xml.True
|
288
|
+
elsif node.operator == "in"
|
289
|
+
xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
|
290
|
+
xml.Array(type: "string") do
|
291
|
+
xml.text node.value.map { |v| escape_element(v) }.join(" ")
|
292
|
+
end
|
293
|
+
end
|
294
|
+
else
|
295
|
+
xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
|
296
|
+
end
|
297
|
+
node.children.each do |n|
|
298
|
+
node_pmml(n, xml)
|
299
|
+
end
|
300
|
+
end
|
301
|
+
end
|
302
|
+
|
303
|
+
def escape_element(v)
|
304
|
+
"\"#{v.gsub("\"", "\\\"")}\""
|
305
|
+
end
|
306
|
+
|
307
|
+
def build_pmml(data_fields)
|
308
|
+
Nokogiri::XML::Builder.new do |xml|
|
309
|
+
xml.PMML(version: "4.4", xmlns: "http://www.dmg.org/PMML-4_4", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do
|
310
|
+
pmml_header(xml)
|
311
|
+
pmml_data_dictionary(xml, data_fields)
|
312
|
+
pmml_transformation_dictionary(xml)
|
313
|
+
yield xml
|
314
|
+
end
|
315
|
+
end.to_xml
|
316
|
+
end
|
317
|
+
|
318
|
+
def pmml_header(xml)
|
319
|
+
xml.Header do
|
320
|
+
xml.Application(name: "Eps", version: Eps::VERSION)
|
321
|
+
# xml.Timestamp Time.now.utc.iso8601
|
322
|
+
end
|
323
|
+
end
|
324
|
+
|
325
|
+
def pmml_data_dictionary(xml, data_fields)
|
326
|
+
xml.DataDictionary do
|
327
|
+
data_fields.each do |k, vs|
|
328
|
+
case features[k]
|
329
|
+
when "categorical", nil
|
330
|
+
xml.DataField(name: k, optype: "categorical", dataType: "string") do
|
331
|
+
vs.map(&:to_s).sort.each do |v|
|
332
|
+
xml.Value(value: v)
|
333
|
+
end
|
334
|
+
end
|
335
|
+
when "text"
|
336
|
+
xml.DataField(name: k, optype: "categorical", dataType: "string")
|
337
|
+
else
|
338
|
+
xml.DataField(name: k, optype: "continuous", dataType: "double")
|
339
|
+
end
|
340
|
+
end
|
341
|
+
end
|
342
|
+
end
|
343
|
+
|
344
|
+
def pmml_transformation_dictionary(xml)
|
345
|
+
if text_features.any?
|
346
|
+
xml.TransformationDictionary do
|
347
|
+
text_features.each do |k, text_options|
|
348
|
+
xml.DefineFunction(name: "#{k}Transform", optype: "continuous") do
|
349
|
+
xml.ParameterField(name: "text")
|
350
|
+
xml.ParameterField(name: "term")
|
351
|
+
xml.TextIndex(textField: "text", localTermWeights: "termFrequency", wordSeparatorCharacterRE: text_options[:tokenizer].source, isCaseSensitive: !!text_options[:case_sensitive]) do
|
352
|
+
xml.FieldRef(field: "term")
|
353
|
+
end
|
354
|
+
end
|
355
|
+
end
|
356
|
+
end
|
357
|
+
end
|
358
|
+
end
|
359
|
+
|
360
|
+
def pmml_local_transformations(xml)
|
361
|
+
if text_features.any?
|
362
|
+
xml.LocalTransformations do
|
363
|
+
text_features.each do |k, _|
|
364
|
+
text_encoders[k].vocabulary.each do |v|
|
365
|
+
xml.DerivedField(name: display_field([k, v]), optype: "continuous", dataType: "integer") do
|
366
|
+
xml.Apply(function: "#{k}Transform") do
|
367
|
+
xml.FieldRef(field: k)
|
368
|
+
xml.Constant v
|
369
|
+
end
|
370
|
+
end
|
371
|
+
end
|
372
|
+
end
|
373
|
+
end
|
374
|
+
end
|
375
|
+
end
|
376
|
+
|
377
|
+
# TODO create instance methods on model for all of these features
|
378
|
+
|
379
|
+
def features
|
380
|
+
model.instance_variable_get("@features")
|
381
|
+
end
|
382
|
+
|
383
|
+
def text_features
|
384
|
+
model.instance_variable_get("@text_features")
|
385
|
+
end
|
386
|
+
|
387
|
+
def text_encoders
|
388
|
+
model.instance_variable_get("@text_encoders")
|
389
|
+
end
|
390
|
+
|
391
|
+
def feature_importance
|
392
|
+
model.instance_variable_get("@feature_importance")
|
393
|
+
end
|
394
|
+
|
395
|
+
def labels
|
396
|
+
model.instance_variable_get("@labels")
|
397
|
+
end
|
398
|
+
|
399
|
+
def trees
|
400
|
+
model.instance_variable_get("@trees")
|
401
|
+
end
|
402
|
+
|
403
|
+
def target
|
404
|
+
model.instance_variable_get("@target")
|
405
|
+
end
|
406
|
+
|
407
|
+
def label_encoders
|
408
|
+
model.instance_variable_get("@label_encoders")
|
409
|
+
end
|
410
|
+
|
411
|
+
def objective
|
412
|
+
model.instance_variable_get("@objective")
|
413
|
+
end
|
414
|
+
|
415
|
+
def probabilities
|
416
|
+
model.instance_variable_get("@probabilities")
|
417
|
+
end
|
418
|
+
|
419
|
+
# end TODO
|
420
|
+
end
|
421
|
+
end
|
422
|
+
end
|
@@ -0,0 +1,241 @@
|
|
1
|
+
module Eps
|
2
|
+
module PMML
|
3
|
+
class Loader
|
4
|
+
attr_reader :data
|
5
|
+
|
6
|
+
def initialize(pmml)
|
7
|
+
if pmml.is_a?(String)
|
8
|
+
pmml = Nokogiri::XML(pmml) { |config| config.strict }
|
9
|
+
end
|
10
|
+
@data = pmml
|
11
|
+
end
|
12
|
+
|
13
|
+
def load
|
14
|
+
if data.css("Segmentation").any?
|
15
|
+
lightgbm
|
16
|
+
elsif data.css("RegressionModel").any?
|
17
|
+
linear_regression
|
18
|
+
elsif data.css("NaiveBayesModel").any?
|
19
|
+
naive_bayes
|
20
|
+
else
|
21
|
+
raise "Unknown model"
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
private
|
26
|
+
|
27
|
+
def lightgbm
|
28
|
+
objective = data.css("MiningModel").first.attribute("functionName").value
|
29
|
+
if objective == "classification"
|
30
|
+
labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
|
31
|
+
objective = labels.size > 2 ? "multiclass" : "binary"
|
32
|
+
end
|
33
|
+
|
34
|
+
features = {}
|
35
|
+
text_features, derived_fields = extract_text_features(data, features)
|
36
|
+
node = data.css("DataDictionary").first
|
37
|
+
node.css("DataField")[1..-1].to_a.each do |node|
|
38
|
+
features[node.attribute("name").value] =
|
39
|
+
if node.attribute("optype").value == "categorical"
|
40
|
+
"categorical"
|
41
|
+
else
|
42
|
+
"numeric"
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
trees = []
|
47
|
+
data.css("Segmentation TreeModel").each do |tree|
|
48
|
+
node = find_nodes(tree.css("Node").first, derived_fields)
|
49
|
+
trees << node
|
50
|
+
end
|
51
|
+
|
52
|
+
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
|
53
|
+
end
|
54
|
+
|
55
|
+
def linear_regression
|
56
|
+
node = data.css("RegressionTable")
|
57
|
+
|
58
|
+
coefficients = {
|
59
|
+
"_intercept" => node.attribute("intercept").value.to_f
|
60
|
+
}
|
61
|
+
|
62
|
+
features = {}
|
63
|
+
|
64
|
+
text_features, derived_fields = extract_text_features(data, features)
|
65
|
+
|
66
|
+
node.css("NumericPredictor").each do |n|
|
67
|
+
name = n.attribute("name").value
|
68
|
+
if derived_fields[name]
|
69
|
+
name = derived_fields[name]
|
70
|
+
else
|
71
|
+
features[name] = "numeric"
|
72
|
+
end
|
73
|
+
coefficients[name] = n.attribute("coefficient").value.to_f
|
74
|
+
end
|
75
|
+
|
76
|
+
node.css("CategoricalPredictor").each do |n|
|
77
|
+
name = n.attribute("name").value
|
78
|
+
coefficients[[name, n.attribute("value").value]] = n.attribute("coefficient").value.to_f
|
79
|
+
features[name] = "categorical"
|
80
|
+
end
|
81
|
+
|
82
|
+
Evaluators::LinearRegression.new(coefficients: coefficients, features: features, text_features: text_features)
|
83
|
+
end
|
84
|
+
|
85
|
+
def naive_bayes
|
86
|
+
node = data.css("NaiveBayesModel")
|
87
|
+
|
88
|
+
prior = {}
|
89
|
+
node.css("BayesOutput TargetValueCount").each do |n|
|
90
|
+
prior[n.attribute("value").value] = n.attribute("count").value.to_f
|
91
|
+
end
|
92
|
+
|
93
|
+
legacy = false
|
94
|
+
|
95
|
+
conditional = {}
|
96
|
+
features = {}
|
97
|
+
node.css("BayesInput").each do |n|
|
98
|
+
prob = {}
|
99
|
+
|
100
|
+
# numeric
|
101
|
+
n.css("TargetValueStat").each do |n2|
|
102
|
+
n3 = n2.css("GaussianDistribution")
|
103
|
+
prob[n2.attribute("value").value] = {
|
104
|
+
mean: n3.attribute("mean").value.to_f,
|
105
|
+
stdev: Math.sqrt(n3.attribute("variance").value.to_f)
|
106
|
+
}
|
107
|
+
end
|
108
|
+
|
109
|
+
# detect bad form in Eps < 0.3
|
110
|
+
bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
|
111
|
+
|
112
|
+
# categorical
|
113
|
+
n.css("PairCounts").each do |n2|
|
114
|
+
if bad_format
|
115
|
+
n2.css("TargetValueCount").each do |n3|
|
116
|
+
prob[n3.attribute("value").value] ||= {}
|
117
|
+
prob[n3.attribute("value").value][n2.attribute("value").value] = n3.attribute("count").value.to_f
|
118
|
+
end
|
119
|
+
else
|
120
|
+
boom = {}
|
121
|
+
n2.css("TargetValueCount").each do |n3|
|
122
|
+
boom[n3.attribute("value").value] = n3.attribute("count").value.to_f
|
123
|
+
end
|
124
|
+
prob[n2.attribute("value").value] = boom
|
125
|
+
end
|
126
|
+
end
|
127
|
+
|
128
|
+
if bad_format
|
129
|
+
legacy = true
|
130
|
+
prob.each do |k, v|
|
131
|
+
prior.keys.each do |k|
|
132
|
+
v[k] ||= 0.0
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
name = n.attribute("fieldName").value
|
138
|
+
conditional[name] = prob
|
139
|
+
features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
|
140
|
+
end
|
141
|
+
|
142
|
+
target = node.css("BayesOutput").attribute("fieldName").value
|
143
|
+
|
144
|
+
probabilities = {
|
145
|
+
prior: prior,
|
146
|
+
conditional: conditional
|
147
|
+
}
|
148
|
+
|
149
|
+
# get derived fields
|
150
|
+
derived = {}
|
151
|
+
data.css("DerivedField").each do |n|
|
152
|
+
name = n.attribute("name").value
|
153
|
+
field = n.css("NormDiscrete").attribute("field").value
|
154
|
+
value = n.css("NormDiscrete").attribute("value").value
|
155
|
+
features.delete(name)
|
156
|
+
features[field] = "derived"
|
157
|
+
derived[field] ||= {}
|
158
|
+
derived[field][name] = value
|
159
|
+
end
|
160
|
+
|
161
|
+
Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
|
162
|
+
end
|
163
|
+
|
164
|
+
def extract_text_features(data, features)
|
165
|
+
# updates features object
|
166
|
+
vocabulary = {}
|
167
|
+
function_mapping = {}
|
168
|
+
derived_fields = {}
|
169
|
+
data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n|
|
170
|
+
name = n.attribute("name")&.value
|
171
|
+
field = n.css("FieldRef").attribute("field").value
|
172
|
+
value = n.css("Constant").text
|
173
|
+
|
174
|
+
field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/
|
175
|
+
next if value.empty?
|
176
|
+
|
177
|
+
(vocabulary[field] ||= []) << value
|
178
|
+
|
179
|
+
function_mapping[field] = n.css("Apply").attribute("function").value
|
180
|
+
|
181
|
+
derived_fields[name] = [field, value]
|
182
|
+
end
|
183
|
+
|
184
|
+
functions = {}
|
185
|
+
data.css("TransformationDictionary DefineFunction").each do |n|
|
186
|
+
name = n.attribute("name").value
|
187
|
+
text_index = n.css("TextIndex")
|
188
|
+
functions[name] = {
|
189
|
+
tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value),
|
190
|
+
case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true"
|
191
|
+
}
|
192
|
+
end
|
193
|
+
|
194
|
+
text_features = {}
|
195
|
+
function_mapping.each do |field, function|
|
196
|
+
text_features[field] = functions[function].merge(vocabulary: vocabulary[field])
|
197
|
+
features[field] = "text"
|
198
|
+
end
|
199
|
+
|
200
|
+
[text_features, derived_fields]
|
201
|
+
end
|
202
|
+
|
203
|
+
def find_nodes(xml, derived_fields)
|
204
|
+
score = xml.attribute("score").value.to_f
|
205
|
+
|
206
|
+
elements = xml.elements
|
207
|
+
xml_predicate = elements.first
|
208
|
+
|
209
|
+
predicate =
|
210
|
+
if xml_predicate.name == "True"
|
211
|
+
nil
|
212
|
+
elsif xml_predicate.name == "SimpleSetPredicate"
|
213
|
+
operator = "in"
|
214
|
+
value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
|
215
|
+
field = xml_predicate.attribute("field").value
|
216
|
+
field = derived_fields[field] if derived_fields[field]
|
217
|
+
{
|
218
|
+
field: field,
|
219
|
+
operator: operator,
|
220
|
+
value: value
|
221
|
+
}
|
222
|
+
else
|
223
|
+
operator = xml_predicate.attribute("operator").value
|
224
|
+
value = xml_predicate.attribute("value").value
|
225
|
+
value = value.to_f if operator == "greaterThan"
|
226
|
+
field = xml_predicate.attribute("field").value
|
227
|
+
field = derived_fields[field] if derived_fields[field]
|
228
|
+
{
|
229
|
+
field: field,
|
230
|
+
operator: operator,
|
231
|
+
value: value
|
232
|
+
}
|
233
|
+
end
|
234
|
+
|
235
|
+
children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
|
236
|
+
|
237
|
+
Evaluators::Node.new(score: score, predicate: predicate, children: children)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|