eps 0.2.1 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/LICENSE.txt +1 -1
- data/README.md +183 -243
- data/lib/eps.rb +27 -3
- data/lib/eps/base_estimator.rb +316 -47
- data/lib/eps/data_frame.rb +141 -0
- data/lib/eps/evaluators/lightgbm.rb +116 -0
- data/lib/eps/evaluators/linear_regression.rb +54 -0
- data/lib/eps/evaluators/naive_bayes.rb +95 -0
- data/lib/eps/evaluators/node.rb +26 -0
- data/lib/eps/label_encoder.rb +41 -0
- data/lib/eps/lightgbm.rb +237 -0
- data/lib/eps/linear_regression.rb +132 -386
- data/lib/eps/metrics.rb +46 -0
- data/lib/eps/model.rb +16 -58
- data/lib/eps/naive_bayes.rb +175 -164
- data/lib/eps/pmml_generators/lightgbm.rb +187 -0
- data/lib/eps/statistics.rb +79 -0
- data/lib/eps/text_encoder.rb +81 -0
- data/lib/eps/utils.rb +22 -0
- data/lib/eps/version.rb +1 -1
- metadata +33 -7
@@ -0,0 +1,116 @@
|
|
1
|
+
module Eps
|
2
|
+
module Evaluators
|
3
|
+
class LightGBM
|
4
|
+
attr_reader :features
|
5
|
+
|
6
|
+
def initialize(trees:, objective:, labels:, features:, text_features:)
|
7
|
+
@trees = trees
|
8
|
+
@objective = objective
|
9
|
+
@labels = labels
|
10
|
+
@features = features
|
11
|
+
@text_features = text_features
|
12
|
+
end
|
13
|
+
|
14
|
+
def predict(data)
|
15
|
+
rows = data.map(&:to_h)
|
16
|
+
|
17
|
+
# sparse matrix
|
18
|
+
@text_features.each do |k, v|
|
19
|
+
encoder = TextEncoder.new(v)
|
20
|
+
|
21
|
+
values = data.columns.delete(k)
|
22
|
+
counts = encoder.transform(values)
|
23
|
+
|
24
|
+
encoder.vocabulary.each do |word|
|
25
|
+
data.columns[[k, word]] = [0] * values.size
|
26
|
+
end
|
27
|
+
|
28
|
+
counts.each_with_index do |xc, i|
|
29
|
+
row = rows[i]
|
30
|
+
row.delete(k)
|
31
|
+
xc.each do |word, count|
|
32
|
+
row[[k, word]] = count
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
case @objective
|
38
|
+
when "regression"
|
39
|
+
sum_trees(rows, @trees)
|
40
|
+
when "binary"
|
41
|
+
sum_trees(rows, @trees).map { |s| @labels[sigmoid(s) > 0.5 ? 1 : 0] }
|
42
|
+
else
|
43
|
+
tree_scores = []
|
44
|
+
num_trees = @trees.size / @labels.size
|
45
|
+
@trees.each_slice(num_trees).each do |trees|
|
46
|
+
tree_scores << sum_trees(rows, trees)
|
47
|
+
end
|
48
|
+
data.size.times.map do |i|
|
49
|
+
v = tree_scores.map { |s| s[i] }
|
50
|
+
idx = v.map.with_index.max_by { |v2, _| v2 }.last
|
51
|
+
@labels[idx]
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
private
|
57
|
+
|
58
|
+
def sum_trees(data, trees)
|
59
|
+
data.map do |row|
|
60
|
+
sum = 0
|
61
|
+
trees.each do |node|
|
62
|
+
score = node_score(node, row)
|
63
|
+
sum += score
|
64
|
+
end
|
65
|
+
sum
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
def matches?(node, row)
|
70
|
+
if node.predicate.nil?
|
71
|
+
true
|
72
|
+
else
|
73
|
+
v = row[node.field]
|
74
|
+
|
75
|
+
# sparse text feature
|
76
|
+
v = 0 if v.nil? && node.field.is_a?(Array)
|
77
|
+
|
78
|
+
if v.nil?
|
79
|
+
# missingValueStrategy="none"
|
80
|
+
false
|
81
|
+
else
|
82
|
+
case node.operator
|
83
|
+
when "equal"
|
84
|
+
v == node.value
|
85
|
+
when "in"
|
86
|
+
node.value.include?(v)
|
87
|
+
when "greaterThan"
|
88
|
+
v > node.value
|
89
|
+
when "lessOrEqual"
|
90
|
+
v <= node.value
|
91
|
+
else
|
92
|
+
raise "Unknown operator: #{node.operator}"
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
|
98
|
+
def node_score(node, row)
|
99
|
+
if matches?(node, row)
|
100
|
+
node.children.each do |c|
|
101
|
+
score = node_score(c, row)
|
102
|
+
return score if score
|
103
|
+
end
|
104
|
+
# noTrueChildStrategy="returnLastPrediction"
|
105
|
+
node.score
|
106
|
+
else
|
107
|
+
nil
|
108
|
+
end
|
109
|
+
end
|
110
|
+
|
111
|
+
def sigmoid(x)
|
112
|
+
1.0 / (1 + Math::E**(-x))
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
@@ -0,0 +1,54 @@
|
|
1
|
+
module Eps
|
2
|
+
module Evaluators
|
3
|
+
class LinearRegression
|
4
|
+
attr_reader :features
|
5
|
+
|
6
|
+
def initialize(coefficients:, features:, text_features:)
|
7
|
+
@coefficients = Hash[coefficients.map { |k, v| [k.is_a?(Array) ? [k[0].to_s, k[1]] : k.to_s, v] }]
|
8
|
+
@features = features
|
9
|
+
@text_features = text_features || {}
|
10
|
+
end
|
11
|
+
|
12
|
+
def predict(x)
|
13
|
+
intercept = @coefficients["_intercept"]
|
14
|
+
scores = [intercept] * x.size
|
15
|
+
|
16
|
+
@features.each do |k, type|
|
17
|
+
raise "Missing data in #{k}" if !x.columns[k] || x.columns[k].any?(&:nil?)
|
18
|
+
|
19
|
+
case type
|
20
|
+
when "categorical"
|
21
|
+
x.columns[k].each_with_index do |xv, i|
|
22
|
+
scores[i] += @coefficients[[k, xv]].to_f
|
23
|
+
end
|
24
|
+
when "text"
|
25
|
+
encoder = TextEncoder.new(@text_features[k])
|
26
|
+
counts = encoder.transform(x.columns[k])
|
27
|
+
coef = {}
|
28
|
+
@coefficients.each do |k2, v|
|
29
|
+
next unless k2.is_a?(Array) && k2.first == k
|
30
|
+
coef[k2.last] = v
|
31
|
+
end
|
32
|
+
|
33
|
+
counts.each_with_index do |xc, i|
|
34
|
+
xc.each do |word, count|
|
35
|
+
scores[i] += coef[word] * count if coef[word]
|
36
|
+
end
|
37
|
+
end
|
38
|
+
else
|
39
|
+
coef = @coefficients[k].to_f
|
40
|
+
x.columns[k].each_with_index do |xv, i|
|
41
|
+
scores[i] += coef * xv
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
scores
|
47
|
+
end
|
48
|
+
|
49
|
+
def coefficients
|
50
|
+
Hash[@coefficients.map { |k, v| [Array(k).join.to_sym, v] }]
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
@@ -0,0 +1,95 @@
|
|
1
|
+
module Eps
|
2
|
+
module Evaluators
|
3
|
+
class NaiveBayes
|
4
|
+
attr_reader :features, :probabilities
|
5
|
+
|
6
|
+
def initialize(probabilities:, features:, derived: nil, legacy: false)
|
7
|
+
@probabilities = probabilities
|
8
|
+
@features = features
|
9
|
+
@derived = derived
|
10
|
+
@legacy = legacy
|
11
|
+
end
|
12
|
+
|
13
|
+
def predict(x)
|
14
|
+
probs = calculate_class_probabilities(x)
|
15
|
+
probs.map do |xp|
|
16
|
+
# convert probabilities
|
17
|
+
# not needed when just returning label
|
18
|
+
# sum = xp.values.map { |v| Math.exp(v) }.sum.to_f
|
19
|
+
# p xp.map { |k, v| [k, Math.exp(v) / sum] }.to_h
|
20
|
+
xp.sort_by { |k, v| [-v, k] }[0][0]
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
# use log to prevent underflow
|
25
|
+
# https://www.antoniomallia.it/lets-implement-a-gaussian-naive-bayes-classifier-in-python.html
|
26
|
+
def calculate_class_probabilities(x)
|
27
|
+
probs = Eps::DataFrame.new
|
28
|
+
|
29
|
+
# assign very small probability if probability is 0
|
30
|
+
tiny_p = @legacy ? 0.0001 : 0
|
31
|
+
|
32
|
+
total = probabilities[:prior].values.sum.to_f
|
33
|
+
probabilities[:prior].each do |c, cv|
|
34
|
+
prior = Math.log(cv / total)
|
35
|
+
px = [prior] * x.size
|
36
|
+
|
37
|
+
@features.each do |k, type|
|
38
|
+
case type
|
39
|
+
when "categorical"
|
40
|
+
x.columns[k].each_with_index do |xi, i|
|
41
|
+
vc = probabilities[:conditional][k][xi]
|
42
|
+
|
43
|
+
# unknown value if not vc
|
44
|
+
if vc
|
45
|
+
denom = probabilities[:conditional][k].map { |k, v| v[c] }.sum.to_f
|
46
|
+
p2 = vc[c].to_f / denom
|
47
|
+
|
48
|
+
# TODO use proper smoothing instead
|
49
|
+
p2 = tiny_p if p2 == 0
|
50
|
+
|
51
|
+
px[i] += Math.log(p2)
|
52
|
+
end
|
53
|
+
end
|
54
|
+
when "derived"
|
55
|
+
@derived[k].each do |k2, v2|
|
56
|
+
vc = probabilities[:conditional][k2][c]
|
57
|
+
|
58
|
+
x.columns[k].each_with_index do |xi, i|
|
59
|
+
px[i] += Math.log(calculate_probability(xi == v2 ? 1 : 0, vc[:mean], vc[:stdev]))
|
60
|
+
end
|
61
|
+
end
|
62
|
+
else
|
63
|
+
vc = probabilities[:conditional][k][c]
|
64
|
+
|
65
|
+
if vc[:stdev] != 0 && !vc[:stdev].nil?
|
66
|
+
x.columns[k].each_with_index do |xi, i|
|
67
|
+
px[i] += Math.log(calculate_probability(xi, vc[:mean], vc[:stdev]))
|
68
|
+
end
|
69
|
+
else
|
70
|
+
x.columns[k].each_with_index do |xi, i|
|
71
|
+
if xi != vc[:mean]
|
72
|
+
# TODO use proper smoothing instead
|
73
|
+
px[i] += Math.log(tiny_p)
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
probs.columns[c] = px
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
probs
|
84
|
+
end
|
85
|
+
|
86
|
+
SQRT_2PI = Math.sqrt(2 * Math::PI)
|
87
|
+
|
88
|
+
# TODO memoize for performance
|
89
|
+
def calculate_probability(x, mean, stdev)
|
90
|
+
exponent = Math.exp(-((x - mean)**2) / (2 * (stdev**2)))
|
91
|
+
(1 / (SQRT_2PI * stdev)) * exponent
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
module Eps
|
2
|
+
module Evaluators
|
3
|
+
class Node
|
4
|
+
attr_accessor :score, :predicate, :children, :leaf_index
|
5
|
+
|
6
|
+
def initialize(predicate: nil, score: nil, children: nil, leaf_index: nil)
|
7
|
+
@predicate = predicate
|
8
|
+
@children = children || []
|
9
|
+
@score = score
|
10
|
+
@leaf_index = leaf_index
|
11
|
+
end
|
12
|
+
|
13
|
+
def field
|
14
|
+
@predicate[:field]
|
15
|
+
end
|
16
|
+
|
17
|
+
def operator
|
18
|
+
@predicate[:operator]
|
19
|
+
end
|
20
|
+
|
21
|
+
def value
|
22
|
+
@predicate[:value]
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Eps
|
2
|
+
class LabelEncoder
|
3
|
+
attr_reader :labels
|
4
|
+
|
5
|
+
def initialize
|
6
|
+
@labels = {}
|
7
|
+
end
|
8
|
+
|
9
|
+
def fit(y)
|
10
|
+
labels = {}
|
11
|
+
y.compact.map(&:to_s).uniq.sort.each_with_index do |label, i|
|
12
|
+
labels[label] = i
|
13
|
+
end
|
14
|
+
@labels = labels
|
15
|
+
end
|
16
|
+
|
17
|
+
def fit_transform(y)
|
18
|
+
fit(y)
|
19
|
+
transform(y)
|
20
|
+
end
|
21
|
+
|
22
|
+
def transform(y)
|
23
|
+
y.map do |yi|
|
24
|
+
if yi.nil?
|
25
|
+
nil
|
26
|
+
else
|
27
|
+
v = @labels[yi.to_s]
|
28
|
+
raise "Unknown label: #{yi}" unless v
|
29
|
+
v
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
def inverse_transform(y)
|
35
|
+
inverse = Hash[@labels.map(&:reverse)]
|
36
|
+
y.map do |yi|
|
37
|
+
inverse[yi.to_i]
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
data/lib/eps/lightgbm.rb
ADDED
@@ -0,0 +1,237 @@
|
|
1
|
+
require "eps/pmml_generators/lightgbm"
|
2
|
+
|
3
|
+
module Eps
|
4
|
+
class LightGBM < BaseEstimator
|
5
|
+
include PmmlGenerators::LightGBM
|
6
|
+
|
7
|
+
def self.load_pmml(data)
|
8
|
+
super do |data|
|
9
|
+
objective = data.css("MiningModel").first.attribute("functionName").value
|
10
|
+
if objective == "classification"
|
11
|
+
labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
|
12
|
+
objective = labels.size > 2 ? "multiclass" : "binary"
|
13
|
+
end
|
14
|
+
|
15
|
+
features = {}
|
16
|
+
text_features, derived_fields = extract_text_features(data, features)
|
17
|
+
node = data.css("DataDictionary").first
|
18
|
+
node.css("DataField")[1..-1].to_a.each do |node|
|
19
|
+
features[node.attribute("name").value] =
|
20
|
+
if node.attribute("optype").value == "categorical"
|
21
|
+
"categorical"
|
22
|
+
else
|
23
|
+
"numeric"
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
trees = []
|
28
|
+
data.css("Segmentation TreeModel").each do |tree|
|
29
|
+
node = find_nodes(tree.css("Node").first, derived_fields)
|
30
|
+
trees << node
|
31
|
+
end
|
32
|
+
|
33
|
+
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def _summary(extended: false)
|
40
|
+
str = String.new("")
|
41
|
+
importance = @booster.feature_importance
|
42
|
+
total = importance.sum.to_f
|
43
|
+
if total == 0
|
44
|
+
str << "Model needs more data for better predictions\n"
|
45
|
+
else
|
46
|
+
str << "Most important features\n"
|
47
|
+
@importance_keys.zip(importance).sort_by { |k, v| [-v, k] }.first(10).each do |k, v|
|
48
|
+
str << "#{display_field(k)}: #{(100 * v / total).round}\n"
|
49
|
+
end
|
50
|
+
end
|
51
|
+
str
|
52
|
+
end
|
53
|
+
|
54
|
+
def self.find_nodes(xml, derived_fields)
|
55
|
+
score = BigDecimal(xml.attribute("score").value).to_f
|
56
|
+
|
57
|
+
elements = xml.elements
|
58
|
+
xml_predicate = elements.first
|
59
|
+
|
60
|
+
predicate =
|
61
|
+
if xml_predicate.name == "True"
|
62
|
+
nil
|
63
|
+
elsif xml_predicate.name == "SimpleSetPredicate"
|
64
|
+
operator = "in"
|
65
|
+
value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
|
66
|
+
field = xml_predicate.attribute("field").value
|
67
|
+
field = derived_fields[field] if derived_fields[field]
|
68
|
+
{
|
69
|
+
field: field,
|
70
|
+
operator: operator,
|
71
|
+
value: value
|
72
|
+
}
|
73
|
+
else
|
74
|
+
operator = xml_predicate.attribute("operator").value
|
75
|
+
value = xml_predicate.attribute("value").value
|
76
|
+
value = BigDecimal(value).to_f if operator == "greaterThan"
|
77
|
+
field = xml_predicate.attribute("field").value
|
78
|
+
field = derived_fields[field] if derived_fields[field]
|
79
|
+
{
|
80
|
+
field: field,
|
81
|
+
operator: operator,
|
82
|
+
value: value
|
83
|
+
}
|
84
|
+
end
|
85
|
+
|
86
|
+
children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
|
87
|
+
|
88
|
+
Evaluators::Node.new(score: score, predicate: predicate, children: children)
|
89
|
+
end
|
90
|
+
|
91
|
+
def _train(verbose: nil, early_stopping: nil)
|
92
|
+
train_set = @train_set
|
93
|
+
validation_set = @validation_set.dup
|
94
|
+
summary_label = train_set.label
|
95
|
+
|
96
|
+
# objective
|
97
|
+
objective =
|
98
|
+
if @target_type == "numeric"
|
99
|
+
"regression"
|
100
|
+
else
|
101
|
+
label_encoder = LabelEncoder.new
|
102
|
+
train_set.label = label_encoder.fit_transform(train_set.label)
|
103
|
+
validation_set.label = label_encoder.transform(validation_set.label) if validation_set
|
104
|
+
labels = label_encoder.labels.keys
|
105
|
+
|
106
|
+
if labels.size > 2
|
107
|
+
"multiclass"
|
108
|
+
else
|
109
|
+
"binary"
|
110
|
+
end
|
111
|
+
end
|
112
|
+
|
113
|
+
# label encoding
|
114
|
+
label_encoders = {}
|
115
|
+
@features.each do |k, type|
|
116
|
+
if type == "categorical"
|
117
|
+
label_encoder = LabelEncoder.new
|
118
|
+
train_set.columns[k] = label_encoder.fit_transform(train_set.columns[k])
|
119
|
+
validation_set.columns[k] = label_encoder.transform(validation_set.columns[k]) if validation_set
|
120
|
+
label_encoders[k] = label_encoder
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
# text feature encoding
|
125
|
+
prep_text_features(train_set)
|
126
|
+
prep_text_features(validation_set) if validation_set
|
127
|
+
|
128
|
+
# create params
|
129
|
+
params = {objective: objective}
|
130
|
+
params[:num_classes] = labels.size if objective == "multiclass"
|
131
|
+
if train_set.size < 30
|
132
|
+
params[:min_data_in_bin] = 1
|
133
|
+
params[:min_data_in_leaf] = 1
|
134
|
+
end
|
135
|
+
|
136
|
+
# create datasets
|
137
|
+
categorical_idx = @features.values.map.with_index.select { |type, _| type == "categorical" }.map(&:last)
|
138
|
+
train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, categorical_feature: categorical_idx, params: params)
|
139
|
+
validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set
|
140
|
+
|
141
|
+
# train
|
142
|
+
valid_sets = [train_ds]
|
143
|
+
valid_sets << validation_ds if validation_ds
|
144
|
+
valid_names = ["training"]
|
145
|
+
valid_names << "validation" if validation_ds
|
146
|
+
early_stopping = 50 if early_stopping.nil? && validation_ds
|
147
|
+
early_stopping = nil if early_stopping == false
|
148
|
+
booster = ::LightGBM.train(params, train_ds, num_boost_round: 1000, early_stopping_rounds: early_stopping, valid_sets: valid_sets, valid_names: valid_names, verbose_eval: verbose || false)
|
149
|
+
|
150
|
+
# separate summary from verbose_eval
|
151
|
+
puts if verbose
|
152
|
+
|
153
|
+
@importance_keys = train_set.columns.keys
|
154
|
+
|
155
|
+
# create evaluator
|
156
|
+
@label_encoders = label_encoders
|
157
|
+
booster_tree = JSON.parse(booster.to_json)
|
158
|
+
trees = booster_tree["tree_info"].map { |s| parse_tree(s["tree_structure"]) }
|
159
|
+
# reshape
|
160
|
+
if objective == "multiclass"
|
161
|
+
new_trees = []
|
162
|
+
grouped = trees.each_slice(labels.size).to_a
|
163
|
+
labels.size.times do |i|
|
164
|
+
new_trees.concat grouped.map { |v| v[i] }
|
165
|
+
end
|
166
|
+
trees = new_trees
|
167
|
+
end
|
168
|
+
|
169
|
+
# for pmml
|
170
|
+
@objective = objective
|
171
|
+
@labels = labels
|
172
|
+
@feature_importance = booster.feature_importance
|
173
|
+
@trees = trees
|
174
|
+
@booster = booster
|
175
|
+
|
176
|
+
# reset pmml
|
177
|
+
@pmml = nil
|
178
|
+
|
179
|
+
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features)
|
180
|
+
end
|
181
|
+
|
182
|
+
def evaluator_class
|
183
|
+
PmmlLoaders::LightGBM
|
184
|
+
end
|
185
|
+
|
186
|
+
# for evaluator
|
187
|
+
|
188
|
+
def parse_tree(node)
|
189
|
+
if node["leaf_value"]
|
190
|
+
score = node["leaf_value"]
|
191
|
+
Evaluators::Node.new(score: score, leaf_index: node["leaf_index"])
|
192
|
+
else
|
193
|
+
field = @importance_keys[node["split_feature"]]
|
194
|
+
operator =
|
195
|
+
case node["decision_type"]
|
196
|
+
when "=="
|
197
|
+
"equal"
|
198
|
+
when "<="
|
199
|
+
node["default_left"] ? "greaterThan" : "lessOrEqual"
|
200
|
+
else
|
201
|
+
raise "Unknown decision type: #{node["decision_type"]}"
|
202
|
+
end
|
203
|
+
|
204
|
+
value =
|
205
|
+
if operator == "equal"
|
206
|
+
if node["threshold"].include?("||")
|
207
|
+
operator = "in"
|
208
|
+
@label_encoders[field].inverse_transform(node["threshold"].split("||"))
|
209
|
+
else
|
210
|
+
@label_encoders[field].inverse_transform([node["threshold"]])[0]
|
211
|
+
end
|
212
|
+
else
|
213
|
+
node["threshold"]
|
214
|
+
end
|
215
|
+
|
216
|
+
predicate = {
|
217
|
+
field: field,
|
218
|
+
value: value,
|
219
|
+
operator: operator
|
220
|
+
}
|
221
|
+
|
222
|
+
left = parse_tree(node["left_child"])
|
223
|
+
right = parse_tree(node["right_child"])
|
224
|
+
|
225
|
+
if node["default_left"]
|
226
|
+
right.predicate = predicate
|
227
|
+
left.children.unshift right
|
228
|
+
left
|
229
|
+
else
|
230
|
+
left.predicate = predicate
|
231
|
+
right.children.unshift left
|
232
|
+
right
|
233
|
+
end
|
234
|
+
end
|
235
|
+
end
|
236
|
+
end
|
237
|
+
end
|