eps 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,14 @@
1
+ # similar to Marshal/JSON/YAML interface
2
+ module Eps
3
+ module PMML
4
+ class << self
5
+ def load(pmml)
6
+ Loader.new(pmml).load
7
+ end
8
+
9
+ def generate(model)
10
+ Generator.new(model).generate
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,422 @@
1
+ module Eps
2
+ module PMML
3
+ class Generator
4
+ attr_reader :model
5
+
6
+ def initialize(model)
7
+ @model = model
8
+ end
9
+
10
+ def generate
11
+ case @model
12
+ when LightGBM
13
+ lightgbm
14
+ when LinearRegression
15
+ linear_regression
16
+ when NaiveBayes
17
+ naive_bayes
18
+ else
19
+ raise "Unknown model"
20
+ end
21
+ end
22
+
23
+ private
24
+
25
+ def lightgbm
26
+ data_fields = {}
27
+ data_fields[target] = labels if labels
28
+ features.each_with_index do |(k, type), i|
29
+ # TODO remove zero importance features
30
+ if type == "categorical"
31
+ data_fields[k] = label_encoders[k].labels.keys
32
+ else
33
+ data_fields[k] = nil
34
+ end
35
+ end
36
+
37
+ build_pmml(data_fields) do |xml|
38
+ function_name = objective == "regression" ? "regression" : "classification"
39
+ xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
40
+ xml.MiningSchema do
41
+ xml.MiningField(name: target, usageType: "target")
42
+ features.keys.each_with_index do |k, i|
43
+ # next if feature_importance[i] == 0
44
+ # TODO add importance, but need to handle text features
45
+ xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
46
+ end
47
+ end
48
+ pmml_local_transformations(xml)
49
+
50
+ case objective
51
+ when "regression"
52
+ xml_segmentation(xml, trees)
53
+ when "binary"
54
+ xml.Segmentation(multipleModelMethod: "modelChain") do
55
+ xml.Segment(id: 1) do
56
+ xml.True
57
+ xml.MiningModel(functionName: "regression") do
58
+ xml.MiningSchema do
59
+ features.each do |k, _|
60
+ xml.MiningField(name: k)
61
+ end
62
+ end
63
+ xml.Output do
64
+ xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
65
+ xml.Apply(function: "/") do
66
+ xml.Constant(dataType: "double") do
67
+ 1.0
68
+ end
69
+ xml.Apply(function: "+") do
70
+ xml.Constant(dataType: "double") do
71
+ 1.0
72
+ end
73
+ xml.Apply(function: "exp") do
74
+ xml.Apply(function: "*") do
75
+ xml.Constant(dataType: "double") do
76
+ -1.0
77
+ end
78
+ xml.FieldRef(field: "lgbmValue")
79
+ end
80
+ end
81
+ end
82
+ end
83
+ end
84
+ end
85
+ xml_segmentation(xml, trees)
86
+ end
87
+ end
88
+ xml.Segment(id: 2) do
89
+ xml.True
90
+ xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
91
+ xml.MiningSchema do
92
+ xml.MiningField(name: target, usageType: "target")
93
+ xml.MiningField(name: "transformedLgbmValue")
94
+ end
95
+ xml.Output do
96
+ labels.each do |label|
97
+ xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
98
+ end
99
+ end
100
+ xml.RegressionTable(intercept: 0.0, targetCategory: labels.last) do
101
+ xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
102
+ end
103
+ xml.RegressionTable(intercept: 0.0, targetCategory: labels.first)
104
+ end
105
+ end
106
+ end
107
+ else # multiclass
108
+ xml.Segmentation(multipleModelMethod: "modelChain") do
109
+ n = trees.size / labels.size
110
+ trees.each_slice(n).each_with_index do |trees, idx|
111
+ xml.Segment(id: idx + 1) do
112
+ xml.True
113
+ xml.MiningModel(functionName: "regression") do
114
+ xml.MiningSchema do
115
+ features.each do |k, _|
116
+ xml.MiningField(name: k)
117
+ end
118
+ end
119
+ xml.Output do
120
+ xml.OutputField(name: "lgbmValue(#{labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
121
+ end
122
+ xml_segmentation(xml, trees)
123
+ end
124
+ end
125
+ end
126
+ xml.Segment(id: labels.size + 1) do
127
+ xml.True
128
+ xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
129
+ xml.MiningSchema do
130
+ xml.MiningField(name: target, usageType: "target")
131
+ labels.each do |label|
132
+ xml.MiningField(name: "lgbmValue(#{label})")
133
+ end
134
+ end
135
+ xml.Output do
136
+ labels.each do |label|
137
+ xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
138
+ end
139
+ end
140
+ labels.each do |label|
141
+ xml.RegressionTable(intercept: 0.0, targetCategory: label) do
142
+ xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
143
+ end
144
+ end
145
+ end
146
+ end
147
+ end
148
+ end
149
+ end
150
+ end
151
+ end
152
+
153
+ def linear_regression
154
+ predictors = model.instance_variable_get("@coefficients").dup
155
+ intercept = predictors.delete("_intercept") || 0.0
156
+
157
+ data_fields = {}
158
+ features.each do |k, type|
159
+ if type == "categorical"
160
+ data_fields[k] = predictors.keys.select { |k, v| k.is_a?(Array) && k.first == k }.map(&:last)
161
+ else
162
+ data_fields[k] = nil
163
+ end
164
+ end
165
+
166
+ build_pmml(data_fields) do |xml|
167
+ xml.RegressionModel(functionName: "regression") do
168
+ xml.MiningSchema do
169
+ features.each do |k, _|
170
+ xml.MiningField(name: k)
171
+ end
172
+ end
173
+ pmml_local_transformations(xml)
174
+ xml.RegressionTable(intercept: intercept) do
175
+ predictors.each do |k, v|
176
+ if k.is_a?(Array)
177
+ if features[k.first] == "text"
178
+ xml.NumericPredictor(name: display_field(k), coefficient: v)
179
+ else
180
+ xml.CategoricalPredictor(name: k[0], value: k[1], coefficient: v)
181
+ end
182
+ else
183
+ xml.NumericPredictor(name: k, coefficient: v)
184
+ end
185
+ end
186
+ end
187
+ end
188
+ end
189
+ end
190
+
191
+ def naive_bayes
192
+ data_fields = {}
193
+ data_fields[target] = probabilities[:prior].keys
194
+ probabilities[:conditional].each do |k, v|
195
+ if features[k] == "categorical"
196
+ data_fields[k] = v.keys
197
+ else
198
+ data_fields[k] = nil
199
+ end
200
+ end
201
+
202
+ build_pmml(data_fields) do |xml|
203
+ xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
204
+ xml.MiningSchema do
205
+ data_fields.each do |k, _|
206
+ xml.MiningField(name: k)
207
+ end
208
+ end
209
+ xml.BayesInputs do
210
+ probabilities[:conditional].each do |k, v|
211
+ xml.BayesInput(fieldName: k) do
212
+ if features[k] == "categorical"
213
+ v.sort_by { |k2, _| k2 }.each do |k2, v2|
214
+ xml.PairCounts(value: k2) do
215
+ xml.TargetValueCounts do
216
+ v2.sort_by { |k2, _| k2 }.each do |k3, v3|
217
+ xml.TargetValueCount(value: k3, count: v3)
218
+ end
219
+ end
220
+ end
221
+ end
222
+ else
223
+ xml.TargetValueStats do
224
+ v.sort_by { |k2, _| k2 }.each do |k2, v2|
225
+ xml.TargetValueStat(value: k2) do
226
+ xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
227
+ end
228
+ end
229
+ end
230
+ end
231
+ end
232
+ end
233
+ end
234
+ xml.BayesOutput(fieldName: "target") do
235
+ xml.TargetValueCounts do
236
+ probabilities[:prior].sort_by { |k, _| k }.each do |k, v|
237
+ xml.TargetValueCount(value: k, count: v)
238
+ end
239
+ end
240
+ end
241
+ end
242
+ end
243
+ end
244
+
245
+ def display_field(k)
246
+ if k.is_a?(Array)
247
+ if features[k.first] == "text"
248
+ "#{k.first}(#{k.last})"
249
+ else
250
+ k.join("=")
251
+ end
252
+ else
253
+ k
254
+ end
255
+ end
256
+
257
+ def xml_segmentation(xml, trees)
258
+ xml.Segmentation(multipleModelMethod: "sum") do
259
+ trees.each_with_index do |node, i|
260
+ xml.Segment(id: i + 1) do
261
+ xml.True
262
+ xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
263
+ xml.MiningSchema do
264
+ node_fields(node).uniq.each do |k|
265
+ xml.MiningField(name: display_field(k))
266
+ end
267
+ end
268
+ node_pmml(node, xml)
269
+ end
270
+ end
271
+ end
272
+ end
273
+ end
274
+
275
+ def node_fields(node)
276
+ fields = []
277
+ fields << node.field if node.predicate
278
+ node.children.each do |n|
279
+ fields.concat(node_fields(n))
280
+ end
281
+ fields
282
+ end
283
+
284
+ def node_pmml(node, xml)
285
+ xml.Node(score: node.score) do
286
+ if node.predicate.nil?
287
+ xml.True
288
+ elsif node.operator == "in"
289
+ xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
290
+ xml.Array(type: "string") do
291
+ xml.text node.value.map { |v| escape_element(v) }.join(" ")
292
+ end
293
+ end
294
+ else
295
+ xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
296
+ end
297
+ node.children.each do |n|
298
+ node_pmml(n, xml)
299
+ end
300
+ end
301
+ end
302
+
303
+ def escape_element(v)
304
+ "\"#{v.gsub("\"", "\\\"")}\""
305
+ end
306
+
307
+ def build_pmml(data_fields)
308
+ Nokogiri::XML::Builder.new do |xml|
309
+ xml.PMML(version: "4.4", xmlns: "http://www.dmg.org/PMML-4_4", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do
310
+ pmml_header(xml)
311
+ pmml_data_dictionary(xml, data_fields)
312
+ pmml_transformation_dictionary(xml)
313
+ yield xml
314
+ end
315
+ end.to_xml
316
+ end
317
+
318
+ def pmml_header(xml)
319
+ xml.Header do
320
+ xml.Application(name: "Eps", version: Eps::VERSION)
321
+ # xml.Timestamp Time.now.utc.iso8601
322
+ end
323
+ end
324
+
325
+ def pmml_data_dictionary(xml, data_fields)
326
+ xml.DataDictionary do
327
+ data_fields.each do |k, vs|
328
+ case features[k]
329
+ when "categorical", nil
330
+ xml.DataField(name: k, optype: "categorical", dataType: "string") do
331
+ vs.map(&:to_s).sort.each do |v|
332
+ xml.Value(value: v)
333
+ end
334
+ end
335
+ when "text"
336
+ xml.DataField(name: k, optype: "categorical", dataType: "string")
337
+ else
338
+ xml.DataField(name: k, optype: "continuous", dataType: "double")
339
+ end
340
+ end
341
+ end
342
+ end
343
+
344
+ def pmml_transformation_dictionary(xml)
345
+ if text_features.any?
346
+ xml.TransformationDictionary do
347
+ text_features.each do |k, text_options|
348
+ xml.DefineFunction(name: "#{k}Transform", optype: "continuous") do
349
+ xml.ParameterField(name: "text")
350
+ xml.ParameterField(name: "term")
351
+ xml.TextIndex(textField: "text", localTermWeights: "termFrequency", wordSeparatorCharacterRE: text_options[:tokenizer].source, isCaseSensitive: !!text_options[:case_sensitive]) do
352
+ xml.FieldRef(field: "term")
353
+ end
354
+ end
355
+ end
356
+ end
357
+ end
358
+ end
359
+
360
+ def pmml_local_transformations(xml)
361
+ if text_features.any?
362
+ xml.LocalTransformations do
363
+ text_features.each do |k, _|
364
+ text_encoders[k].vocabulary.each do |v|
365
+ xml.DerivedField(name: display_field([k, v]), optype: "continuous", dataType: "integer") do
366
+ xml.Apply(function: "#{k}Transform") do
367
+ xml.FieldRef(field: k)
368
+ xml.Constant v
369
+ end
370
+ end
371
+ end
372
+ end
373
+ end
374
+ end
375
+ end
376
+
377
+ # TODO create instance methods on model for all of these features
378
+
379
+ def features
380
+ model.instance_variable_get("@features")
381
+ end
382
+
383
+ def text_features
384
+ model.instance_variable_get("@text_features")
385
+ end
386
+
387
+ def text_encoders
388
+ model.instance_variable_get("@text_encoders")
389
+ end
390
+
391
+ def feature_importance
392
+ model.instance_variable_get("@feature_importance")
393
+ end
394
+
395
+ def labels
396
+ model.instance_variable_get("@labels")
397
+ end
398
+
399
+ def trees
400
+ model.instance_variable_get("@trees")
401
+ end
402
+
403
+ def target
404
+ model.instance_variable_get("@target")
405
+ end
406
+
407
+ def label_encoders
408
+ model.instance_variable_get("@label_encoders")
409
+ end
410
+
411
+ def objective
412
+ model.instance_variable_get("@objective")
413
+ end
414
+
415
+ def probabilities
416
+ model.instance_variable_get("@probabilities")
417
+ end
418
+
419
+ # end TODO
420
+ end
421
+ end
422
+ end
@@ -0,0 +1,241 @@
1
+ module Eps
2
+ module PMML
3
+ class Loader
4
+ attr_reader :data
5
+
6
+ def initialize(pmml)
7
+ if pmml.is_a?(String)
8
+ pmml = Nokogiri::XML(pmml) { |config| config.strict }
9
+ end
10
+ @data = pmml
11
+ end
12
+
13
+ def load
14
+ if data.css("Segmentation").any?
15
+ lightgbm
16
+ elsif data.css("RegressionModel").any?
17
+ linear_regression
18
+ elsif data.css("NaiveBayesModel").any?
19
+ naive_bayes
20
+ else
21
+ raise "Unknown model"
22
+ end
23
+ end
24
+
25
+ private
26
+
27
+ def lightgbm
28
+ objective = data.css("MiningModel").first.attribute("functionName").value
29
+ if objective == "classification"
30
+ labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
31
+ objective = labels.size > 2 ? "multiclass" : "binary"
32
+ end
33
+
34
+ features = {}
35
+ text_features, derived_fields = extract_text_features(data, features)
36
+ node = data.css("DataDictionary").first
37
+ node.css("DataField")[1..-1].to_a.each do |node|
38
+ features[node.attribute("name").value] =
39
+ if node.attribute("optype").value == "categorical"
40
+ "categorical"
41
+ else
42
+ "numeric"
43
+ end
44
+ end
45
+
46
+ trees = []
47
+ data.css("Segmentation TreeModel").each do |tree|
48
+ node = find_nodes(tree.css("Node").first, derived_fields)
49
+ trees << node
50
+ end
51
+
52
+ Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
53
+ end
54
+
55
+ def linear_regression
56
+ node = data.css("RegressionTable")
57
+
58
+ coefficients = {
59
+ "_intercept" => node.attribute("intercept").value.to_f
60
+ }
61
+
62
+ features = {}
63
+
64
+ text_features, derived_fields = extract_text_features(data, features)
65
+
66
+ node.css("NumericPredictor").each do |n|
67
+ name = n.attribute("name").value
68
+ if derived_fields[name]
69
+ name = derived_fields[name]
70
+ else
71
+ features[name] = "numeric"
72
+ end
73
+ coefficients[name] = n.attribute("coefficient").value.to_f
74
+ end
75
+
76
+ node.css("CategoricalPredictor").each do |n|
77
+ name = n.attribute("name").value
78
+ coefficients[[name, n.attribute("value").value]] = n.attribute("coefficient").value.to_f
79
+ features[name] = "categorical"
80
+ end
81
+
82
+ Evaluators::LinearRegression.new(coefficients: coefficients, features: features, text_features: text_features)
83
+ end
84
+
85
+ def naive_bayes
86
+ node = data.css("NaiveBayesModel")
87
+
88
+ prior = {}
89
+ node.css("BayesOutput TargetValueCount").each do |n|
90
+ prior[n.attribute("value").value] = n.attribute("count").value.to_f
91
+ end
92
+
93
+ legacy = false
94
+
95
+ conditional = {}
96
+ features = {}
97
+ node.css("BayesInput").each do |n|
98
+ prob = {}
99
+
100
+ # numeric
101
+ n.css("TargetValueStat").each do |n2|
102
+ n3 = n2.css("GaussianDistribution")
103
+ prob[n2.attribute("value").value] = {
104
+ mean: n3.attribute("mean").value.to_f,
105
+ stdev: Math.sqrt(n3.attribute("variance").value.to_f)
106
+ }
107
+ end
108
+
109
+ # detect bad form in Eps < 0.3
110
+ bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
111
+
112
+ # categorical
113
+ n.css("PairCounts").each do |n2|
114
+ if bad_format
115
+ n2.css("TargetValueCount").each do |n3|
116
+ prob[n3.attribute("value").value] ||= {}
117
+ prob[n3.attribute("value").value][n2.attribute("value").value] = n3.attribute("count").value.to_f
118
+ end
119
+ else
120
+ boom = {}
121
+ n2.css("TargetValueCount").each do |n3|
122
+ boom[n3.attribute("value").value] = n3.attribute("count").value.to_f
123
+ end
124
+ prob[n2.attribute("value").value] = boom
125
+ end
126
+ end
127
+
128
+ if bad_format
129
+ legacy = true
130
+ prob.each do |k, v|
131
+ prior.keys.each do |k|
132
+ v[k] ||= 0.0
133
+ end
134
+ end
135
+ end
136
+
137
+ name = n.attribute("fieldName").value
138
+ conditional[name] = prob
139
+ features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
140
+ end
141
+
142
+ target = node.css("BayesOutput").attribute("fieldName").value
143
+
144
+ probabilities = {
145
+ prior: prior,
146
+ conditional: conditional
147
+ }
148
+
149
+ # get derived fields
150
+ derived = {}
151
+ data.css("DerivedField").each do |n|
152
+ name = n.attribute("name").value
153
+ field = n.css("NormDiscrete").attribute("field").value
154
+ value = n.css("NormDiscrete").attribute("value").value
155
+ features.delete(name)
156
+ features[field] = "derived"
157
+ derived[field] ||= {}
158
+ derived[field][name] = value
159
+ end
160
+
161
+ Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
162
+ end
163
+
164
+ def extract_text_features(data, features)
165
+ # updates features object
166
+ vocabulary = {}
167
+ function_mapping = {}
168
+ derived_fields = {}
169
+ data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n|
170
+ name = n.attribute("name")&.value
171
+ field = n.css("FieldRef").attribute("field").value
172
+ value = n.css("Constant").text
173
+
174
+ field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/
175
+ next if value.empty?
176
+
177
+ (vocabulary[field] ||= []) << value
178
+
179
+ function_mapping[field] = n.css("Apply").attribute("function").value
180
+
181
+ derived_fields[name] = [field, value]
182
+ end
183
+
184
+ functions = {}
185
+ data.css("TransformationDictionary DefineFunction").each do |n|
186
+ name = n.attribute("name").value
187
+ text_index = n.css("TextIndex")
188
+ functions[name] = {
189
+ tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value),
190
+ case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true"
191
+ }
192
+ end
193
+
194
+ text_features = {}
195
+ function_mapping.each do |field, function|
196
+ text_features[field] = functions[function].merge(vocabulary: vocabulary[field])
197
+ features[field] = "text"
198
+ end
199
+
200
+ [text_features, derived_fields]
201
+ end
202
+
203
+ def find_nodes(xml, derived_fields)
204
+ score = xml.attribute("score").value.to_f
205
+
206
+ elements = xml.elements
207
+ xml_predicate = elements.first
208
+
209
+ predicate =
210
+ if xml_predicate.name == "True"
211
+ nil
212
+ elsif xml_predicate.name == "SimpleSetPredicate"
213
+ operator = "in"
214
+ value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
215
+ field = xml_predicate.attribute("field").value
216
+ field = derived_fields[field] if derived_fields[field]
217
+ {
218
+ field: field,
219
+ operator: operator,
220
+ value: value
221
+ }
222
+ else
223
+ operator = xml_predicate.attribute("operator").value
224
+ value = xml_predicate.attribute("value").value
225
+ value = value.to_f if operator == "greaterThan"
226
+ field = xml_predicate.attribute("field").value
227
+ field = derived_fields[field] if derived_fields[field]
228
+ {
229
+ field: field,
230
+ operator: operator,
231
+ value: value
232
+ }
233
+ end
234
+
235
+ children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
236
+
237
+ Evaluators::Node.new(score: score, predicate: predicate, children: children)
238
+ end
239
+ end
240
+ end
241
+ end