ml4r 0.1.4 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- data/ext/ml4r/LinearRegression/LinearRegression.cpp +305 -0
- data/ext/ml4r/LinearRegression/OLSLinearRegression.cpp +75 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeExperiment.cpp +50 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeNode.cpp +195 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitter.cpp +551 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterCategorical.cpp +22 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterContinuous.cpp +21 -0
- data/ext/ml4r/MachineLearning/DecisionTree/SplitDefinition.cpp +142 -0
- data/ext/ml4r/MachineLearning/GBM/BernoulliCalculator.cpp +95 -0
- data/ext/ml4r/MachineLearning/GBM/GBMEstimator.cpp +601 -0
- data/ext/ml4r/MachineLearning/GBM/GBMOutput.cpp +86 -0
- data/ext/ml4r/MachineLearning/GBM/GBMRunner.cpp +117 -0
- data/ext/ml4r/MachineLearning/GBM/GaussianCalculator.cpp +94 -0
- data/ext/ml4r/MachineLearning/GBM/ZenithGBM.cpp +317 -0
- data/ext/ml4r/MachineLearning/MLData/MLData.cpp +232 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataFields.cpp +1 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataReader.cpp +139 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLData.cpp +96 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLDataReader.cpp +113 -0
- data/ext/ml4r/MachineLearning/MLExperiment.cpp +69 -0
- data/ext/ml4r/MachineLearning/MLRunner.cpp +183 -0
- data/ext/ml4r/MachineLearning/MLUtils.cpp +15 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestEstimator.cpp +172 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestOutput.cpp +66 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestRunner.cpp +84 -0
- data/ext/ml4r/MachineLearning/RandomForest/ZenithRandomForest.cpp +184 -0
- data/ext/ml4r/ml4r.cpp +34 -0
- data/ext/ml4r/ml4r_wrap.cpp +15727 -0
- data/ext/ml4r/utils/MathUtils.cpp +204 -0
- data/ext/ml4r/utils/StochasticUtils.cpp +73 -0
- data/ext/ml4r/utils/Utils.cpp +14 -0
- data/ext/ml4r/utils/VlcMessage.cpp +3 -0
- metadata +33 -1
@@ -0,0 +1,305 @@
|
|
1
|
+
#include <boost/numeric/ublas/vector.hpp>
|
2
|
+
#include <boost/numeric/ublas/vector_proxy.hpp>
|
3
|
+
#include <boost/numeric/ublas/matrix.hpp>
|
4
|
+
#include <boost/numeric/ublas/triangular.hpp>
|
5
|
+
#include <boost/numeric/ublas/lu.hpp>
|
6
|
+
#include <boost/numeric/ublas/io.hpp>
|
7
|
+
#include <boost/foreach.hpp>
|
8
|
+
#include <iostream>
|
9
|
+
using std::cout;
|
10
|
+
using std::endl;
|
11
|
+
|
12
|
+
#include "LinearRegression/LinearRegression.h"
|
13
|
+
#include "utils/MatrixInversion.h"
|
14
|
+
#include "utils/Utils.h"
|
15
|
+
namespace ublas = boost::numeric::ublas;
|
16
|
+
|
17
|
+
using std::vector;
|
18
|
+
using ublas::prod;
|
19
|
+
using ublas::matrix;
|
20
|
+
|
21
|
+
|
22
|
+
void LinearRegression::setWeights(vector<double> weights)
|
23
|
+
{
|
24
|
+
m_ws = weights;
|
25
|
+
}
|
26
|
+
|
27
|
+
void LinearRegression::setFixedConstant(double val)
|
28
|
+
{
|
29
|
+
m_constant = val;
|
30
|
+
m_constantIsFixed = true;
|
31
|
+
}
|
32
|
+
|
33
|
+
pair<vector<double>,double> LinearRegression::getParameterEstimates()
|
34
|
+
{
|
35
|
+
return make_pair(m_bs,m_constant);
|
36
|
+
}
|
37
|
+
|
38
|
+
void LinearRegression::checkDimensions()
|
39
|
+
{
|
40
|
+
if (!m_ys.size())
|
41
|
+
throw std::runtime_error("[LinearRegression] Number of observations equals zero");
|
42
|
+
|
43
|
+
if (m_xs.size() != m_ys.size())
|
44
|
+
throw std::runtime_error("[LinearRegression] Number of observations in x doesn't match number of observations in y");
|
45
|
+
|
46
|
+
if (m_ws.size() && m_ws.size() != m_ys.size())
|
47
|
+
throw std::runtime_error("[LinearRegression] Number of specified weights doesn't match number of observations");
|
48
|
+
|
49
|
+
unsigned long dimensionOfX = m_xs.front().size();
|
50
|
+
BOOST_FOREACH(vector<double>& x, m_xs)
|
51
|
+
if (x.size() != dimensionOfX)
|
52
|
+
throw std::runtime_error("[LinearRegression] Dimensions of x variables are inconsistent between observations");
|
53
|
+
}
|
54
|
+
|
55
|
+
void LinearRegression::calculateStatistics()
|
56
|
+
{
|
57
|
+
if (!m_paramsAreValid)
|
58
|
+
throw std::runtime_error("[LinearRegression] Parameters have not been estimated");
|
59
|
+
|
60
|
+
calculateModelStatistics();
|
61
|
+
calculateParameterStatistics();
|
62
|
+
}
|
63
|
+
|
64
|
+
void LinearRegression::calculateParameterStatistics2()
|
65
|
+
{
|
66
|
+
// form the matrix X'X
|
67
|
+
ublas::matrix<double> X(m_xs.size(), m_xs.front().size()+1);
|
68
|
+
ublas::matrix<double>::iterator2 matrixIterator = X.begin2();
|
69
|
+
BOOST_FOREACH(vector<double>& row, m_xs)
|
70
|
+
{
|
71
|
+
matrixIterator = std::copy(row.begin(), row.end(), matrixIterator);
|
72
|
+
*(matrixIterator++) = 1.0;
|
73
|
+
}
|
74
|
+
ublas::matrix<double> X_transpose_X = ublas::prod(ublas::trans(X), X);
|
75
|
+
|
76
|
+
// Invert the matrix
|
77
|
+
ublas::matrix<double> X_transpose_X_inverse(X_transpose_X);
|
78
|
+
InvertMatrix(X_transpose_X, X_transpose_X_inverse);
|
79
|
+
|
80
|
+
// Also construct a t-stat for the constant
|
81
|
+
if (!m_constantIsFixed) m_bs.push_back(m_constant);
|
82
|
+
|
83
|
+
m_tStatistics.resize(m_bs.size());
|
84
|
+
for (unsigned int i=0; i<m_bs.size(); ++i)
|
85
|
+
{
|
86
|
+
m_tStatistics.at(i) = m_bs.at(i) / (m_s * sqrt(X_transpose_X_inverse(i,i)));
|
87
|
+
}
|
88
|
+
|
89
|
+
if (!m_constantIsFixed) m_bs.pop_back();
|
90
|
+
}
|
91
|
+
|
92
|
+
void LinearRegression::calculateModelStatistics()
|
93
|
+
{
|
94
|
+
checkDimensions();
|
95
|
+
checkParametersAreEstimated();
|
96
|
+
estimateYs();
|
97
|
+
|
98
|
+
double meanY = Utils::vectorSum(m_ys) / m_n;
|
99
|
+
double sumSquaresTotal = 0.0;
|
100
|
+
double sumSquaresRegression = 0.0;
|
101
|
+
double sumSquaresError = 0.0;
|
102
|
+
double meanWeight = Utils::vectorSum(m_ws) / m_n;
|
103
|
+
for (int i=0; i<m_n; ++i)
|
104
|
+
{
|
105
|
+
sumSquaresTotal += m_ws.at(i) / meanWeight * pow(m_ys.at(i) - meanY, 2.0);
|
106
|
+
sumSquaresRegression += m_ws.at(i) / meanWeight * pow(m_fittedYs.at(i) - meanY, 2.0);
|
107
|
+
sumSquaresError += m_ws.at(i) / meanWeight * pow(m_ys.at(i) - m_fittedYs.at(i), 2.0);
|
108
|
+
}
|
109
|
+
|
110
|
+
double meanSquaredRegression = sumSquaresRegression / (m_k);
|
111
|
+
|
112
|
+
m_rSquared = 1.0 - (sumSquaresError / sumSquaresTotal);
|
113
|
+
m_adjustedRSquared = 1.0 - (sumSquaresError / (m_n - m_p)) / (sumSquaresTotal / (m_n - 1));
|
114
|
+
m_fStatistic = (m_n-m_p) * sumSquaresRegression / (sumSquaresError * m_k);
|
115
|
+
m_sSquared = 1.0 / (m_n-m_p) * sumSquaresError;
|
116
|
+
m_s = sqrt(m_sSquared);
|
117
|
+
|
118
|
+
m_h_diagonal.resize(m_n, 0.0);
|
119
|
+
// auto XIterator = m_X.begin2(); // row-wise
|
120
|
+
// auto AIterator = m_A.begin1(); // column-wise
|
121
|
+
for (int i = 0; i < m_n; ++i)
|
122
|
+
{
|
123
|
+
double sumProduct = 0.0;
|
124
|
+
for (int j = 0; j < m_p; ++j)
|
125
|
+
sumProduct += m_X(i, j) * m_A(j, i);
|
126
|
+
m_h_diagonal.at(i) = sumProduct;
|
127
|
+
}
|
128
|
+
|
129
|
+
m_pressStatistic = 0.0;
|
130
|
+
m_presarStatistic = 0.0;
|
131
|
+
|
132
|
+
m_predictedYs.resize(m_n);
|
133
|
+
for (int i = 0; i < m_n; ++i)
|
134
|
+
{
|
135
|
+
double ei = m_fittedYs.at(i) - m_ys.at(i);
|
136
|
+
double hii = m_h_diagonal.at(i);
|
137
|
+
double ei_prediction = ei / (1.0 - hii); // best thing eva!!!
|
138
|
+
m_predictedYs.at(i) = m_ys.at(i) + ei_prediction;
|
139
|
+
m_presarStatistic += m_ws.at(i) / meanWeight * abs((float)ei_prediction);
|
140
|
+
m_pressStatistic += m_ws.at(i) / meanWeight * pow(ei_prediction, 2.0);
|
141
|
+
}
|
142
|
+
m_rSquaredPrediction = 1.0 - m_pressStatistic / sumSquaresTotal;
|
143
|
+
}
|
144
|
+
|
145
|
+
void LinearRegression::estimateYs()
|
146
|
+
{
|
147
|
+
m_fittedYs.clear();
|
148
|
+
m_fittedYs.resize(m_ys.size(), m_constant);
|
149
|
+
for (unsigned int i=0; i<m_ys.size(); ++i)
|
150
|
+
{
|
151
|
+
for (unsigned int j=0; j<m_bs.size(); ++j)
|
152
|
+
m_fittedYs.at(i) += m_bs.at(j) * m_xs.at(i).at(j);
|
153
|
+
}
|
154
|
+
}
|
155
|
+
|
156
|
+
void LinearRegression::checkParametersAreEstimated()
|
157
|
+
{
|
158
|
+
if (!m_paramsAreValid)
|
159
|
+
throw std::runtime_error("[LinearRegression] Parameters have not been estimated");
|
160
|
+
}
|
161
|
+
|
162
|
+
double LinearRegression::getRSquared()
|
163
|
+
{
|
164
|
+
return m_rSquared;
|
165
|
+
}
|
166
|
+
|
167
|
+
double LinearRegression::getFstatistic()
|
168
|
+
{
|
169
|
+
return m_fStatistic;
|
170
|
+
}
|
171
|
+
|
172
|
+
vector<double>& LinearRegression::getFittedYs()
|
173
|
+
{
|
174
|
+
return m_fittedYs;
|
175
|
+
}
|
176
|
+
|
177
|
+
vector<double>& LinearRegression::getTstatistics()
|
178
|
+
{
|
179
|
+
return m_tStatistics;
|
180
|
+
}
|
181
|
+
|
182
|
+
void LinearRegression::populateMembers()
|
183
|
+
{
|
184
|
+
m_k = m_xs.front().size();
|
185
|
+
m_p = m_k + (m_constantIsFixed ? 0 : 1);
|
186
|
+
m_n = m_xs.size();
|
187
|
+
|
188
|
+
// populate m_X
|
189
|
+
m_X.resize(m_n, m_p);
|
190
|
+
ublas::matrix<double>::iterator2 matrixIterator = m_X.begin2();
|
191
|
+
BOOST_FOREACH(vector<double>& row, m_xs)
|
192
|
+
{
|
193
|
+
matrixIterator = std::copy(row.begin(), row.end(), matrixIterator);
|
194
|
+
if (!m_constantIsFixed) *(matrixIterator++) = 1.0;
|
195
|
+
}
|
196
|
+
|
197
|
+
// populate m_Y
|
198
|
+
m_Y.resize(m_n, 1);
|
199
|
+
ublas::matrix<double>::iterator1 matrixIterator2 = m_Y.begin1();
|
200
|
+
BOOST_FOREACH(double& y, m_ys)
|
201
|
+
{
|
202
|
+
(*matrixIterator2) = y;
|
203
|
+
++matrixIterator2;
|
204
|
+
}
|
205
|
+
|
206
|
+
// populate m_ws with 1's if it's not already defined
|
207
|
+
if (!m_ws.size())
|
208
|
+
{
|
209
|
+
m_ws.resize(m_n, 1.0);
|
210
|
+
}
|
211
|
+
|
212
|
+
// form the matrix X' [P x N]
|
213
|
+
m_Xtranspose = ublas::trans(m_X);
|
214
|
+
|
215
|
+
// form the matrix X'WX [P x N] . [N x N] . [N x P] => [P x P]
|
216
|
+
m_Xtranspose_W_X.resize(m_p, m_p);
|
217
|
+
m_Xtranspose_W = multiplyMatrixByWeights(m_Xtranspose);
|
218
|
+
m_Xtranspose_W_X = ublas::prod(m_Xtranspose_W, m_X);
|
219
|
+
|
220
|
+
// Invert the matrix
|
221
|
+
m_Xtranspose_W_X_inverse.resize(m_p, m_p);
|
222
|
+
InvertMatrix(m_Xtranspose_W_X, m_Xtranspose_W_X_inverse);
|
223
|
+
}
|
224
|
+
|
225
|
+
void LinearRegression::calculateParameterStatistics()
|
226
|
+
{
|
227
|
+
m_tStatistics.resize(m_p);
|
228
|
+
m_standardErrors.resize(m_p);
|
229
|
+
|
230
|
+
ublas::matrix<double> AAt = prod(m_A, ublas::trans(m_A));
|
231
|
+
for (int i=0; i<m_p; ++i)
|
232
|
+
{
|
233
|
+
// made more complicated by weights!!!
|
234
|
+
m_standardErrors.at(i) = m_s * sqrt(AAt(i,i));
|
235
|
+
m_tStatistics.at(i) = m_B(i,0) / m_standardErrors.at(i);
|
236
|
+
}
|
237
|
+
}
|
238
|
+
|
239
|
+
double LinearRegression::getPressStatistic()
|
240
|
+
{
|
241
|
+
return m_pressStatistic;
|
242
|
+
}
|
243
|
+
|
244
|
+
double LinearRegression::getPresarStatistic()
|
245
|
+
{
|
246
|
+
return m_presarStatistic;
|
247
|
+
}
|
248
|
+
|
249
|
+
double LinearRegression::getRSquaredPrediction()
|
250
|
+
{
|
251
|
+
return m_rSquaredPrediction;
|
252
|
+
}
|
253
|
+
|
254
|
+
vector<double>& LinearRegression::getPredictedYs()
|
255
|
+
{
|
256
|
+
return m_predictedYs;
|
257
|
+
}
|
258
|
+
|
259
|
+
double LinearRegression::getAdjustedRSquared()
|
260
|
+
{
|
261
|
+
return m_adjustedRSquared;
|
262
|
+
}
|
263
|
+
|
264
|
+
matrix<double> LinearRegression::multiplyMatrixByWeights(matrix<double>& mat)
|
265
|
+
{
|
266
|
+
if (mat.size2() != m_ws.size())
|
267
|
+
throw std::runtime_error("[LinearRegression::multiplyMatrixByWeights] invalid matrix dimensions!");
|
268
|
+
|
269
|
+
matrix<double> new_matrix = mat; // copy
|
270
|
+
for (unsigned int j = 0; j < new_matrix.size2(); ++j) // each column
|
271
|
+
{
|
272
|
+
double weight = m_ws.at(j);
|
273
|
+
for (unsigned int i = 0; i < new_matrix.size1(); ++i) // each row
|
274
|
+
new_matrix(i,j) *= weight;
|
275
|
+
}
|
276
|
+
return new_matrix;
|
277
|
+
}
|
278
|
+
|
279
|
+
matrix<double> LinearRegression::multiplyWeightsByMatrix(matrix<double>& mat)
|
280
|
+
{
|
281
|
+
if (mat.size1() != m_ws.size())
|
282
|
+
throw std::runtime_error("[LinearRegression::multiplyMatrixByWeights] invalid matrix dimensions!");
|
283
|
+
|
284
|
+
matrix<double> new_matrix = mat; // copy
|
285
|
+
for (unsigned int i = 0; i < new_matrix.size2(); ++i) // each row
|
286
|
+
{
|
287
|
+
double weight = m_ws.at(i);
|
288
|
+
for (unsigned int j = 0; j < new_matrix.size1(); ++j) // each column
|
289
|
+
new_matrix(i,j) *= weight;
|
290
|
+
}
|
291
|
+
return new_matrix;
|
292
|
+
}
|
293
|
+
|
294
|
+
vector<double>& LinearRegression::getStandardErrors()
|
295
|
+
{
|
296
|
+
return m_standardErrors;
|
297
|
+
}
|
298
|
+
|
299
|
+
double LinearRegression::getSSquared()
|
300
|
+
{
|
301
|
+
return m_sSquared;
|
302
|
+
}
|
303
|
+
|
304
|
+
|
305
|
+
|
@@ -0,0 +1,75 @@
|
|
1
|
+
#include "LinearRegression/OLSLinearRegression.h"
|
2
|
+
#include "utils/MathUtils.h"
|
3
|
+
#include "utils/Utils.h"
|
4
|
+
|
5
|
+
#include <iostream>
|
6
|
+
#include <boost/numeric/ublas/io.hpp>
|
7
|
+
using std::cout;
|
8
|
+
using std::endl;
|
9
|
+
|
10
|
+
namespace ublas = boost::numeric::ublas;
|
11
|
+
using Utils::operator+=;
|
12
|
+
using ublas::matrix;
|
13
|
+
using ublas::prod;
|
14
|
+
|
15
|
+
OLSLinearRegression::OLSLinearRegression(std::vector<double> xs, std::vector<double> ys,
|
16
|
+
std::vector<double> weights)
|
17
|
+
: LinearRegression(xs, ys, weights)
|
18
|
+
{
|
19
|
+
calculate();
|
20
|
+
}
|
21
|
+
|
22
|
+
OLSLinearRegression::OLSLinearRegression(std::vector<std::vector<double> > xs, std::vector<double> ys,
|
23
|
+
std::vector<double> weights)
|
24
|
+
: LinearRegression(xs, ys, weights)
|
25
|
+
{
|
26
|
+
calculate();
|
27
|
+
}
|
28
|
+
|
29
|
+
OLSLinearRegression::OLSLinearRegression(std::vector<std::vector<double> > xs, std::vector<double> ys,
|
30
|
+
double fixedConstant, std::vector<double> weights)
|
31
|
+
: LinearRegression(xs, ys, fixedConstant, weights)
|
32
|
+
{
|
33
|
+
calculate();
|
34
|
+
}
|
35
|
+
|
36
|
+
OLSLinearRegression::~OLSLinearRegression()
|
37
|
+
{}
|
38
|
+
|
39
|
+
void OLSLinearRegression::calculate()
|
40
|
+
{
|
41
|
+
checkDimensions();
|
42
|
+
// matrix based implementation
|
43
|
+
|
44
|
+
// b = inverse(X'WX)X'Wy
|
45
|
+
// where X is the data matrix (rows are observations, columns are our X variables. If a constant is to be estimated,
|
46
|
+
// then the first column is set to 1, and the first estimated parameter will be the constant).
|
47
|
+
// X' is the transpose of X
|
48
|
+
// W is a matrix with diag(w1, w2, w3) etc, where wi is the weight of observation i
|
49
|
+
// y is the column matrix containing the observed y's.
|
50
|
+
populateMembers();
|
51
|
+
EstimateBs();
|
52
|
+
if (m_paramsAreValid) calculateStatistics();
|
53
|
+
}
|
54
|
+
|
55
|
+
void OLSLinearRegression::EstimateBs()
|
56
|
+
{
|
57
|
+
matrix<double> Y = m_Y;
|
58
|
+
if (m_constantIsFixed)
|
59
|
+
{
|
60
|
+
for (int i = 0; i < m_n; ++i) Y(i, 0) -= m_constant;
|
61
|
+
}
|
62
|
+
|
63
|
+
m_A = prod(m_Xtranspose_W_X_inverse, m_Xtranspose_W);
|
64
|
+
m_B = prod(m_A, Y);
|
65
|
+
|
66
|
+
// set m_bs and constant
|
67
|
+
m_bs.resize(m_k);
|
68
|
+
for (int i = 0; i < m_k; ++i)
|
69
|
+
m_bs.at(i) = m_B(i, 0);
|
70
|
+
|
71
|
+
if (!m_constantIsFixed)
|
72
|
+
m_constant = m_B(m_p-1, 0);
|
73
|
+
|
74
|
+
m_paramsAreValid = true;
|
75
|
+
}
|
@@ -0,0 +1,50 @@
|
|
1
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
2
|
+
|
3
|
+
DecisionTreeExperiment::DecisionTreeExperiment()
|
4
|
+
: MLExperiment()
|
5
|
+
{
|
6
|
+
|
7
|
+
}
|
8
|
+
|
9
|
+
DecisionTreeExperiment::DecisionTreeExperiment(shared_ptr<MLExperiment> mlExperiment)
|
10
|
+
: MLExperiment(mlExperiment)
|
11
|
+
{
|
12
|
+
|
13
|
+
}
|
14
|
+
|
15
|
+
|
16
|
+
|
17
|
+
DecisionTreeExperiment::~DecisionTreeExperiment()
|
18
|
+
{
|
19
|
+
|
20
|
+
}
|
21
|
+
|
22
|
+
void DecisionTreeExperiment::setF(double f)
|
23
|
+
{
|
24
|
+
m_F = f;
|
25
|
+
}
|
26
|
+
|
27
|
+
void DecisionTreeExperiment::setZ(double z)
|
28
|
+
{
|
29
|
+
m_Z = z;
|
30
|
+
}
|
31
|
+
|
32
|
+
double DecisionTreeExperiment::getF()
|
33
|
+
{
|
34
|
+
return m_F;
|
35
|
+
}
|
36
|
+
|
37
|
+
double DecisionTreeExperiment::getZ()
|
38
|
+
{
|
39
|
+
return m_Z;
|
40
|
+
}
|
41
|
+
|
42
|
+
void DecisionTreeExperiment::incrementF(double increment)
|
43
|
+
{
|
44
|
+
m_F += increment;
|
45
|
+
}
|
46
|
+
|
47
|
+
double DecisionTreeExperiment::getY()
|
48
|
+
{
|
49
|
+
return m_yValue;
|
50
|
+
}
|
@@ -0,0 +1,195 @@
|
|
1
|
+
#include "MachineLearning/DecisionTree/DecisionTreeNode.h"
|
2
|
+
#include "MachineLearning/DecisionTree/SplitDefinition.h"
|
3
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
4
|
+
#include "utils/Utils.h"
|
5
|
+
|
6
|
+
#include <stdexcept>
|
7
|
+
using std::runtime_error;
|
8
|
+
|
9
|
+
bool DecisionTreeNode::m_missingValueDefined = false;
|
10
|
+
double DecisionTreeNode::m_missingValue = -1.0;
|
11
|
+
|
12
|
+
DecisionTreeNode::DecisionTreeNode( vector<shared_ptr<DecisionTreeExperiment> > experiments,
|
13
|
+
double sumZ,
|
14
|
+
double sumW,
|
15
|
+
Partition partition,
|
16
|
+
shared_ptr<SplitDefinition> parentSplitDefinition)
|
17
|
+
: m_experiments(experiments), m_nodeHasChildren(false), m_sumZ(sumZ), m_sumW(sumW),
|
18
|
+
m_whichPartitionAmI(partition), m_parentSplitDefinition(parentSplitDefinition)
|
19
|
+
{
|
20
|
+
|
21
|
+
}
|
22
|
+
|
23
|
+
DecisionTreeNode::~DecisionTreeNode()
|
24
|
+
{
|
25
|
+
|
26
|
+
}
|
27
|
+
|
28
|
+
shared_ptr<DecisionTreeNode> DecisionTreeNode::getTerminalNodeForExperiment(shared_ptr<DecisionTreeExperiment> experiment)
|
29
|
+
{
|
30
|
+
if (!m_nodeHasChildren)
|
31
|
+
throw std::runtime_error("Node is a terminal node, so you shouldn't ask it for a terminal node!");
|
32
|
+
|
33
|
+
if (m_splitDefinition.get() == 0)
|
34
|
+
throw std::runtime_error("Node has children, but split definition is empty");
|
35
|
+
|
36
|
+
shared_ptr<DecisionTreeNode> childForExperiment = chooseChild(experiment);
|
37
|
+
|
38
|
+
if (childForExperiment.get() == 0)
|
39
|
+
return childForExperiment;
|
40
|
+
else if (childForExperiment->getSumW() == 0)
|
41
|
+
{
|
42
|
+
// this likely means that the value is missing, but there weren't any missing values in the
|
43
|
+
// bagged training set. Therefore, there is no weight in the missing child.
|
44
|
+
// return an empty pointer, and this DecisionTreeNode will become the one chosen.
|
45
|
+
return shared_ptr<DecisionTreeNode>();
|
46
|
+
}
|
47
|
+
else if (childForExperiment->isTerminalNode())
|
48
|
+
return childForExperiment;
|
49
|
+
else
|
50
|
+
{
|
51
|
+
shared_ptr<DecisionTreeNode> terminalNode = childForExperiment->getTerminalNodeForExperiment(experiment);
|
52
|
+
if (terminalNode.get() == 0)
|
53
|
+
{
|
54
|
+
// we have encountered a NEW category...therefore we couldn't split on childForExperiment.
|
55
|
+
// therefore, return the child itself.
|
56
|
+
return childForExperiment;
|
57
|
+
}
|
58
|
+
else
|
59
|
+
return terminalNode;
|
60
|
+
}
|
61
|
+
|
62
|
+
}
|
63
|
+
|
64
|
+
shared_ptr<DecisionTreeNode> DecisionTreeNode::chooseChild(shared_ptr<DecisionTreeExperiment> experiment)
|
65
|
+
{
|
66
|
+
if (!m_nodeHasChildren)
|
67
|
+
throw std::runtime_error("[DecisionTreeNode::chooseChild] - this Decision Tree has no children!");
|
68
|
+
|
69
|
+
double featureValue = experiment->getFeatureValue(m_splitDefinition->getFeatureIndex());
|
70
|
+
|
71
|
+
if (m_missingValueDefined && m_missingValue == featureValue)
|
72
|
+
return m_missingChild;
|
73
|
+
|
74
|
+
if (m_splitDefinition->isCategorical()) // categorical variable
|
75
|
+
{
|
76
|
+
if (Utils::hasElement(m_splitDefinition->getLhsCategories(), featureValue))
|
77
|
+
return m_lhsChild;
|
78
|
+
else if (Utils::hasElement(m_splitDefinition->getRhsCategories(), featureValue))
|
79
|
+
return m_rhsChild;
|
80
|
+
else
|
81
|
+
{
|
82
|
+
// it's not missing, but not in left or right. Therefore, we have a NEW category.
|
83
|
+
// We should return an empty pointer, and let the parent handle it.
|
84
|
+
return shared_ptr<DecisionTreeNode>();
|
85
|
+
}
|
86
|
+
}
|
87
|
+
else // continuous variable
|
88
|
+
{
|
89
|
+
double splitValue = m_splitDefinition->getSplitValue();
|
90
|
+
if (m_missingValueDefined && m_missingValue == splitValue)
|
91
|
+
{
|
92
|
+
// complicated logic. Our split value equals the missing value. Therefore, we split off missing versus
|
93
|
+
// everything else (which gets put in the rhsChild). As our feature value is not the missing value, we choose
|
94
|
+
// the rhsChild.
|
95
|
+
return m_rhsChild;
|
96
|
+
}
|
97
|
+
else if (featureValue < splitValue)
|
98
|
+
return m_lhsChild;
|
99
|
+
else
|
100
|
+
return m_rhsChild;
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
104
|
+
void DecisionTreeNode::defineSplit( shared_ptr<SplitDefinition> splitDefinition,
|
105
|
+
shared_ptr<DecisionTreeNode> lhsChild,
|
106
|
+
shared_ptr<DecisionTreeNode> rhsChild,
|
107
|
+
shared_ptr<DecisionTreeNode> missingChild)
|
108
|
+
{
|
109
|
+
setChildren(lhsChild, rhsChild, missingChild);
|
110
|
+
m_splitDefinition = splitDefinition;
|
111
|
+
}
|
112
|
+
|
113
|
+
void DecisionTreeNode::setChildren( shared_ptr<DecisionTreeNode> lhsChild,
|
114
|
+
shared_ptr<DecisionTreeNode> rhsChild,
|
115
|
+
shared_ptr<DecisionTreeNode> missingChild)
|
116
|
+
{
|
117
|
+
m_nodeHasChildren = true;
|
118
|
+
m_lhsChild = lhsChild;
|
119
|
+
m_rhsChild = rhsChild;
|
120
|
+
m_missingChild = missingChild;
|
121
|
+
}
|
122
|
+
|
123
|
+
vector<shared_ptr<DecisionTreeExperiment> > DecisionTreeNode::getExperiments()
|
124
|
+
{
|
125
|
+
return m_experiments;
|
126
|
+
}
|
127
|
+
|
128
|
+
bool DecisionTreeNode::isTerminalNode()
|
129
|
+
{
|
130
|
+
return !m_nodeHasChildren;
|
131
|
+
}
|
132
|
+
|
133
|
+
void DecisionTreeNode::clearExperimentsWithinTree()
|
134
|
+
{
|
135
|
+
m_experiments.clear();
|
136
|
+
if (m_nodeHasChildren)
|
137
|
+
{
|
138
|
+
m_lhsChild->clearExperimentsWithinTree();
|
139
|
+
m_rhsChild->clearExperimentsWithinTree();
|
140
|
+
m_missingChild->clearExperimentsWithinTree();
|
141
|
+
}
|
142
|
+
}
|
143
|
+
|
144
|
+
double DecisionTreeNode::getSumZ()
|
145
|
+
{
|
146
|
+
return m_sumZ;
|
147
|
+
}
|
148
|
+
|
149
|
+
double DecisionTreeNode::getSumW()
|
150
|
+
{
|
151
|
+
return m_sumW;
|
152
|
+
}
|
153
|
+
|
154
|
+
void DecisionTreeNode::setMissingValue( double missingValue )
|
155
|
+
{
|
156
|
+
m_missingValue = missingValue;
|
157
|
+
m_missingValueDefined = true;
|
158
|
+
}
|
159
|
+
|
160
|
+
shared_ptr<SplitDefinition> DecisionTreeNode::getSplitDefinition()
|
161
|
+
{
|
162
|
+
return m_splitDefinition;
|
163
|
+
}
|
164
|
+
|
165
|
+
shared_ptr<SplitDefinition> DecisionTreeNode::getParentSplitDefinition()
|
166
|
+
{
|
167
|
+
return m_parentSplitDefinition;
|
168
|
+
}
|
169
|
+
|
170
|
+
Partition DecisionTreeNode::getPartition()
|
171
|
+
{
|
172
|
+
return m_whichPartitionAmI;
|
173
|
+
}
|
174
|
+
|
175
|
+
void DecisionTreeNode::setSumZ( double sumZ )
|
176
|
+
{
|
177
|
+
m_sumZ = sumZ;
|
178
|
+
}
|
179
|
+
|
180
|
+
void DecisionTreeNode::setSumW( double sumW )
|
181
|
+
{
|
182
|
+
m_sumW = sumW;
|
183
|
+
}
|
184
|
+
|
185
|
+
void DecisionTreeNode::updateSums()
|
186
|
+
{
|
187
|
+
m_sumW = 0.0;
|
188
|
+
m_sumZ = 0.0;
|
189
|
+
for (unsigned int i=0; i<m_experiments.size(); ++i)
|
190
|
+
{
|
191
|
+
double w = m_experiments.at(i)->getWeight();
|
192
|
+
m_sumW += w;
|
193
|
+
m_sumZ += w * m_experiments.at(i)->getZ();
|
194
|
+
}
|
195
|
+
}
|