eps 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,3 @@
1
1
  module Eps
2
- VERSION = "0.3.0"
2
+ VERSION = "0.3.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: eps
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.3.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-09-05 00:00:00.000000000 Z
11
+ date: 2019-12-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: lightgbm
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.1.5
19
+ version: 0.1.7
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.1.5
26
+ version: 0.1.7
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: nokogiri
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -117,7 +117,9 @@ files:
117
117
  - lib/eps/metrics.rb
118
118
  - lib/eps/model.rb
119
119
  - lib/eps/naive_bayes.rb
120
- - lib/eps/pmml_generators/lightgbm.rb
120
+ - lib/eps/pmml.rb
121
+ - lib/eps/pmml/generator.rb
122
+ - lib/eps/pmml/loader.rb
121
123
  - lib/eps/statistics.rb
122
124
  - lib/eps/text_encoder.rb
123
125
  - lib/eps/utils.rb
@@ -1,187 +0,0 @@
1
- module Eps
2
- module PmmlGenerators
3
- module LightGBM
4
- private
5
-
6
- def generate_pmml
7
- feature_importance = @feature_importance
8
-
9
- data_fields = {}
10
- data_fields[@target] = @labels if @labels
11
- @features.each_with_index do |(k, type), i|
12
- # TODO remove zero importance features
13
- if type == "categorical"
14
- data_fields[k] = @label_encoders[k].labels.keys
15
- else
16
- data_fields[k] = nil
17
- end
18
- end
19
-
20
- build_pmml(data_fields) do |xml|
21
- function_name = @objective == "regression" ? "regression" : "classification"
22
- xml.MiningModel(functionName: function_name, algorithmName: "LightGBM") do
23
- xml.MiningSchema do
24
- xml.MiningField(name: @target, usageType: "target")
25
- @features.keys.each_with_index do |k, i|
26
- # next if feature_importance[i] == 0
27
- # TODO add importance, but need to handle text features
28
- xml.MiningField(name: k) #, importance: feature_importance[i].to_f, missingValueTreatment: "asIs")
29
- end
30
- end
31
- pmml_local_transformations(xml)
32
-
33
- case @objective
34
- when "regression"
35
- xml_segmentation(xml, @trees)
36
- when "binary"
37
- xml.Segmentation(multipleModelMethod: "modelChain") do
38
- xml.Segment(id: 1) do
39
- xml.True
40
- xml.MiningModel(functionName: "regression") do
41
- xml.MiningSchema do
42
- @features.each do |k, _|
43
- xml.MiningField(name: k)
44
- end
45
- end
46
- xml.Output do
47
- xml.OutputField(name: "lgbmValue", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false) do
48
- xml.Apply(function: "/") do
49
- xml.Constant(dataType: "double") do
50
- 1.0
51
- end
52
- xml.Apply(function: "+") do
53
- xml.Constant(dataType: "double") do
54
- 1.0
55
- end
56
- xml.Apply(function: "exp") do
57
- xml.Apply(function: "*") do
58
- xml.Constant(dataType: "double") do
59
- -1.0
60
- end
61
- xml.FieldRef(field: "lgbmValue")
62
- end
63
- end
64
- end
65
- end
66
- end
67
- end
68
- xml_segmentation(xml, @trees)
69
- end
70
- end
71
- xml.Segment(id: 2) do
72
- xml.True
73
- xml.RegressionModel(functionName: "classification", normalizationMethod: "none") do
74
- xml.MiningSchema do
75
- xml.MiningField(name: @target, usageType: "target")
76
- xml.MiningField(name: "transformedLgbmValue")
77
- end
78
- xml.Output do
79
- @labels.each do |label|
80
- xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
81
- end
82
- end
83
- xml.RegressionTable(intercept: 0.0, targetCategory: @labels.last) do
84
- xml.NumericPredictor(name: "transformedLgbmValue", coefficient: "1.0")
85
- end
86
- xml.RegressionTable(intercept: 0.0, targetCategory: @labels.first)
87
- end
88
- end
89
- end
90
- else # multiclass
91
- xml.Segmentation(multipleModelMethod: "modelChain") do
92
- n = @trees.size / @labels.size
93
- @trees.each_slice(n).each_with_index do |trees, idx|
94
- xml.Segment(id: idx + 1) do
95
- xml.True
96
- xml.MiningModel(functionName: "regression") do
97
- xml.MiningSchema do
98
- @features.each do |k, _|
99
- xml.MiningField(name: k)
100
- end
101
- end
102
- xml.Output do
103
- xml.OutputField(name: "lgbmValue(#{@labels[idx]})", optype: "continuous", dataType: "double", feature: "predictedValue", isFinalResult: false)
104
- end
105
- xml_segmentation(xml, trees)
106
- end
107
- end
108
- end
109
- xml.Segment(id: @labels.size + 1) do
110
- xml.True
111
- xml.RegressionModel(functionName: "classification", normalizationMethod: "softmax") do
112
- xml.MiningSchema do
113
- xml.MiningField(name: @target, usageType: "target")
114
- @labels.each do |label|
115
- xml.MiningField(name: "lgbmValue(#{label})")
116
- end
117
- end
118
- xml.Output do
119
- @labels.each do |label|
120
- xml.OutputField(name: "probability(#{label})", optype: "continuous", dataType: "double", feature: "probability", value: label)
121
- end
122
- end
123
- @labels.each do |label|
124
- xml.RegressionTable(intercept: 0.0, targetCategory: label) do
125
- xml.NumericPredictor(name: "lgbmValue(#{label})", coefficient: "1.0")
126
- end
127
- end
128
- end
129
- end
130
- end
131
- end
132
- end
133
- end
134
- end
135
-
136
- def xml_segmentation(xml, trees)
137
- xml.Segmentation(multipleModelMethod: "sum") do
138
- trees.each_with_index do |node, i|
139
- xml.Segment(id: i + 1) do
140
- xml.True
141
- xml.TreeModel(functionName: "regression", missingValueStrategy: "none", noTrueChildStrategy: "returnLastPrediction", splitCharacteristic: "multiSplit") do
142
- xml.MiningSchema do
143
- node_fields(node).uniq.each do |k|
144
- xml.MiningField(name: display_field(k))
145
- end
146
- end
147
- node_pmml(node, xml)
148
- end
149
- end
150
- end
151
- end
152
- end
153
-
154
- def node_fields(node)
155
- fields = []
156
- fields << node.field if node.predicate
157
- node.children.each do |n|
158
- fields.concat(node_fields(n))
159
- end
160
- fields
161
- end
162
-
163
- def node_pmml(node, xml)
164
- xml.Node(score: node.score) do
165
- if node.predicate.nil?
166
- xml.True
167
- elsif node.operator == "in"
168
- xml.SimpleSetPredicate(field: display_field(node.field), booleanOperator: "isIn") do
169
- xml.Array(type: "string") do
170
- xml.text node.value.map { |v| escape_element(v) }.join(" ")
171
- end
172
- end
173
- else
174
- xml.SimplePredicate(field: display_field(node.field), operator: node.operator, value: node.value)
175
- end
176
- node.children.each do |n|
177
- node_pmml(n, xml)
178
- end
179
- end
180
- end
181
-
182
- def escape_element(v)
183
- "\"#{v.gsub("\"", "\\\"")}\""
184
- end
185
- end
186
- end
187
- end