eps 0.2.1 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/LICENSE.txt +1 -1
- data/README.md +183 -243
- data/lib/eps.rb +27 -3
- data/lib/eps/base_estimator.rb +316 -47
- data/lib/eps/data_frame.rb +141 -0
- data/lib/eps/evaluators/lightgbm.rb +116 -0
- data/lib/eps/evaluators/linear_regression.rb +54 -0
- data/lib/eps/evaluators/naive_bayes.rb +95 -0
- data/lib/eps/evaluators/node.rb +26 -0
- data/lib/eps/label_encoder.rb +41 -0
- data/lib/eps/lightgbm.rb +237 -0
- data/lib/eps/linear_regression.rb +132 -386
- data/lib/eps/metrics.rb +46 -0
- data/lib/eps/model.rb +16 -58
- data/lib/eps/naive_bayes.rb +175 -164
- data/lib/eps/pmml_generators/lightgbm.rb +187 -0
- data/lib/eps/statistics.rb +79 -0
- data/lib/eps/text_encoder.rb +81 -0
- data/lib/eps/utils.rb +22 -0
- data/lib/eps/version.rb +1 -1
- metadata +33 -7
@@ -0,0 +1,116 @@
|
|
1
|
+
module Eps
|
2
|
+
module Evaluators
|
3
|
+
class LightGBM
|
4
|
+
attr_reader :features
|
5
|
+
|
6
|
+
def initialize(trees:, objective:, labels:, features:, text_features:)
|
7
|
+
@trees = trees
|
8
|
+
@objective = objective
|
9
|
+
@labels = labels
|
10
|
+
@features = features
|
11
|
+
@text_features = text_features
|
12
|
+
end
|
13
|
+
|
14
|
+
def predict(data)
|
15
|
+
rows = data.map(&:to_h)
|
16
|
+
|
17
|
+
# sparse matrix
|
18
|
+
@text_features.each do |k, v|
|
19
|
+
encoder = TextEncoder.new(v)
|
20
|
+
|
21
|
+
values = data.columns.delete(k)
|
22
|
+
counts = encoder.transform(values)
|
23
|
+
|
24
|
+
encoder.vocabulary.each do |word|
|
25
|
+
data.columns[[k, word]] = [0] * values.size
|
26
|
+
end
|
27
|
+
|
28
|
+
counts.each_with_index do |xc, i|
|
29
|
+
row = rows[i]
|
30
|
+
row.delete(k)
|
31
|
+
xc.each do |word, count|
|
32
|
+
row[[k, word]] = count
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
case @objective
|
38
|
+
when "regression"
|
39
|
+
sum_trees(rows, @trees)
|
40
|
+
when "binary"
|
41
|
+
sum_trees(rows, @trees).map { |s| @labels[sigmoid(s) > 0.5 ? 1 : 0] }
|
42
|
+
else
|
43
|
+
tree_scores = []
|
44
|
+
num_trees = @trees.size / @labels.size
|
45
|
+
@trees.each_slice(num_trees).each do |trees|
|
46
|
+
tree_scores << sum_trees(rows, trees)
|
47
|
+
end
|
48
|
+
data.size.times.map do |i|
|
49
|
+
v = tree_scores.map { |s| s[i] }
|
50
|
+
idx = v.map.with_index.max_by { |v2, _| v2 }.last
|
51
|
+
@labels[idx]
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
private
|
57
|
+
|
58
|
+
def sum_trees(data, trees)
|
59
|
+
data.map do |row|
|
60
|
+
sum = 0
|
61
|
+
trees.each do |node|
|
62
|
+
score = node_score(node, row)
|
63
|
+
sum += score
|
64
|
+
end
|
65
|
+
sum
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
def matches?(node, row)
|
70
|
+
if node.predicate.nil?
|
71
|
+
true
|
72
|
+
else
|
73
|
+
v = row[node.field]
|
74
|
+
|
75
|
+
# sparse text feature
|
76
|
+
v = 0 if v.nil? && node.field.is_a?(Array)
|
77
|
+
|
78
|
+
if v.nil?
|
79
|
+
# missingValueStrategy="none"
|
80
|
+
false
|
81
|
+
else
|
82
|
+
case node.operator
|
83
|
+
when "equal"
|
84
|
+
v == node.value
|
85
|
+
when "in"
|
86
|
+
node.value.include?(v)
|
87
|
+
when "greaterThan"
|
88
|
+
v > node.value
|
89
|
+
when "lessOrEqual"
|
90
|
+
v <= node.value
|
91
|
+
else
|
92
|
+
raise "Unknown operator: #{node.operator}"
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
|
98
|
+
def node_score(node, row)
|
99
|
+
if matches?(node, row)
|
100
|
+
node.children.each do |c|
|
101
|
+
score = node_score(c, row)
|
102
|
+
return score if score
|
103
|
+
end
|
104
|
+
# noTrueChildStrategy="returnLastPrediction"
|
105
|
+
node.score
|
106
|
+
else
|
107
|
+
nil
|
108
|
+
end
|
109
|
+
end
|
110
|
+
|
111
|
+
def sigmoid(x)
|
112
|
+
1.0 / (1 + Math::E**(-x))
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
@@ -0,0 +1,54 @@
|
|
1
|
+
module Eps
|
2
|
+
module Evaluators
|
3
|
+
class LinearRegression
|
4
|
+
attr_reader :features
|
5
|
+
|
6
|
+
def initialize(coefficients:, features:, text_features:)
|
7
|
+
@coefficients = Hash[coefficients.map { |k, v| [k.is_a?(Array) ? [k[0].to_s, k[1]] : k.to_s, v] }]
|
8
|
+
@features = features
|
9
|
+
@text_features = text_features || {}
|
10
|
+
end
|
11
|
+
|
12
|
+
def predict(x)
|
13
|
+
intercept = @coefficients["_intercept"]
|
14
|
+
scores = [intercept] * x.size
|
15
|
+
|
16
|
+
@features.each do |k, type|
|
17
|
+
raise "Missing data in #{k}" if !x.columns[k] || x.columns[k].any?(&:nil?)
|
18
|
+
|
19
|
+
case type
|
20
|
+
when "categorical"
|
21
|
+
x.columns[k].each_with_index do |xv, i|
|
22
|
+
scores[i] += @coefficients[[k, xv]].to_f
|
23
|
+
end
|
24
|
+
when "text"
|
25
|
+
encoder = TextEncoder.new(@text_features[k])
|
26
|
+
counts = encoder.transform(x.columns[k])
|
27
|
+
coef = {}
|
28
|
+
@coefficients.each do |k2, v|
|
29
|
+
next unless k2.is_a?(Array) && k2.first == k
|
30
|
+
coef[k2.last] = v
|
31
|
+
end
|
32
|
+
|
33
|
+
counts.each_with_index do |xc, i|
|
34
|
+
xc.each do |word, count|
|
35
|
+
scores[i] += coef[word] * count if coef[word]
|
36
|
+
end
|
37
|
+
end
|
38
|
+
else
|
39
|
+
coef = @coefficients[k].to_f
|
40
|
+
x.columns[k].each_with_index do |xv, i|
|
41
|
+
scores[i] += coef * xv
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
scores
|
47
|
+
end
|
48
|
+
|
49
|
+
def coefficients
|
50
|
+
Hash[@coefficients.map { |k, v| [Array(k).join.to_sym, v] }]
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
@@ -0,0 +1,95 @@
|
|
1
|
+
module Eps
|
2
|
+
module Evaluators
|
3
|
+
class NaiveBayes
|
4
|
+
attr_reader :features, :probabilities
|
5
|
+
|
6
|
+
def initialize(probabilities:, features:, derived: nil, legacy: false)
|
7
|
+
@probabilities = probabilities
|
8
|
+
@features = features
|
9
|
+
@derived = derived
|
10
|
+
@legacy = legacy
|
11
|
+
end
|
12
|
+
|
13
|
+
def predict(x)
|
14
|
+
probs = calculate_class_probabilities(x)
|
15
|
+
probs.map do |xp|
|
16
|
+
# convert probabilities
|
17
|
+
# not needed when just returning label
|
18
|
+
# sum = xp.values.map { |v| Math.exp(v) }.sum.to_f
|
19
|
+
# p xp.map { |k, v| [k, Math.exp(v) / sum] }.to_h
|
20
|
+
xp.sort_by { |k, v| [-v, k] }[0][0]
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
# use log to prevent underflow
|
25
|
+
# https://www.antoniomallia.it/lets-implement-a-gaussian-naive-bayes-classifier-in-python.html
|
26
|
+
def calculate_class_probabilities(x)
|
27
|
+
probs = Eps::DataFrame.new
|
28
|
+
|
29
|
+
# assign very small probability if probability is 0
|
30
|
+
tiny_p = @legacy ? 0.0001 : 0
|
31
|
+
|
32
|
+
total = probabilities[:prior].values.sum.to_f
|
33
|
+
probabilities[:prior].each do |c, cv|
|
34
|
+
prior = Math.log(cv / total)
|
35
|
+
px = [prior] * x.size
|
36
|
+
|
37
|
+
@features.each do |k, type|
|
38
|
+
case type
|
39
|
+
when "categorical"
|
40
|
+
x.columns[k].each_with_index do |xi, i|
|
41
|
+
vc = probabilities[:conditional][k][xi]
|
42
|
+
|
43
|
+
# unknown value if not vc
|
44
|
+
if vc
|
45
|
+
denom = probabilities[:conditional][k].map { |k, v| v[c] }.sum.to_f
|
46
|
+
p2 = vc[c].to_f / denom
|
47
|
+
|
48
|
+
# TODO use proper smoothing instead
|
49
|
+
p2 = tiny_p if p2 == 0
|
50
|
+
|
51
|
+
px[i] += Math.log(p2)
|
52
|
+
end
|
53
|
+
end
|
54
|
+
when "derived"
|
55
|
+
@derived[k].each do |k2, v2|
|
56
|
+
vc = probabilities[:conditional][k2][c]
|
57
|
+
|
58
|
+
x.columns[k].each_with_index do |xi, i|
|
59
|
+
px[i] += Math.log(calculate_probability(xi == v2 ? 1 : 0, vc[:mean], vc[:stdev]))
|
60
|
+
end
|
61
|
+
end
|
62
|
+
else
|
63
|
+
vc = probabilities[:conditional][k][c]
|
64
|
+
|
65
|
+
if vc[:stdev] != 0 && !vc[:stdev].nil?
|
66
|
+
x.columns[k].each_with_index do |xi, i|
|
67
|
+
px[i] += Math.log(calculate_probability(xi, vc[:mean], vc[:stdev]))
|
68
|
+
end
|
69
|
+
else
|
70
|
+
x.columns[k].each_with_index do |xi, i|
|
71
|
+
if xi != vc[:mean]
|
72
|
+
# TODO use proper smoothing instead
|
73
|
+
px[i] += Math.log(tiny_p)
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
probs.columns[c] = px
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
probs
|
84
|
+
end
|
85
|
+
|
86
|
+
SQRT_2PI = Math.sqrt(2 * Math::PI)
|
87
|
+
|
88
|
+
# TODO memoize for performance
|
89
|
+
def calculate_probability(x, mean, stdev)
|
90
|
+
exponent = Math.exp(-((x - mean)**2) / (2 * (stdev**2)))
|
91
|
+
(1 / (SQRT_2PI * stdev)) * exponent
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
module Eps
|
2
|
+
module Evaluators
|
3
|
+
class Node
|
4
|
+
attr_accessor :score, :predicate, :children, :leaf_index
|
5
|
+
|
6
|
+
def initialize(predicate: nil, score: nil, children: nil, leaf_index: nil)
|
7
|
+
@predicate = predicate
|
8
|
+
@children = children || []
|
9
|
+
@score = score
|
10
|
+
@leaf_index = leaf_index
|
11
|
+
end
|
12
|
+
|
13
|
+
def field
|
14
|
+
@predicate[:field]
|
15
|
+
end
|
16
|
+
|
17
|
+
def operator
|
18
|
+
@predicate[:operator]
|
19
|
+
end
|
20
|
+
|
21
|
+
def value
|
22
|
+
@predicate[:value]
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Eps
|
2
|
+
class LabelEncoder
|
3
|
+
attr_reader :labels
|
4
|
+
|
5
|
+
def initialize
|
6
|
+
@labels = {}
|
7
|
+
end
|
8
|
+
|
9
|
+
def fit(y)
|
10
|
+
labels = {}
|
11
|
+
y.compact.map(&:to_s).uniq.sort.each_with_index do |label, i|
|
12
|
+
labels[label] = i
|
13
|
+
end
|
14
|
+
@labels = labels
|
15
|
+
end
|
16
|
+
|
17
|
+
def fit_transform(y)
|
18
|
+
fit(y)
|
19
|
+
transform(y)
|
20
|
+
end
|
21
|
+
|
22
|
+
def transform(y)
|
23
|
+
y.map do |yi|
|
24
|
+
if yi.nil?
|
25
|
+
nil
|
26
|
+
else
|
27
|
+
v = @labels[yi.to_s]
|
28
|
+
raise "Unknown label: #{yi}" unless v
|
29
|
+
v
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
def inverse_transform(y)
|
35
|
+
inverse = Hash[@labels.map(&:reverse)]
|
36
|
+
y.map do |yi|
|
37
|
+
inverse[yi.to_i]
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
data/lib/eps/lightgbm.rb
ADDED
@@ -0,0 +1,237 @@
|
|
1
|
+
require "eps/pmml_generators/lightgbm"
|
2
|
+
|
3
|
+
module Eps
|
4
|
+
class LightGBM < BaseEstimator
|
5
|
+
include PmmlGenerators::LightGBM
|
6
|
+
|
7
|
+
def self.load_pmml(data)
|
8
|
+
super do |data|
|
9
|
+
objective = data.css("MiningModel").first.attribute("functionName").value
|
10
|
+
if objective == "classification"
|
11
|
+
labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
|
12
|
+
objective = labels.size > 2 ? "multiclass" : "binary"
|
13
|
+
end
|
14
|
+
|
15
|
+
features = {}
|
16
|
+
text_features, derived_fields = extract_text_features(data, features)
|
17
|
+
node = data.css("DataDictionary").first
|
18
|
+
node.css("DataField")[1..-1].to_a.each do |node|
|
19
|
+
features[node.attribute("name").value] =
|
20
|
+
if node.attribute("optype").value == "categorical"
|
21
|
+
"categorical"
|
22
|
+
else
|
23
|
+
"numeric"
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
trees = []
|
28
|
+
data.css("Segmentation TreeModel").each do |tree|
|
29
|
+
node = find_nodes(tree.css("Node").first, derived_fields)
|
30
|
+
trees << node
|
31
|
+
end
|
32
|
+
|
33
|
+
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def _summary(extended: false)
|
40
|
+
str = String.new("")
|
41
|
+
importance = @booster.feature_importance
|
42
|
+
total = importance.sum.to_f
|
43
|
+
if total == 0
|
44
|
+
str << "Model needs more data for better predictions\n"
|
45
|
+
else
|
46
|
+
str << "Most important features\n"
|
47
|
+
@importance_keys.zip(importance).sort_by { |k, v| [-v, k] }.first(10).each do |k, v|
|
48
|
+
str << "#{display_field(k)}: #{(100 * v / total).round}\n"
|
49
|
+
end
|
50
|
+
end
|
51
|
+
str
|
52
|
+
end
|
53
|
+
|
54
|
+
def self.find_nodes(xml, derived_fields)
|
55
|
+
score = BigDecimal(xml.attribute("score").value).to_f
|
56
|
+
|
57
|
+
elements = xml.elements
|
58
|
+
xml_predicate = elements.first
|
59
|
+
|
60
|
+
predicate =
|
61
|
+
if xml_predicate.name == "True"
|
62
|
+
nil
|
63
|
+
elsif xml_predicate.name == "SimpleSetPredicate"
|
64
|
+
operator = "in"
|
65
|
+
value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
|
66
|
+
field = xml_predicate.attribute("field").value
|
67
|
+
field = derived_fields[field] if derived_fields[field]
|
68
|
+
{
|
69
|
+
field: field,
|
70
|
+
operator: operator,
|
71
|
+
value: value
|
72
|
+
}
|
73
|
+
else
|
74
|
+
operator = xml_predicate.attribute("operator").value
|
75
|
+
value = xml_predicate.attribute("value").value
|
76
|
+
value = BigDecimal(value).to_f if operator == "greaterThan"
|
77
|
+
field = xml_predicate.attribute("field").value
|
78
|
+
field = derived_fields[field] if derived_fields[field]
|
79
|
+
{
|
80
|
+
field: field,
|
81
|
+
operator: operator,
|
82
|
+
value: value
|
83
|
+
}
|
84
|
+
end
|
85
|
+
|
86
|
+
children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
|
87
|
+
|
88
|
+
Evaluators::Node.new(score: score, predicate: predicate, children: children)
|
89
|
+
end
|
90
|
+
|
91
|
+
def _train(verbose: nil, early_stopping: nil)
|
92
|
+
train_set = @train_set
|
93
|
+
validation_set = @validation_set.dup
|
94
|
+
summary_label = train_set.label
|
95
|
+
|
96
|
+
# objective
|
97
|
+
objective =
|
98
|
+
if @target_type == "numeric"
|
99
|
+
"regression"
|
100
|
+
else
|
101
|
+
label_encoder = LabelEncoder.new
|
102
|
+
train_set.label = label_encoder.fit_transform(train_set.label)
|
103
|
+
validation_set.label = label_encoder.transform(validation_set.label) if validation_set
|
104
|
+
labels = label_encoder.labels.keys
|
105
|
+
|
106
|
+
if labels.size > 2
|
107
|
+
"multiclass"
|
108
|
+
else
|
109
|
+
"binary"
|
110
|
+
end
|
111
|
+
end
|
112
|
+
|
113
|
+
# label encoding
|
114
|
+
label_encoders = {}
|
115
|
+
@features.each do |k, type|
|
116
|
+
if type == "categorical"
|
117
|
+
label_encoder = LabelEncoder.new
|
118
|
+
train_set.columns[k] = label_encoder.fit_transform(train_set.columns[k])
|
119
|
+
validation_set.columns[k] = label_encoder.transform(validation_set.columns[k]) if validation_set
|
120
|
+
label_encoders[k] = label_encoder
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
# text feature encoding
|
125
|
+
prep_text_features(train_set)
|
126
|
+
prep_text_features(validation_set) if validation_set
|
127
|
+
|
128
|
+
# create params
|
129
|
+
params = {objective: objective}
|
130
|
+
params[:num_classes] = labels.size if objective == "multiclass"
|
131
|
+
if train_set.size < 30
|
132
|
+
params[:min_data_in_bin] = 1
|
133
|
+
params[:min_data_in_leaf] = 1
|
134
|
+
end
|
135
|
+
|
136
|
+
# create datasets
|
137
|
+
categorical_idx = @features.values.map.with_index.select { |type, _| type == "categorical" }.map(&:last)
|
138
|
+
train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, categorical_feature: categorical_idx, params: params)
|
139
|
+
validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set
|
140
|
+
|
141
|
+
# train
|
142
|
+
valid_sets = [train_ds]
|
143
|
+
valid_sets << validation_ds if validation_ds
|
144
|
+
valid_names = ["training"]
|
145
|
+
valid_names << "validation" if validation_ds
|
146
|
+
early_stopping = 50 if early_stopping.nil? && validation_ds
|
147
|
+
early_stopping = nil if early_stopping == false
|
148
|
+
booster = ::LightGBM.train(params, train_ds, num_boost_round: 1000, early_stopping_rounds: early_stopping, valid_sets: valid_sets, valid_names: valid_names, verbose_eval: verbose || false)
|
149
|
+
|
150
|
+
# separate summary from verbose_eval
|
151
|
+
puts if verbose
|
152
|
+
|
153
|
+
@importance_keys = train_set.columns.keys
|
154
|
+
|
155
|
+
# create evaluator
|
156
|
+
@label_encoders = label_encoders
|
157
|
+
booster_tree = JSON.parse(booster.to_json)
|
158
|
+
trees = booster_tree["tree_info"].map { |s| parse_tree(s["tree_structure"]) }
|
159
|
+
# reshape
|
160
|
+
if objective == "multiclass"
|
161
|
+
new_trees = []
|
162
|
+
grouped = trees.each_slice(labels.size).to_a
|
163
|
+
labels.size.times do |i|
|
164
|
+
new_trees.concat grouped.map { |v| v[i] }
|
165
|
+
end
|
166
|
+
trees = new_trees
|
167
|
+
end
|
168
|
+
|
169
|
+
# for pmml
|
170
|
+
@objective = objective
|
171
|
+
@labels = labels
|
172
|
+
@feature_importance = booster.feature_importance
|
173
|
+
@trees = trees
|
174
|
+
@booster = booster
|
175
|
+
|
176
|
+
# reset pmml
|
177
|
+
@pmml = nil
|
178
|
+
|
179
|
+
Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features)
|
180
|
+
end
|
181
|
+
|
182
|
+
def evaluator_class
|
183
|
+
PmmlLoaders::LightGBM
|
184
|
+
end
|
185
|
+
|
186
|
+
# for evaluator
|
187
|
+
|
188
|
+
def parse_tree(node)
|
189
|
+
if node["leaf_value"]
|
190
|
+
score = node["leaf_value"]
|
191
|
+
Evaluators::Node.new(score: score, leaf_index: node["leaf_index"])
|
192
|
+
else
|
193
|
+
field = @importance_keys[node["split_feature"]]
|
194
|
+
operator =
|
195
|
+
case node["decision_type"]
|
196
|
+
when "=="
|
197
|
+
"equal"
|
198
|
+
when "<="
|
199
|
+
node["default_left"] ? "greaterThan" : "lessOrEqual"
|
200
|
+
else
|
201
|
+
raise "Unknown decision type: #{node["decision_type"]}"
|
202
|
+
end
|
203
|
+
|
204
|
+
value =
|
205
|
+
if operator == "equal"
|
206
|
+
if node["threshold"].include?("||")
|
207
|
+
operator = "in"
|
208
|
+
@label_encoders[field].inverse_transform(node["threshold"].split("||"))
|
209
|
+
else
|
210
|
+
@label_encoders[field].inverse_transform([node["threshold"]])[0]
|
211
|
+
end
|
212
|
+
else
|
213
|
+
node["threshold"]
|
214
|
+
end
|
215
|
+
|
216
|
+
predicate = {
|
217
|
+
field: field,
|
218
|
+
value: value,
|
219
|
+
operator: operator
|
220
|
+
}
|
221
|
+
|
222
|
+
left = parse_tree(node["left_child"])
|
223
|
+
right = parse_tree(node["right_child"])
|
224
|
+
|
225
|
+
if node["default_left"]
|
226
|
+
right.predicate = predicate
|
227
|
+
left.children.unshift right
|
228
|
+
left
|
229
|
+
else
|
230
|
+
left.predicate = predicate
|
231
|
+
right.children.unshift left
|
232
|
+
right
|
233
|
+
end
|
234
|
+
end
|
235
|
+
end
|
236
|
+
end
|
237
|
+
end
|