eps 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ # similar to Marshal/JSON/YAML interface
2
+ module Eps
3
+ module PMML
4
+ class << self
5
+ def load(pmml)
6
+ Loader.new(pmml).load
7
+ end
8
+
9
+ def generate(model)
10
+ Generator.new(model).generate
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,422 @@
1
+ module Eps
2
+ module PMML
3
+ class Generator
4
+ attr_reader :model
5
+
6
+ def initialize(model)
7
+ @model = model
8
+ end
9
+
10
+ def generate
11
+ case @model
12
+ when LightGBM
13
+ lightgbm
14
+ when LinearRegression
15
+ linear_regression
16
+ when NaiveBayes
17
+ naive_bayes
18
+ else
19
+ raise "Unknown model"
20
+ end
21
+ end
22
+
23
+ private
24
+
25
+ def lightgbm
26
+ data_fields = {}
27
+ data_fields[target] = labels if labels
28
+ features.each_with_index do |(k, type), i|
29
+ # TODO remove zero importance features
30
+ if type == "categorical"
31
+ data_fields[k] = label_encoders[k].labels.keys
32
+ else
33
+ data_fields[k] = nil
34
+ end
35
+ end
36
+
37
+ build_pmml(data_fields) do |xml|
38
+ function_name = objective == "regression" ? "regression" : "classification"
39
+ xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
40
+ xml.MiningSchema do
41
+ xml.MiningField(name: target, usageType: "target")
42
+ features.keys.each_with_index do |k, i|
43
+ # next if feature_importance[i] == 0
44
+ # TODO add importance, but need to handle text features
45
+ xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
46
+ end
47
+ end
48
+ pmml_local_transformations(xml)
49
+
50
+ case objective
51
+ when "regression"
52
+ xml_segmentation(xml, trees)
53
+ when "binary"
54
+ xml.Segmentation(multipleModelMethod: "modelChain") do
55
+ xml.Segment(id: 1) do
56
+ xml.True
57
+ xml.MiningModel(functionName: "regression") do
58
+ xml.MiningSchema do
59
+ features.each do |k, _|
60
+ xml.MiningField(name: k)
61
+ end
62
+ end
63
+ xml.Output do
64
+ xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
65
+ xml.Apply(function: "/") do
66
+ xml.Constant(dataType: "double") do
67
+ 1.0
68
+ end
69
+ xml.Apply(function: "+") do
70
+ xml.Constant(dataType: "double") do
71
+ 1.0
72
+ end
73
+ xml.Apply(function: "exp") do
74
+ xml.Apply(function: "*") do
75
+ xml.Constant(dataType: "double") do
76
+ -1.0
77
+ end
78
+ xml.FieldRef(field: "lgbmValue")
79
+ end
80
+ end
81
+ end
82
+ end
83
+ end
84
+ end
85
+ xml_segmentation(xml, trees)
86
+ end
87
+ end
88
+ xml.Segment(id: 2) do
89
+ xml.True
90
+ xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
91
+ xml.MiningSchema do
92
+ xml.MiningField(name: target, usageType: "target")
93
+ xml.MiningField(name: "transformedLgbmValue")
94
+ end
95
+ xml.Output do
96
+ labels.each do |label|
97
+ xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
98
+ end
99
+ end
100
+ xml.RegressionTable(intercept: 0.0, targetCategory: labels.last) do
101
+ xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
102
+ end
103
+ xml.RegressionTable(intercept: 0.0, targetCategory: labels.first)
104
+ end
105
+ end
106
+ end
107
+ else # multiclass
108
+ xml.Segmentation(multipleModelMethod: "modelChain") do
109
+ n = trees.size / labels.size
110
+ trees.each_slice(n).each_with_index do |trees, idx|
111
+ xml.Segment(id: idx + 1) do
112
+ xml.True
113
+ xml.MiningModel(functionName: "regression") do
114
+ xml.MiningSchema do
115
+ features.each do |k, _|
116
+ xml.MiningField(name: k)
117
+ end
118
+ end
119
+ xml.Output do
120
+ xml.OutputField(name: "lgbmValue(#{labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
121
+ end
122
+ xml_segmentation(xml, trees)
123
+ end
124
+ end
125
+ end
126
+ xml.Segment(id: labels.size + 1) do
127
+ xml.True
128
+ xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
129
+ xml.MiningSchema do
130
+ xml.MiningField(name: target, usageType: "target")
131
+ labels.each do |label|
132
+ xml.MiningField(name: "lgbmValue(#{label})")
133
+ end
134
+ end
135
+ xml.Output do
136
+ labels.each do |label|
137
+ xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
138
+ end
139
+ end
140
+ labels.each do |label|
141
+ xml.RegressionTable(intercept: 0.0, targetCategory: label) do
142
+ xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
143
+ end
144
+ end
145
+ end
146
+ end
147
+ end
148
+ end
149
+ end
150
+ end
151
+ end
152
+
153
+ def linear_regression
154
+ predictors = model.instance_variable_get("@coefficients").dup
155
+ intercept = predictors.delete("_intercept") || 0.0
156
+
157
+ data_fields = {}
158
+ features.each do |k, type|
159
+ if type == "categorical"
160
+ data_fields[k] = predictors.keys.select { |k, v| k.is_a?(Array) && k.first == k }.map(&:last)
161
+ else
162
+ data_fields[k] = nil
163
+ end
164
+ end
165
+
166
+ build_pmml(data_fields) do |xml|
167
+ xml.RegressionModel(functionName: "regression") do
168
+ xml.MiningSchema do
169
+ features.each do |k, _|
170
+ xml.MiningField(name: k)
171
+ end
172
+ end
173
+ pmml_local_transformations(xml)
174
+ xml.RegressionTable(intercept: intercept) do
175
+ predictors.each do |k, v|
176
+ if k.is_a?(Array)
177
+ if features[k.first] == "text"
178
+ xml.NumericPredictor(name: display_field(k), coefficient: v)
179
+ else
180
+ xml.CategoricalPredictor(name: k[0], value: k[1], coefficient: v)
181
+ end
182
+ else
183
+ xml.NumericPredictor(name: k, coefficient: v)
184
+ end
185
+ end
186
+ end
187
+ end
188
+ end
189
+ end
190
+
191
+ def naive_bayes
192
+ data_fields = {}
193
+ data_fields[target] = probabilities[:prior].keys
194
+ probabilities[:conditional].each do |k, v|
195
+ if features[k] == "categorical"
196
+ data_fields[k] = v.keys
197
+ else
198
+ data_fields[k] = nil
199
+ end
200
+ end
201
+
202
+ build_pmml(data_fields) do |xml|
203
+ xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
204
+ xml.MiningSchema do
205
+ data_fields.each do |k, _|
206
+ xml.MiningField(name: k)
207
+ end
208
+ end
209
+ xml.BayesInputs do
210
+ probabilities[:conditional].each do |k, v|
211
+ xml.BayesInput(fieldName: k) do
212
+ if features[k] == "categorical"
213
+ v.sort_by { |k2, _| k2 }.each do |k2, v2|
214
+ xml.PairCounts(value: k2) do
215
+ xml.TargetValueCounts do
216
+ v2.sort_by { |k2, _| k2 }.each do |k3, v3|
217
+ xml.TargetValueCount(value: k3, count: v3)
218
+ end
219
+ end
220
+ end
221
+ end
222
+ else
223
+ xml.TargetValueStats do
224
+ v.sort_by { |k2, _| k2 }.each do |k2, v2|
225
+ xml.TargetValueStat(value: k2) do
226
+ xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
227
+ end
228
+ end
229
+ end
230
+ end
231
+ end
232
+ end
233
+ end
234
+ xml.BayesOutput(fieldName: "target") do
235
+ xml.TargetValueCounts do
236
+ probabilities[:prior].sort_by { |k, _| k }.each do |k, v|
237
+ xml.TargetValueCount(value: k, count: v)
238
+ end
239
+ end
240
+ end
241
+ end
242
+ end
243
+ end
244
+
245
+ def display_field(k)
246
+ if k.is_a?(Array)
247
+ if features[k.first] == "text"
248
+ "#{k.first}(#{k.last})"
249
+ else
250
+ k.join("=")
251
+ end
252
+ else
253
+ k
254
+ end
255
+ end
256
+
257
+ def xml_segmentation(xml, trees)
258
+ xml.Segmentation(multipleModelMethod: "sum") do
259
+ trees.each_with_index do |node, i|
260
+ xml.Segment(id: i + 1) do
261
+ xml.True
262
+ xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
263
+ xml.MiningSchema do
264
+ node_fields(node).uniq.each do |k|
265
+ xml.MiningField(name: display_field(k))
266
+ end
267
+ end
268
+ node_pmml(node, xml)
269
+ end
270
+ end
271
+ end
272
+ end
273
+ end
274
+
275
+ def node_fields(node)
276
+ fields = []
277
+ fields << node.field if node.predicate
278
+ node.children.each do |n|
279
+ fields.concat(node_fields(n))
280
+ end
281
+ fields
282
+ end
283
+
284
+ def node_pmml(node, xml)
285
+ xml.Node(score: node.score) do
286
+ if node.predicate.nil?
287
+ xml.True
288
+ elsif node.operator == "in"
289
+ xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
290
+ xml.Array(type: "string") do
291
+ xml.text node.value.map { |v| escape_element(v) }.join(" ")
292
+ end
293
+ end
294
+ else
295
+ xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
296
+ end
297
+ node.children.each do |n|
298
+ node_pmml(n, xml)
299
+ end
300
+ end
301
+ end
302
+
303
+ def escape_element(v)
304
+ "\"#{v.gsub("\"", "\\\"")}\""
305
+ end
306
+
307
+ def build_pmml(data_fields)
308
+ Nokogiri::XML::Builder.new do |xml|
309
+ xml.PMML(version: "4.4", xmlns: "http://www.dmg.org/PMML-4_4", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do
310
+ pmml_header(xml)
311
+ pmml_data_dictionary(xml, data_fields)
312
+ pmml_transformation_dictionary(xml)
313
+ yield xml
314
+ end
315
+ end.to_xml
316
+ end
317
+
318
+ def pmml_header(xml)
319
+ xml.Header do
320
+ xml.Application(name: "Eps", version: Eps::VERSION)
321
+ # xml.Timestamp Time.now.utc.iso8601
322
+ end
323
+ end
324
+
325
+ def pmml_data_dictionary(xml, data_fields)
326
+ xml.DataDictionary do
327
+ data_fields.each do |k, vs|
328
+ case features[k]
329
+ when "categorical", nil
330
+ xml.DataField(name: k, optype: "categorical", dataType: "string") do
331
+ vs.map(&:to_s).sort.each do |v|
332
+ xml.Value(value: v)
333
+ end
334
+ end
335
+ when "text"
336
+ xml.DataField(name: k, optype: "categorical", dataType: "string")
337
+ else
338
+ xml.DataField(name: k, optype: "continuous", dataType: "double")
339
+ end
340
+ end
341
+ end
342
+ end
343
+
344
+ def pmml_transformation_dictionary(xml)
345
+ if text_features.any?
346
+ xml.TransformationDictionary do
347
+ text_features.each do |k, text_options|
348
+ xml.DefineFunction(name: "#{k}Transform", optype: "continuous") do
349
+ xml.ParameterField(name: "text")
350
+ xml.ParameterField(name: "term")
351
+ xml.TextIndex(textField: "text", localTermWeights: "termFrequency", wordSeparatorCharacterRE: text_options[:tokenizer].source, isCaseSensitive: !!text_options[:case_sensitive]) do
352
+ xml.FieldRef(field: "term")
353
+ end
354
+ end
355
+ end
356
+ end
357
+ end
358
+ end
359
+
360
+ def pmml_local_transformations(xml)
361
+ if text_features.any?
362
+ xml.LocalTransformations do
363
+ text_features.each do |k, _|
364
+ text_encoders[k].vocabulary.each do |v|
365
+ xml.DerivedField(name: display_field([k, v]), optype: "continuous", dataType: "integer") do
366
+ xml.Apply(function: "#{k}Transform") do
367
+ xml.FieldRef(field: k)
368
+ xml.Constant v
369
+ end
370
+ end
371
+ end
372
+ end
373
+ end
374
+ end
375
+ end
376
+
377
+ # TODO create instance methods on model for all of these features
378
+
379
+ def features
380
+ model.instance_variable_get("@features")
381
+ end
382
+
383
+ def text_features
384
+ model.instance_variable_get("@text_features")
385
+ end
386
+
387
+ def text_encoders
388
+ model.instance_variable_get("@text_encoders")
389
+ end
390
+
391
+ def feature_importance
392
+ model.instance_variable_get("@feature_importance")
393
+ end
394
+
395
+ def labels
396
+ model.instance_variable_get("@labels")
397
+ end
398
+
399
+ def trees
400
+ model.instance_variable_get("@trees")
401
+ end
402
+
403
+ def target
404
+ model.instance_variable_get("@target")
405
+ end
406
+
407
+ def label_encoders
408
+ model.instance_variable_get("@label_encoders")
409
+ end
410
+
411
+ def objective
412
+ model.instance_variable_get("@objective")
413
+ end
414
+
415
+ def probabilities
416
+ model.instance_variable_get("@probabilities")
417
+ end
418
+
419
+ # end TODO
420
+ end
421
+ end
422
+ end
@@ -0,0 +1,241 @@
1
+ module Eps
2
+ module PMML
3
+ class Loader
4
+ attr_reader :data
5
+
6
+ def initialize(pmml)
7
+ if pmml.is_a?(String)
8
+ pmml = Nokogiri::XML(pmml) { |config| config.strict }
9
+ end
10
+ @data = pmml
11
+ end
12
+
13
+ def load
14
+ if data.css("Segmentation").any?
15
+ lightgbm
16
+ elsif data.css("RegressionModel").any?
17
+ linear_regression
18
+ elsif data.css("NaiveBayesModel").any?
19
+ naive_bayes
20
+ else
21
+ raise "Unknown model"
22
+ end
23
+ end
24
+
25
+ private
26
+
27
+ def lightgbm
28
+ objective = data.css("MiningModel").first.attribute("functionName").value
29
+ if objective == "classification"
30
+ labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
31
+ objective = labels.size > 2 ? "multiclass" : "binary"
32
+ end
33
+
34
+ features = {}
35
+ text_features, derived_fields = extract_text_features(data, features)
36
+ node = data.css("DataDictionary").first
37
+ node.css("DataField")[1..-1].to_a.each do |node|
38
+ features[node.attribute("name").value] =
39
+ if node.attribute("optype").value == "categorical"
40
+ "categorical"
41
+ else
42
+ "numeric"
43
+ end
44
+ end
45
+
46
+ trees = []
47
+ data.css("Segmentation TreeModel").each do |tree|
48
+ node = find_nodes(tree.css("Node").first, derived_fields)
49
+ trees << node
50
+ end
51
+
52
+ Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
53
+ end
54
+
55
+ def linear_regression
56
+ node = data.css("RegressionTable")
57
+
58
+ coefficients = {
59
+ "_intercept" => node.attribute("intercept").value.to_f
60
+ }
61
+
62
+ features = {}
63
+
64
+ text_features, derived_fields = extract_text_features(data, features)
65
+
66
+ node.css("NumericPredictor").each do |n|
67
+ name = n.attribute("name").value
68
+ if derived_fields[name]
69
+ name = derived_fields[name]
70
+ else
71
+ features[name] = "numeric"
72
+ end
73
+ coefficients[name] = n.attribute("coefficient").value.to_f
74
+ end
75
+
76
+ node.css("CategoricalPredictor").each do |n|
77
+ name = n.attribute("name").value
78
+ coefficients[[name, n.attribute("value").value]] = n.attribute("coefficient").value.to_f
79
+ features[name] = "categorical"
80
+ end
81
+
82
+ Evaluators::LinearRegression.new(coefficients: coefficients, features: features, text_features: text_features)
83
+ end
84
+
85
+ def naive_bayes
86
+ node = data.css("NaiveBayesModel")
87
+
88
+ prior = {}
89
+ node.css("BayesOutput TargetValueCount").each do |n|
90
+ prior[n.attribute("value").value] = n.attribute("count").value.to_f
91
+ end
92
+
93
+ legacy = false
94
+
95
+ conditional = {}
96
+ features = {}
97
+ node.css("BayesInput").each do |n|
98
+ prob = {}
99
+
100
+ # numeric
101
+ n.css("TargetValueStat").each do |n2|
102
+ n3 = n2.css("GaussianDistribution")
103
+ prob[n2.attribute("value").value] = {
104
+ mean: n3.attribute("mean").value.to_f,
105
+ stdev: Math.sqrt(n3.attribute("variance").value.to_f)
106
+ }
107
+ end
108
+
109
+ # detect bad form in Eps < 0.3
110
+ bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
111
+
112
+ # categorical
113
+ n.css("PairCounts").each do |n2|
114
+ if bad_format
115
+ n2.css("TargetValueCount").each do |n3|
116
+ prob[n3.attribute("value").value] ||= {}
117
+ prob[n3.attribute("value").value][n2.attribute("value").value] = n3.attribute("count").value.to_f
118
+ end
119
+ else
120
+ boom = {}
121
+ n2.css("TargetValueCount").each do |n3|
122
+ boom[n3.attribute("value").value] = n3.attribute("count").value.to_f
123
+ end
124
+ prob[n2.attribute("value").value] = boom
125
+ end
126
+ end
127
+
128
+ if bad_format
129
+ legacy = true
130
+ prob.each do |k, v|
131
+ prior.keys.each do |k|
132
+ v[k] ||= 0.0
133
+ end
134
+ end
135
+ end
136
+
137
+ name = n.attribute("fieldName").value
138
+ conditional[name] = prob
139
+ features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
140
+ end
141
+
142
+ target = node.css("BayesOutput").attribute("fieldName").value
143
+
144
+ probabilities = {
145
+ prior: prior,
146
+ conditional: conditional
147
+ }
148
+
149
+ # get derived fields
150
+ derived = {}
151
+ data.css("DerivedField").each do |n|
152
+ name = n.attribute("name").value
153
+ field = n.css("NormDiscrete").attribute("field").value
154
+ value = n.css("NormDiscrete").attribute("value").value
155
+ features.delete(name)
156
+ features[field] = "derived"
157
+ derived[field] ||= {}
158
+ derived[field][name] = value
159
+ end
160
+
161
+ Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
162
+ end
163
+
164
+ def extract_text_features(data, features)
165
+ # updates features object
166
+ vocabulary = {}
167
+ function_mapping = {}
168
+ derived_fields = {}
169
+ data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n|
170
+ name = n.attribute("name")&.value
171
+ field = n.css("FieldRef").attribute("field").value
172
+ value = n.css("Constant").text
173
+
174
+ field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/
175
+ next if value.empty?
176
+
177
+ (vocabulary[field] ||= []) << value
178
+
179
+ function_mapping[field] = n.css("Apply").attribute("function").value
180
+
181
+ derived_fields[name] = [field, value]
182
+ end
183
+
184
+ functions = {}
185
+ data.css("TransformationDictionary DefineFunction").each do |n|
186
+ name = n.attribute("name").value
187
+ text_index = n.css("TextIndex")
188
+ functions[name] = {
189
+ tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value),
190
+ case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true"
191
+ }
192
+ end
193
+
194
+ text_features = {}
195
+ function_mapping.each do |field, function|
196
+ text_features[field] = functions[function].merge(vocabulary: vocabulary[field])
197
+ features[field] = "text"
198
+ end
199
+
200
+ [text_features, derived_fields]
201
+ end
202
+
203
+ def find_nodes(xml, derived_fields)
204
+ score = xml.attribute("score").value.to_f
205
+
206
+ elements = xml.elements
207
+ xml_predicate = elements.first
208
+
209
+ predicate =
210
+ if xml_predicate.name == "True"
211
+ nil
212
+ elsif xml_predicate.name == "SimpleSetPredicate"
213
+ operator = "in"
214
+ value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
215
+ field = xml_predicate.attribute("field").value
216
+ field = derived_fields[field] if derived_fields[field]
217
+ {
218
+ field: field,
219
+ operator: operator,
220
+ value: value
221
+ }
222
+ else
223
+ operator = xml_predicate.attribute("operator").value
224
+ value = xml_predicate.attribute("value").value
225
+ value = value.to_f if operator == "greaterThan"
226
+ field = xml_predicate.attribute("field").value
227
+ field = derived_fields[field] if derived_fields[field]
228
+ {
229
+ field: field,
230
+ operator: operator,
231
+ value: value
232
+ }
233
+ end
234
+
235
+ children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
236
+
237
+ Evaluators::Node.new(score: score, predicate: predicate, children: children)
238
+ end
239
+ end
240
+ end
241
+ end