eps 0.3.0 → 0.3.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,241 @@
1
+ module Eps
2
+ module PMML
3
+ class Loader
4
+ attr_reader :data
5
+
6
+ def initialize(pmml)
7
+ if pmml.is_a?(String)
8
+ pmml = Nokogiri::XML(pmml) { |config| config.strict }
9
+ end
10
+ @data = pmml
11
+ end
12
+
13
+ def load
14
+ if data.css("Segmentation").any?
15
+ lightgbm
16
+ elsif data.css("RegressionModel").any?
17
+ linear_regression
18
+ elsif data.css("NaiveBayesModel").any?
19
+ naive_bayes
20
+ else
21
+ raise "Unknown model"
22
+ end
23
+ end
24
+
25
+ private
26
+
27
+ def lightgbm
28
+ objective = data.css("MiningModel").first.attribute("functionName").value
29
+ if objective == "classification"
30
+ labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
31
+ objective = labels.size > 2 ? "multiclass" : "binary"
32
+ end
33
+
34
+ features = {}
35
+ text_features, derived_fields = extract_text_features(data, features)
36
+ node = data.css("DataDictionary").first
37
+ node.css("DataField")[1..-1].to_a.each do |node|
38
+ features[node.attribute("name").value] =
39
+ if node.attribute("optype").value == "categorical"
40
+ "categorical"
41
+ else
42
+ "numeric"
43
+ end
44
+ end
45
+
46
+ trees = []
47
+ data.css("Segmentation TreeModel").each do |tree|
48
+ node = find_nodes(tree.css("Node").first, derived_fields)
49
+ trees << node
50
+ end
51
+
52
+ Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
53
+ end
54
+
55
+ def linear_regression
56
+ node = data.css("RegressionTable")
57
+
58
+ coefficients = {
59
+ "_intercept" => node.attribute("intercept").value.to_f
60
+ }
61
+
62
+ features = {}
63
+
64
+ text_features, derived_fields = extract_text_features(data, features)
65
+
66
+ node.css("NumericPredictor").each do |n|
67
+ name = n.attribute("name").value
68
+ if derived_fields[name]
69
+ name = derived_fields[name]
70
+ else
71
+ features[name] = "numeric"
72
+ end
73
+ coefficients[name] = n.attribute("coefficient").value.to_f
74
+ end
75
+
76
+ node.css("CategoricalPredictor").each do |n|
77
+ name = n.attribute("name").value
78
+ coefficients[[name, n.attribute("value").value]] = n.attribute("coefficient").value.to_f
79
+ features[name] = "categorical"
80
+ end
81
+
82
+ Evaluators::LinearRegression.new(coefficients: coefficients, features: features, text_features: text_features)
83
+ end
84
+
85
+ def naive_bayes
86
+ node = data.css("NaiveBayesModel")
87
+
88
+ prior = {}
89
+ node.css("BayesOutput TargetValueCount").each do |n|
90
+ prior[n.attribute("value").value] = n.attribute("count").value.to_f
91
+ end
92
+
93
+ legacy = false
94
+
95
+ conditional = {}
96
+ features = {}
97
+ node.css("BayesInput").each do |n|
98
+ prob = {}
99
+
100
+ # numeric
101
+ n.css("TargetValueStat").each do |n2|
102
+ n3 = n2.css("GaussianDistribution")
103
+ prob[n2.attribute("value").value] = {
104
+ mean: n3.attribute("mean").value.to_f,
105
+ stdev: Math.sqrt(n3.attribute("variance").value.to_f)
106
+ }
107
+ end
108
+
109
+ # detect bad form in Eps < 0.3
110
+ bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
111
+
112
+ # categorical
113
+ n.css("PairCounts").each do |n2|
114
+ if bad_format
115
+ n2.css("TargetValueCount").each do |n3|
116
+ prob[n3.attribute("value").value] ||= {}
117
+ prob[n3.attribute("value").value][n2.attribute("value").value] = n3.attribute("count").value.to_f
118
+ end
119
+ else
120
+ boom = {}
121
+ n2.css("TargetValueCount").each do |n3|
122
+ boom[n3.attribute("value").value] = n3.attribute("count").value.to_f
123
+ end
124
+ prob[n2.attribute("value").value] = boom
125
+ end
126
+ end
127
+
128
+ if bad_format
129
+ legacy = true
130
+ prob.each do |k, v|
131
+ prior.keys.each do |k|
132
+ v[k] ||= 0.0
133
+ end
134
+ end
135
+ end
136
+
137
+ name = n.attribute("fieldName").value
138
+ conditional[name] = prob
139
+ features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
140
+ end
141
+
142
+ target = node.css("BayesOutput").attribute("fieldName").value
143
+
144
+ probabilities = {
145
+ prior: prior,
146
+ conditional: conditional
147
+ }
148
+
149
+ # get derived fields
150
+ derived = {}
151
+ data.css("DerivedField").each do |n|
152
+ name = n.attribute("name").value
153
+ field = n.css("NormDiscrete").attribute("field").value
154
+ value = n.css("NormDiscrete").attribute("value").value
155
+ features.delete(name)
156
+ features[field] = "derived"
157
+ derived[field] ||= {}
158
+ derived[field][name] = value
159
+ end
160
+
161
+ Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
162
+ end
163
+
164
+ def extract_text_features(data, features)
165
+ # updates features object
166
+ vocabulary = {}
167
+ function_mapping = {}
168
+ derived_fields = {}
169
+ data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n|
170
+ name = n.attribute("name")&.value
171
+ field = n.css("FieldRef").attribute("field").value
172
+ value = n.css("Constant").text
173
+
174
+ field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/
175
+ next if value.empty?
176
+
177
+ (vocabulary[field] ||= []) << value
178
+
179
+ function_mapping[field] = n.css("Apply").attribute("function").value
180
+
181
+ derived_fields[name] = [field, value]
182
+ end
183
+
184
+ functions = {}
185
+ data.css("TransformationDictionary DefineFunction").each do |n|
186
+ name = n.attribute("name").value
187
+ text_index = n.css("TextIndex")
188
+ functions[name] = {
189
+ tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value),
190
+ case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true"
191
+ }
192
+ end
193
+
194
+ text_features = {}
195
+ function_mapping.each do |field, function|
196
+ text_features[field] = functions[function].merge(vocabulary: vocabulary[field])
197
+ features[field] = "text"
198
+ end
199
+
200
+ [text_features, derived_fields]
201
+ end
202
+
203
+ def find_nodes(xml, derived_fields)
204
+ score = xml.attribute("score").value.to_f
205
+
206
+ elements = xml.elements
207
+ xml_predicate = elements.first
208
+
209
+ predicate =
210
+ if xml_predicate.name == "True"
211
+ nil
212
+ elsif xml_predicate.name == "SimpleSetPredicate"
213
+ operator = "in"
214
+ value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
215
+ field = xml_predicate.attribute("field").value
216
+ field = derived_fields[field] if derived_fields[field]
217
+ {
218
+ field: field,
219
+ operator: operator,
220
+ value: value
221
+ }
222
+ else
223
+ operator = xml_predicate.attribute("operator").value
224
+ value = xml_predicate.attribute("value").value
225
+ value = value.to_f if operator == "greaterThan"
226
+ field = xml_predicate.attribute("field").value
227
+ field = derived_fields[field] if derived_fields[field]
228
+ {
229
+ field: field,
230
+ operator: operator,
231
+ value: value
232
+ }
233
+ end
234
+
235
+ children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
236
+
237
+ Evaluators::Node.new(score: score, predicate: predicate, children: children)
238
+ end
239
+ end
240
+ end
241
+ end
@@ -1,3 +1,3 @@
1
1
  module Eps
2
- VERSION = "0.3.0"
2
+ VERSION = "0.3.5"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: eps
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.3.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-09-05 00:00:00.000000000 Z
11
+ date: 2020-06-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: lightgbm
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.1.5
19
+ version: 0.1.7
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.1.5
26
+ version: 0.1.7
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: nokogiri
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -80,6 +80,20 @@ dependencies:
80
80
  - - ">="
81
81
  - !ruby/object:Gem::Version
82
82
  version: '0'
83
+ - !ruby/object:Gem::Dependency
84
+ name: numo-narray
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: '0'
83
97
  - !ruby/object:Gem::Dependency
84
98
  name: rake
85
99
  requirement: !ruby/object:Gem::Requirement
@@ -94,6 +108,20 @@ dependencies:
94
108
  - - ">="
95
109
  - !ruby/object:Gem::Version
96
110
  version: '0'
111
+ - !ruby/object:Gem::Dependency
112
+ name: rover-df
113
+ requirement: !ruby/object:Gem::Requirement
114
+ requirements:
115
+ - - ">="
116
+ - !ruby/object:Gem::Version
117
+ version: '0'
118
+ type: :development
119
+ prerelease: false
120
+ version_requirements: !ruby/object:Gem::Requirement
121
+ requirements:
122
+ - - ">="
123
+ - !ruby/object:Gem::Version
124
+ version: '0'
97
125
  description:
98
126
  email: andrew@chartkick.com
99
127
  executables: []
@@ -117,7 +145,9 @@ files:
117
145
  - lib/eps/metrics.rb
118
146
  - lib/eps/model.rb
119
147
  - lib/eps/naive_bayes.rb
120
- - lib/eps/pmml_generators/lightgbm.rb
148
+ - lib/eps/pmml.rb
149
+ - lib/eps/pmml/generator.rb
150
+ - lib/eps/pmml/loader.rb
121
151
  - lib/eps/statistics.rb
122
152
  - lib/eps/text_encoder.rb
123
153
  - lib/eps/utils.rb
@@ -141,7 +171,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
141
171
  - !ruby/object:Gem::Version
142
172
  version: '0'
143
173
  requirements: []
144
- rubygems_version: 3.0.3
174
+ rubygems_version: 3.1.2
145
175
  signing_key:
146
176
  specification_version: 4
147
177
  summary: Machine learning for Ruby. Supports regression (linear regression) and classification
@@ -1,187 +0,0 @@
1
- module Eps
2
- module PmmlGenerators
3
- module LightGBM
4
- private
5
-
6
- def generate_pmml
7
- feature_importance = @feature_importance
8
-
9
- data_fields = {}
10
- data_fields[@target] = @labels if @labels
11
- @features.each_with_index do |(k, type), i|
12
- # TODO remove zero importance features
13
- if type == "categorical"
14
- data_fields[k] = @label_encoders[k].labels.keys
15
- else
16
- data_fields[k] = nil
17
- end
18
- end
19
-
20
- build_pmml(data_fields) do |xml|
21
- function_name = @objective == "regression" ? "regression" : "classification"
22
- xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
23
- xml.MiningSchema do
24
- xml.MiningField(name: @target, usageType: "target")
25
- @features.keys.each_with_index do |k, i|
26
- # next if feature_importance[i] == 0
27
- # TODO add importance, but need to handle text features
28
- xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
29
- end
30
- end
31
- pmml_local_transformations(xml)
32
-
33
- case @objective
34
- when "regression"
35
- xml_segmentation(xml, @trees)
36
- when "binary"
37
- xml.Segmentation(multipleModelMethod: "modelChain") do
38
- xml.Segment(id: 1) do
39
- xml.True
40
- xml.MiningModel(functionName: "regression") do
41
- xml.MiningSchema do
42
- @features.each do |k, _|
43
- xml.MiningField(name: k)
44
- end
45
- end
46
- xml.Output do
47
- xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
48
- xml.Apply(function: "/") do
49
- xml.Constant(dataType: "double") do
50
- 1.0
51
- end
52
- xml.Apply(function: "+") do
53
- xml.Constant(dataType: "double") do
54
- 1.0
55
- end
56
- xml.Apply(function: "exp") do
57
- xml.Apply(function: "*") do
58
- xml.Constant(dataType: "double") do
59
- -1.0
60
- end
61
- xml.FieldRef(field: "lgbmValue")
62
- end
63
- end
64
- end
65
- end
66
- end
67
- end
68
- xml_segmentation(xml, @trees)
69
- end
70
- end
71
- xml.Segment(id: 2) do
72
- xml.True
73
- xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
74
- xml.MiningSchema do
75
- xml.MiningField(name: @target, usageType: "target")
76
- xml.MiningField(name: "transformedLgbmValue")
77
- end
78
- xml.Output do
79
- @labels.each do |label|
80
- xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
81
- end
82
- end
83
- xml.RegressionTable(intercept: 0.0, targetCategory: @labels.last) do
84
- xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
85
- end
86
- xml.RegressionTable(intercept: 0.0, targetCategory: @labels.first)
87
- end
88
- end
89
- end
90
- else # multiclass
91
- xml.Segmentation(multipleModelMethod: "modelChain") do
92
- n = @trees.size / @labels.size
93
- @trees.each_slice(n).each_with_index do |trees, idx|
94
- xml.Segment(id: idx + 1) do
95
- xml.True
96
- xml.MiningModel(functionName: "regression") do
97
- xml.MiningSchema do
98
- @features.each do |k, _|
99
- xml.MiningField(name: k)
100
- end
101
- end
102
- xml.Output do
103
- xml.OutputField(name: "lgbmValue(#{@labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
104
- end
105
- xml_segmentation(xml, trees)
106
- end
107
- end
108
- end
109
- xml.Segment(id: @labels.size + 1) do
110
- xml.True
111
- xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
112
- xml.MiningSchema do
113
- xml.MiningField(name: @target, usageType: "target")
114
- @labels.each do |label|
115
- xml.MiningField(name: "lgbmValue(#{label})")
116
- end
117
- end
118
- xml.Output do
119
- @labels.each do |label|
120
- xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
121
- end
122
- end
123
- @labels.each do |label|
124
- xml.RegressionTable(intercept: 0.0, targetCategory: label) do
125
- xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
126
- end
127
- end
128
- end
129
- end
130
- end
131
- end
132
- end
133
- end
134
- end
135
-
136
- def xml_segmentation(xml, trees)
137
- xml.Segmentation(multipleModelMethod: "sum") do
138
- trees.each_with_index do |node, i|
139
- xml.Segment(id: i + 1) do
140
- xml.True
141
- xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
142
- xml.MiningSchema do
143
- node_fields(node).uniq.each do |k|
144
- xml.MiningField(name: display_field(k))
145
- end
146
- end
147
- node_pmml(node, xml)
148
- end
149
- end
150
- end
151
- end
152
- end
153
-
154
- def node_fields(node)
155
- fields = []
156
- fields << node.field if node.predicate
157
- node.children.each do |n|
158
- fields.concat(node_fields(n))
159
- end
160
- fields
161
- end
162
-
163
- def node_pmml(node, xml)
164
- xml.Node(score: node.score) do
165
- if node.predicate.nil?
166
- xml.True
167
- elsif node.operator == "in"
168
- xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
169
- xml.Array(type: "string") do
170
- xml.text node.value.map { |v| escape_element(v) }.join(" ")
171
- end
172
- end
173
- else
174
- xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
175
- end
176
- node.children.each do |n|
177
- node_pmml(n, xml)
178
- end
179
- end
180
- end
181
-
182
- def escape_element(v)
183
- "\"#{v.gsub("\"", "\\\"")}\""
184
- end
185
- end
186
- end
187
- end