eps 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,3 +1,3 @@
1
1
  module Eps
2
- VERSION = "0.3.0"
2
+ VERSION = "0.3.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: eps
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.3.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-09-05 00:00:00.000000000 Z
11
+ date: 2019-12-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: lightgbm
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.1.5
19
+ version: 0.1.7
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.1.5
26
+ version: 0.1.7
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: nokogiri
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -117,7 +117,9 @@ files:
117
117
  - lib/eps/metrics.rb
118
118
  - lib/eps/model.rb
119
119
  - lib/eps/naive_bayes.rb
120
- - lib/eps/pmml_generators/lightgbm.rb
120
+ - lib/eps/pmml.rb
121
+ - lib/eps/pmml/generator.rb
122
+ - lib/eps/pmml/loader.rb
121
123
  - lib/eps/statistics.rb
122
124
  - lib/eps/text_encoder.rb
123
125
  - lib/eps/utils.rb
@@ -1,187 +0,0 @@
1
- module Eps
2
- module PmmlGenerators
3
- module LightGBM
4
- private
5
-
6
- def generate_pmml
7
- feature_importance = @feature_importance
8
-
9
- data_fields = {}
10
- data_fields[@target] = @labels if @labels
11
- @features.each_with_index do |(k, type), i|
12
- # TODO remove zero importance features
13
- if type == "categorical"
14
- data_fields[k] = @label_encoders[k].labels.keys
15
- else
16
- data_fields[k] = nil
17
- end
18
- end
19
-
20
- build_pmml(data_fields) do |xml|
21
- function_name = @objective == "regression" ? "regression" : "classification"
22
- xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
23
- xml.MiningSchema do
24
- xml.MiningField(name: @target, usageType: "target")
25
- @features.keys.each_with_index do |k, i|
26
- # next if feature_importance[i] == 0
27
- # TODO add importance, but need to handle text features
28
- xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
29
- end
30
- end
31
- pmml_local_transformations(xml)
32
-
33
- case @objective
34
- when "regression"
35
- xml_segmentation(xml, @trees)
36
- when "binary"
37
- xml.Segmentation(multipleModelMethod: "modelChain") do
38
- xml.Segment(id: 1) do
39
- xml.True
40
- xml.MiningModel(functionName: "regression") do
41
- xml.MiningSchema do
42
- @features.each do |k, _|
43
- xml.MiningField(name: k)
44
- end
45
- end
46
- xml.Output do
47
- xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
48
- xml.Apply(function: "/") do
49
- xml.Constant(dataType: "double") do
50
- 1.0
51
- end
52
- xml.Apply(function: "+") do
53
- xml.Constant(dataType: "double") do
54
- 1.0
55
- end
56
- xml.Apply(function: "exp") do
57
- xml.Apply(function: "*") do
58
- xml.Constant(dataType: "double") do
59
- -1.0
60
- end
61
- xml.FieldRef(field: "lgbmValue")
62
- end
63
- end
64
- end
65
- end
66
- end
67
- end
68
- xml_segmentation(xml, @trees)
69
- end
70
- end
71
- xml.Segment(id: 2) do
72
- xml.True
73
- xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
74
- xml.MiningSchema do
75
- xml.MiningField(name: @target, usageType: "target")
76
- xml.MiningField(name: "transformedLgbmValue")
77
- end
78
- xml.Output do
79
- @labels.each do |label|
80
- xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
81
- end
82
- end
83
- xml.RegressionTable(intercept: 0.0, targetCategory: @labels.last) do
84
- xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
85
- end
86
- xml.RegressionTable(intercept: 0.0, targetCategory: @labels.first)
87
- end
88
- end
89
- end
90
- else # multiclass
91
- xml.Segmentation(multipleModelMethod: "modelChain") do
92
- n = @trees.size / @labels.size
93
- @trees.each_slice(n).each_with_index do |trees, idx|
94
- xml.Segment(id: idx + 1) do
95
- xml.True
96
- xml.MiningModel(functionName: "regression") do
97
- xml.MiningSchema do
98
- @features.each do |k, _|
99
- xml.MiningField(name: k)
100
- end
101
- end
102
- xml.Output do
103
- xml.OutputField(name: "lgbmValue(#{@labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
104
- end
105
- xml_segmentation(xml, trees)
106
- end
107
- end
108
- end
109
- xml.Segment(id: @labels.size + 1) do
110
- xml.True
111
- xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
112
- xml.MiningSchema do
113
- xml.MiningField(name: @target, usageType: "target")
114
- @labels.each do |label|
115
- xml.MiningField(name: "lgbmValue(#{label})")
116
- end
117
- end
118
- xml.Output do
119
- @labels.each do |label|
120
- xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
121
- end
122
- end
123
- @labels.each do |label|
124
- xml.RegressionTable(intercept: 0.0, targetCategory: label) do
125
- xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
126
- end
127
- end
128
- end
129
- end
130
- end
131
- end
132
- end
133
- end
134
- end
135
-
136
- def xml_segmentation(xml, trees)
137
- xml.Segmentation(multipleModelMethod: "sum") do
138
- trees.each_with_index do |node, i|
139
- xml.Segment(id: i + 1) do
140
- xml.True
141
- xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
142
- xml.MiningSchema do
143
- node_fields(node).uniq.each do |k|
144
- xml.MiningField(name: display_field(k))
145
- end
146
- end
147
- node_pmml(node, xml)
148
- end
149
- end
150
- end
151
- end
152
- end
153
-
154
- def node_fields(node)
155
- fields = []
156
- fields << node.field if node.predicate
157
- node.children.each do |n|
158
- fields.concat(node_fields(n))
159
- end
160
- fields
161
- end
162
-
163
- def node_pmml(node, xml)
164
- xml.Node(score: node.score) do
165
- if node.predicate.nil?
166
- xml.True
167
- elsif node.operator == "in"
168
- xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
169
- xml.Array(type: "string") do
170
- xml.text node.value.map { |v| escape_element(v) }.join(" ")
171
- end
172
- end
173
- else
174
- xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
175
- end
176
- node.children.each do |n|
177
- node_pmml(n, xml)
178
- end
179
- end
180
- end
181
-
182
- def escape_element(v)
183
- "\"#{v.gsub("\"", "\\\"")}\""
184
- end
185
- end
186
- end
187
- end