eps 0.3.0 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,241 @@
1
+ module Eps
2
+ module PMML
3
+ class Loader
4
+ attr_reader :data
5
+
6
+ def initialize(pmml)
7
+ if pmml.is_a?(String)
8
+ pmml = Nokogiri::XML(pmml) { |config| config.strict }
9
+ end
10
+ @data = pmml
11
+ end
12
+
13
+ def load
14
+ if data.css("Segmentation").any?
15
+ lightgbm
16
+ elsif data.css("RegressionModel").any?
17
+ linear_regression
18
+ elsif data.css("NaiveBayesModel").any?
19
+ naive_bayes
20
+ else
21
+ raise "Unknown model"
22
+ end
23
+ end
24
+
25
+ private
26
+
27
+ def lightgbm
28
+ objective = data.css("MiningModel").first.attribute("functionName").value
29
+ if objective == "classification"
30
+ labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
31
+ objective = labels.size > 2 ? "multiclass" : "binary"
32
+ end
33
+
34
+ features = {}
35
+ text_features, derived_fields = extract_text_features(data, features)
36
+ node = data.css("DataDictionary").first
37
+ node.css("DataField")[1..-1].to_a.each do |node|
38
+ features[node.attribute("name").value] =
39
+ if node.attribute("optype").value == "categorical"
40
+ "categorical"
41
+ else
42
+ "numeric"
43
+ end
44
+ end
45
+
46
+ trees = []
47
+ data.css("Segmentation TreeModel").each do |tree|
48
+ node = find_nodes(tree.css("Node").first, derived_fields)
49
+ trees << node
50
+ end
51
+
52
+ Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
53
+ end
54
+
55
+ def linear_regression
56
+ node = data.css("RegressionTable")
57
+
58
+ coefficients = {
59
+ "_intercept" => node.attribute("intercept").value.to_f
60
+ }
61
+
62
+ features = {}
63
+
64
+ text_features, derived_fields = extract_text_features(data, features)
65
+
66
+ node.css("NumericPredictor").each do |n|
67
+ name = n.attribute("name").value
68
+ if derived_fields[name]
69
+ name = derived_fields[name]
70
+ else
71
+ features[name] = "numeric"
72
+ end
73
+ coefficients[name] = n.attribute("coefficient").value.to_f
74
+ end
75
+
76
+ node.css("CategoricalPredictor").each do |n|
77
+ name = n.attribute("name").value
78
+ coefficients[[name, n.attribute("value").value]] = n.attribute("coefficient").value.to_f
79
+ features[name] = "categorical"
80
+ end
81
+
82
+ Evaluators::LinearRegression.new(coefficients: coefficients, features: features, text_features: text_features)
83
+ end
84
+
85
+ def naive_bayes
86
+ node = data.css("NaiveBayesModel")
87
+
88
+ prior = {}
89
+ node.css("BayesOutput TargetValueCount").each do |n|
90
+ prior[n.attribute("value").value] = n.attribute("count").value.to_f
91
+ end
92
+
93
+ legacy = false
94
+
95
+ conditional = {}
96
+ features = {}
97
+ node.css("BayesInput").each do |n|
98
+ prob = {}
99
+
100
+ # numeric
101
+ n.css("TargetValueStat").each do |n2|
102
+ n3 = n2.css("GaussianDistribution")
103
+ prob[n2.attribute("value").value] = {
104
+ mean: n3.attribute("mean").value.to_f,
105
+ stdev: Math.sqrt(n3.attribute("variance").value.to_f)
106
+ }
107
+ end
108
+
109
+ # detect bad form in Eps < 0.3
110
+ bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
111
+
112
+ # categorical
113
+ n.css("PairCounts").each do |n2|
114
+ if bad_format
115
+ n2.css("TargetValueCount").each do |n3|
116
+ prob[n3.attribute("value").value] ||= {}
117
+ prob[n3.attribute("value").value][n2.attribute("value").value] = n3.attribute("count").value.to_f
118
+ end
119
+ else
120
+ boom = {}
121
+ n2.css("TargetValueCount").each do |n3|
122
+ boom[n3.attribute("value").value] = n3.attribute("count").value.to_f
123
+ end
124
+ prob[n2.attribute("value").value] = boom
125
+ end
126
+ end
127
+
128
+ if bad_format
129
+ legacy = true
130
+ prob.each do |k, v|
131
+ prior.keys.each do |k|
132
+ v[k] ||= 0.0
133
+ end
134
+ end
135
+ end
136
+
137
+ name = n.attribute("fieldName").value
138
+ conditional[name] = prob
139
+ features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
140
+ end
141
+
142
+ target = node.css("BayesOutput").attribute("fieldName").value
143
+
144
+ probabilities = {
145
+ prior: prior,
146
+ conditional: conditional
147
+ }
148
+
149
+ # get derived fields
150
+ derived = {}
151
+ data.css("DerivedField").each do |n|
152
+ name = n.attribute("name").value
153
+ field = n.css("NormDiscrete").attribute("field").value
154
+ value = n.css("NormDiscrete").attribute("value").value
155
+ features.delete(name)
156
+ features[field] = "derived"
157
+ derived[field] ||= {}
158
+ derived[field][name] = value
159
+ end
160
+
161
+ Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
162
+ end
163
+
164
+ def extract_text_features(data, features)
165
+ # updates features object
166
+ vocabulary = {}
167
+ function_mapping = {}
168
+ derived_fields = {}
169
+ data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n|
170
+ name = n.attribute("name")&.value
171
+ field = n.css("FieldRef").attribute("field").value
172
+ value = n.css("Constant").text
173
+
174
+ field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/
175
+ next if value.empty?
176
+
177
+ (vocabulary[field] ||= []) << value
178
+
179
+ function_mapping[field] = n.css("Apply").attribute("function").value
180
+
181
+ derived_fields[name] = [field, value]
182
+ end
183
+
184
+ functions = {}
185
+ data.css("TransformationDictionary DefineFunction").each do |n|
186
+ name = n.attribute("name").value
187
+ text_index = n.css("TextIndex")
188
+ functions[name] = {
189
+ tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value),
190
+ case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true"
191
+ }
192
+ end
193
+
194
+ text_features = {}
195
+ function_mapping.each do |field, function|
196
+ text_features[field] = functions[function].merge(vocabulary: vocabulary[field])
197
+ features[field] = "text"
198
+ end
199
+
200
+ [text_features, derived_fields]
201
+ end
202
+
203
+ def find_nodes(xml, derived_fields)
204
+ score = xml.attribute("score").value.to_f
205
+
206
+ elements = xml.elements
207
+ xml_predicate = elements.first
208
+
209
+ predicate =
210
+ if xml_predicate.name == "True"
211
+ nil
212
+ elsif xml_predicate.name == "SimpleSetPredicate"
213
+ operator = "in"
214
+ value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
215
+ field = xml_predicate.attribute("field").value
216
+ field = derived_fields[field] if derived_fields[field]
217
+ {
218
+ field: field,
219
+ operator: operator,
220
+ value: value
221
+ }
222
+ else
223
+ operator = xml_predicate.attribute("operator").value
224
+ value = xml_predicate.attribute("value").value
225
+ value = value.to_f if operator == "greaterThan"
226
+ field = xml_predicate.attribute("field").value
227
+ field = derived_fields[field] if derived_fields[field]
228
+ {
229
+ field: field,
230
+ operator: operator,
231
+ value: value
232
+ }
233
+ end
234
+
235
+ children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
236
+
237
+ Evaluators::Node.new(score: score, predicate: predicate, children: children)
238
+ end
239
+ end
240
+ end
241
+ end
@@ -1,3 +1,3 @@
1
1
  module Eps
2
- VERSION = "0.3.0"
2
+ VERSION = "0.3.5"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: eps
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.3.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-09-05 00:00:00.000000000 Z
11
+ date: 2020-06-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: lightgbm
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.1.5
19
+ version: 0.1.7
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.1.5
26
+ version: 0.1.7
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: nokogiri
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -80,6 +80,20 @@ dependencies:
80
80
  - - ">="
81
81
  - !ruby/object:Gem::Version
82
82
  version: '0'
83
+ - !ruby/object:Gem::Dependency
84
+ name: numo-narray
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: '0'
83
97
  - !ruby/object:Gem::Dependency
84
98
  name: rake
85
99
  requirement: !ruby/object:Gem::Requirement
@@ -94,6 +108,20 @@ dependencies:
94
108
  - - ">="
95
109
  - !ruby/object:Gem::Version
96
110
  version: '0'
111
+ - !ruby/object:Gem::Dependency
112
+ name: rover-df
113
+ requirement: !ruby/object:Gem::Requirement
114
+ requirements:
115
+ - - ">="
116
+ - !ruby/object:Gem::Version
117
+ version: '0'
118
+ type: :development
119
+ prerelease: false
120
+ version_requirements: !ruby/object:Gem::Requirement
121
+ requirements:
122
+ - - ">="
123
+ - !ruby/object:Gem::Version
124
+ version: '0'
97
125
  description:
98
126
  email: andrew@chartkick.com
99
127
  executables: []
@@ -117,7 +145,9 @@ files:
117
145
  - lib/eps/metrics.rb
118
146
  - lib/eps/model.rb
119
147
  - lib/eps/naive_bayes.rb
120
- - lib/eps/pmml_generators/lightgbm.rb
148
+ - lib/eps/pmml.rb
149
+ - lib/eps/pmml/generator.rb
150
+ - lib/eps/pmml/loader.rb
121
151
  - lib/eps/statistics.rb
122
152
  - lib/eps/text_encoder.rb
123
153
  - lib/eps/utils.rb
@@ -141,7 +171,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
141
171
  - !ruby/object:Gem::Version
142
172
  version: '0'
143
173
  requirements: []
144
- rubygems_version: 3.0.3
174
+ rubygems_version: 3.1.2
145
175
  signing_key:
146
176
  specification_version: 4
147
177
  summary: Machine learning for Ruby. Supports regression (linear regression) and classification
@@ -1,187 +0,0 @@
1
- module Eps
2
- module PmmlGenerators
3
- module LightGBM
4
- private
5
-
6
- def generate_pmml
7
- feature_importance = @feature_importance
8
-
9
- data_fields = {}
10
- data_fields[@target] = @labels if @labels
11
- @features.each_with_index do |(k, type), i|
12
- # TODO remove zero importance features
13
- if type == "categorical"
14
- data_fields[k] = @label_encoders[k].labels.keys
15
- else
16
- data_fields[k] = nil
17
- end
18
- end
19
-
20
- build_pmml(data_fields) do |xml|
21
- function_name = @objective == "regression" ? "regression" : "classification"
22
- xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
23
- xml.MiningSchema do
24
- xml.MiningField(name: @target, usageType: "target")
25
- @features.keys.each_with_index do |k, i|
26
- # next if feature_importance[i] == 0
27
- # TODO add importance, but need to handle text features
28
- xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
29
- end
30
- end
31
- pmml_local_transformations(xml)
32
-
33
- case @objective
34
- when "regression"
35
- xml_segmentation(xml, @trees)
36
- when "binary"
37
- xml.Segmentation(multipleModelMethod: "modelChain") do
38
- xml.Segment(id: 1) do
39
- xml.True
40
- xml.MiningModel(functionName: "regression") do
41
- xml.MiningSchema do
42
- @features.each do |k, _|
43
- xml.MiningField(name: k)
44
- end
45
- end
46
- xml.Output do
47
- xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
48
- xml.Apply(function: "/") do
49
- xml.Constant(dataType: "double") do
50
- 1.0
51
- end
52
- xml.Apply(function: "+") do
53
- xml.Constant(dataType: "double") do
54
- 1.0
55
- end
56
- xml.Apply(function: "exp") do
57
- xml.Apply(function: "*") do
58
- xml.Constant(dataType: "double") do
59
- -1.0
60
- end
61
- xml.FieldRef(field: "lgbmValue")
62
- end
63
- end
64
- end
65
- end
66
- end
67
- end
68
- xml_segmentation(xml, @trees)
69
- end
70
- end
71
- xml.Segment(id: 2) do
72
- xml.True
73
- xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
74
- xml.MiningSchema do
75
- xml.MiningField(name: @target, usageType: "target")
76
- xml.MiningField(name: "transformedLgbmValue")
77
- end
78
- xml.Output do
79
- @labels.each do |label|
80
- xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
81
- end
82
- end
83
- xml.RegressionTable(intercept: 0.0, targetCategory: @labels.last) do
84
- xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
85
- end
86
- xml.RegressionTable(intercept: 0.0, targetCategory: @labels.first)
87
- end
88
- end
89
- end
90
- else # multiclass
91
- xml.Segmentation(multipleModelMethod: "modelChain") do
92
- n = @trees.size / @labels.size
93
- @trees.each_slice(n).each_with_index do |trees, idx|
94
- xml.Segment(id: idx + 1) do
95
- xml.True
96
- xml.MiningModel(functionName: "regression") do
97
- xml.MiningSchema do
98
- @features.each do |k, _|
99
- xml.MiningField(name: k)
100
- end
101
- end
102
- xml.Output do
103
- xml.OutputField(name: "lgbmValue(#{@labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
104
- end
105
- xml_segmentation(xml, trees)
106
- end
107
- end
108
- end
109
- xml.Segment(id: @labels.size + 1) do
110
- xml.True
111
- xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
112
- xml.MiningSchema do
113
- xml.MiningField(name: @target, usageType: "target")
114
- @labels.each do |label|
115
- xml.MiningField(name: "lgbmValue(#{label})")
116
- end
117
- end
118
- xml.Output do
119
- @labels.each do |label|
120
- xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
121
- end
122
- end
123
- @labels.each do |label|
124
- xml.RegressionTable(intercept: 0.0, targetCategory: label) do
125
- xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
126
- end
127
- end
128
- end
129
- end
130
- end
131
- end
132
- end
133
- end
134
- end
135
-
136
- def xml_segmentation(xml, trees)
137
- xml.Segmentation(multipleModelMethod: "sum") do
138
- trees.each_with_index do |node, i|
139
- xml.Segment(id: i + 1) do
140
- xml.True
141
- xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
142
- xml.MiningSchema do
143
- node_fields(node).uniq.each do |k|
144
- xml.MiningField(name: display_field(k))
145
- end
146
- end
147
- node_pmml(node, xml)
148
- end
149
- end
150
- end
151
- end
152
- end
153
-
154
- def node_fields(node)
155
- fields = []
156
- fields << node.field if node.predicate
157
- node.children.each do |n|
158
- fields.concat(node_fields(n))
159
- end
160
- fields
161
- end
162
-
163
- def node_pmml(node, xml)
164
- xml.Node(score: node.score) do
165
- if node.predicate.nil?
166
- xml.True
167
- elsif node.operator == "in"
168
- xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
169
- xml.Array(type: "string") do
170
- xml.text node.value.map { |v| escape_element(v) }.join(" ")
171
- end
172
- end
173
- else
174
- xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
175
- end
176
- node.children.each do |n|
177
- node_pmml(n, xml)
178
- end
179
- end
180
- end
181
-
182
- def escape_element(v)
183
- "\"#{v.gsub("\"", "\\\"")}\""
184
- end
185
- end
186
- end
187
- end