ml4r 0.1.4 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. data/ext/ml4r/LinearRegression/LinearRegression.cpp +305 -0
  2. data/ext/ml4r/LinearRegression/OLSLinearRegression.cpp +75 -0
  3. data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeExperiment.cpp +50 -0
  4. data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeNode.cpp +195 -0
  5. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitter.cpp +551 -0
  6. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterCategorical.cpp +22 -0
  7. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterContinuous.cpp +21 -0
  8. data/ext/ml4r/MachineLearning/DecisionTree/SplitDefinition.cpp +142 -0
  9. data/ext/ml4r/MachineLearning/GBM/BernoulliCalculator.cpp +95 -0
  10. data/ext/ml4r/MachineLearning/GBM/GBMEstimator.cpp +601 -0
  11. data/ext/ml4r/MachineLearning/GBM/GBMOutput.cpp +86 -0
  12. data/ext/ml4r/MachineLearning/GBM/GBMRunner.cpp +117 -0
  13. data/ext/ml4r/MachineLearning/GBM/GaussianCalculator.cpp +94 -0
  14. data/ext/ml4r/MachineLearning/GBM/ZenithGBM.cpp +317 -0
  15. data/ext/ml4r/MachineLearning/MLData/MLData.cpp +232 -0
  16. data/ext/ml4r/MachineLearning/MLData/MLDataFields.cpp +1 -0
  17. data/ext/ml4r/MachineLearning/MLData/MLDataReader.cpp +139 -0
  18. data/ext/ml4r/MachineLearning/MLData/ZenithMLData.cpp +96 -0
  19. data/ext/ml4r/MachineLearning/MLData/ZenithMLDataReader.cpp +113 -0
  20. data/ext/ml4r/MachineLearning/MLExperiment.cpp +69 -0
  21. data/ext/ml4r/MachineLearning/MLRunner.cpp +183 -0
  22. data/ext/ml4r/MachineLearning/MLUtils.cpp +15 -0
  23. data/ext/ml4r/MachineLearning/RandomForest/RandomForestEstimator.cpp +172 -0
  24. data/ext/ml4r/MachineLearning/RandomForest/RandomForestOutput.cpp +66 -0
  25. data/ext/ml4r/MachineLearning/RandomForest/RandomForestRunner.cpp +84 -0
  26. data/ext/ml4r/MachineLearning/RandomForest/ZenithRandomForest.cpp +184 -0
  27. data/ext/ml4r/ml4r.cpp +34 -0
  28. data/ext/ml4r/ml4r_wrap.cpp +15727 -0
  29. data/ext/ml4r/utils/MathUtils.cpp +204 -0
  30. data/ext/ml4r/utils/StochasticUtils.cpp +73 -0
  31. data/ext/ml4r/utils/Utils.cpp +14 -0
  32. data/ext/ml4r/utils/VlcMessage.cpp +3 -0
  33. metadata +33 -1
@@ -0,0 +1,305 @@
1
+ #include <boost/numeric/ublas/vector.hpp>
2
+ #include <boost/numeric/ublas/vector_proxy.hpp>
3
+ #include <boost/numeric/ublas/matrix.hpp>
4
+ #include <boost/numeric/ublas/triangular.hpp>
5
+ #include <boost/numeric/ublas/lu.hpp>
6
+ #include <boost/numeric/ublas/io.hpp>
7
+ #include <boost/foreach.hpp>
8
+ #include <iostream>
9
+ using std::cout;
10
+ using std::endl;
11
+
12
+ #include "LinearRegression/LinearRegression.h"
13
+ #include "utils/MatrixInversion.h"
14
+ #include "utils/Utils.h"
15
+ namespace ublas = boost::numeric::ublas;
16
+
17
+ using std::vector;
18
+ using ublas::prod;
19
+ using ublas::matrix;
20
+
21
+
22
+ void LinearRegression::setWeights(vector<double> weights)
23
+ {
24
+ m_ws = weights;
25
+ }
26
+
27
+ void LinearRegression::setFixedConstant(double val)
28
+ {
29
+ m_constant = val;
30
+ m_constantIsFixed = true;
31
+ }
32
+
33
+ pair<vector<double>,double> LinearRegression::getParameterEstimates()
34
+ {
35
+ return make_pair(m_bs,m_constant);
36
+ }
37
+
38
+ void LinearRegression::checkDimensions()
39
+ {
40
+ if (!m_ys.size())
41
+ throw std::runtime_error("[LinearRegression] Number of observations equals zero");
42
+
43
+ if (m_xs.size() != m_ys.size())
44
+ throw std::runtime_error("[LinearRegression] Number of observations in x doesn't match number of observations in y");
45
+
46
+ if (m_ws.size() && m_ws.size() != m_ys.size())
47
+ throw std::runtime_error("[LinearRegression] Number of specified weights doesn't match number of observations");
48
+
49
+ unsigned long dimensionOfX = m_xs.front().size();
50
+ BOOST_FOREACH(vector<double>& x, m_xs)
51
+ if (x.size() != dimensionOfX)
52
+ throw std::runtime_error("[LinearRegression] Dimensions of x variables are inconsistent between observations");
53
+ }
54
+
55
+ void LinearRegression::calculateStatistics()
56
+ {
57
+ if (!m_paramsAreValid)
58
+ throw std::runtime_error("[LinearRegression] Parameters have not been estimated");
59
+
60
+ calculateModelStatistics();
61
+ calculateParameterStatistics();
62
+ }
63
+
64
+ void LinearRegression::calculateParameterStatistics2()
65
+ {
66
+ // form the matrix X'X
67
+ ublas::matrix<double> X(m_xs.size(), m_xs.front().size()+1);
68
+ ublas::matrix<double>::iterator2 matrixIterator = X.begin2();
69
+ BOOST_FOREACH(vector<double>& row, m_xs)
70
+ {
71
+ matrixIterator = std::copy(row.begin(), row.end(), matrixIterator);
72
+ *(matrixIterator++) = 1.0;
73
+ }
74
+ ublas::matrix<double> X_transpose_X = ublas::prod(ublas::trans(X), X);
75
+
76
+ // Invert the matrix
77
+ ublas::matrix<double> X_transpose_X_inverse(X_transpose_X);
78
+ InvertMatrix(X_transpose_X, X_transpose_X_inverse);
79
+
80
+ // Also construct a t-stat for the constant
81
+ if (!m_constantIsFixed) m_bs.push_back(m_constant);
82
+
83
+ m_tStatistics.resize(m_bs.size());
84
+ for (unsigned int i=0; i<m_bs.size(); ++i)
85
+ {
86
+ m_tStatistics.at(i) = m_bs.at(i) / (m_s * sqrt(X_transpose_X_inverse(i,i)));
87
+ }
88
+
89
+ if (!m_constantIsFixed) m_bs.pop_back();
90
+ }
91
+
92
+ void LinearRegression::calculateModelStatistics()
93
+ {
94
+ checkDimensions();
95
+ checkParametersAreEstimated();
96
+ estimateYs();
97
+
98
+ double meanY = Utils::vectorSum(m_ys) / m_n;
99
+ double sumSquaresTotal = 0.0;
100
+ double sumSquaresRegression = 0.0;
101
+ double sumSquaresError = 0.0;
102
+ double meanWeight = Utils::vectorSum(m_ws) / m_n;
103
+ for (int i=0; i<m_n; ++i)
104
+ {
105
+ sumSquaresTotal += m_ws.at(i) / meanWeight * pow(m_ys.at(i) - meanY, 2.0);
106
+ sumSquaresRegression += m_ws.at(i) / meanWeight * pow(m_fittedYs.at(i) - meanY, 2.0);
107
+ sumSquaresError += m_ws.at(i) / meanWeight * pow(m_ys.at(i) - m_fittedYs.at(i), 2.0);
108
+ }
109
+
110
+ double meanSquaredRegression = sumSquaresRegression / (m_k);
111
+
112
+ m_rSquared = 1.0 - (sumSquaresError / sumSquaresTotal);
113
+ m_adjustedRSquared = 1.0 - (sumSquaresError / (m_n - m_p)) / (sumSquaresTotal / (m_n - 1));
114
+ m_fStatistic = (m_n-m_p) * sumSquaresRegression / (sumSquaresError * m_k);
115
+ m_sSquared = 1.0 / (m_n-m_p) * sumSquaresError;
116
+ m_s = sqrt(m_sSquared);
117
+
118
+ m_h_diagonal.resize(m_n, 0.0);
119
+ // auto XIterator = m_X.begin2(); // row-wise
120
+ // auto AIterator = m_A.begin1(); // column-wise
121
+ for (int i = 0; i < m_n; ++i)
122
+ {
123
+ double sumProduct = 0.0;
124
+ for (int j = 0; j < m_p; ++j)
125
+ sumProduct += m_X(i, j) * m_A(j, i);
126
+ m_h_diagonal.at(i) = sumProduct;
127
+ }
128
+
129
+ m_pressStatistic = 0.0;
130
+ m_presarStatistic = 0.0;
131
+
132
+ m_predictedYs.resize(m_n);
133
+ for (int i = 0; i < m_n; ++i)
134
+ {
135
+ double ei = m_fittedYs.at(i) - m_ys.at(i);
136
+ double hii = m_h_diagonal.at(i);
137
+ double ei_prediction = ei / (1.0 - hii); // best thing eva!!!
138
+ m_predictedYs.at(i) = m_ys.at(i) + ei_prediction;
139
+ m_presarStatistic += m_ws.at(i) / meanWeight * abs((float)ei_prediction);
140
+ m_pressStatistic += m_ws.at(i) / meanWeight * pow(ei_prediction, 2.0);
141
+ }
142
+ m_rSquaredPrediction = 1.0 - m_pressStatistic / sumSquaresTotal;
143
+ }
144
+
145
+ void LinearRegression::estimateYs()
146
+ {
147
+ m_fittedYs.clear();
148
+ m_fittedYs.resize(m_ys.size(), m_constant);
149
+ for (unsigned int i=0; i<m_ys.size(); ++i)
150
+ {
151
+ for (unsigned int j=0; j<m_bs.size(); ++j)
152
+ m_fittedYs.at(i) += m_bs.at(j) * m_xs.at(i).at(j);
153
+ }
154
+ }
155
+
156
+ void LinearRegression::checkParametersAreEstimated()
157
+ {
158
+ if (!m_paramsAreValid)
159
+ throw std::runtime_error("[LinearRegression] Parameters have not been estimated");
160
+ }
161
+
162
+ double LinearRegression::getRSquared()
163
+ {
164
+ return m_rSquared;
165
+ }
166
+
167
+ double LinearRegression::getFstatistic()
168
+ {
169
+ return m_fStatistic;
170
+ }
171
+
172
+ vector<double>& LinearRegression::getFittedYs()
173
+ {
174
+ return m_fittedYs;
175
+ }
176
+
177
+ vector<double>& LinearRegression::getTstatistics()
178
+ {
179
+ return m_tStatistics;
180
+ }
181
+
182
+ void LinearRegression::populateMembers()
183
+ {
184
+ m_k = m_xs.front().size();
185
+ m_p = m_k + (m_constantIsFixed ? 0 : 1);
186
+ m_n = m_xs.size();
187
+
188
+ // populate m_X
189
+ m_X.resize(m_n, m_p);
190
+ ublas::matrix<double>::iterator2 matrixIterator = m_X.begin2();
191
+ BOOST_FOREACH(vector<double>& row, m_xs)
192
+ {
193
+ matrixIterator = std::copy(row.begin(), row.end(), matrixIterator);
194
+ if (!m_constantIsFixed) *(matrixIterator++) = 1.0;
195
+ }
196
+
197
+ // populate m_Y
198
+ m_Y.resize(m_n, 1);
199
+ ublas::matrix<double>::iterator1 matrixIterator2 = m_Y.begin1();
200
+ BOOST_FOREACH(double& y, m_ys)
201
+ {
202
+ (*matrixIterator2) = y;
203
+ ++matrixIterator2;
204
+ }
205
+
206
+ // populate m_ws with 1's if it's not already defined
207
+ if (!m_ws.size())
208
+ {
209
+ m_ws.resize(m_n, 1.0);
210
+ }
211
+
212
+ // form the matrix X' [P x N]
213
+ m_Xtranspose = ublas::trans(m_X);
214
+
215
+ // form the matrix X'WX [P x N] . [N x N] . [N x P] => [P x P]
216
+ m_Xtranspose_W_X.resize(m_p, m_p);
217
+ m_Xtranspose_W = multiplyMatrixByWeights(m_Xtranspose);
218
+ m_Xtranspose_W_X = ublas::prod(m_Xtranspose_W, m_X);
219
+
220
+ // Invert the matrix
221
+ m_Xtranspose_W_X_inverse.resize(m_p, m_p);
222
+ InvertMatrix(m_Xtranspose_W_X, m_Xtranspose_W_X_inverse);
223
+ }
224
+
225
+ void LinearRegression::calculateParameterStatistics()
226
+ {
227
+ m_tStatistics.resize(m_p);
228
+ m_standardErrors.resize(m_p);
229
+
230
+ ublas::matrix<double> AAt = prod(m_A, ublas::trans(m_A));
231
+ for (int i=0; i<m_p; ++i)
232
+ {
233
+ // made more complicated by weights!!!
234
+ m_standardErrors.at(i) = m_s * sqrt(AAt(i,i));
235
+ m_tStatistics.at(i) = m_B(i,0) / m_standardErrors.at(i);
236
+ }
237
+ }
238
+
239
+ double LinearRegression::getPressStatistic()
240
+ {
241
+ return m_pressStatistic;
242
+ }
243
+
244
+ double LinearRegression::getPresarStatistic()
245
+ {
246
+ return m_presarStatistic;
247
+ }
248
+
249
+ double LinearRegression::getRSquaredPrediction()
250
+ {
251
+ return m_rSquaredPrediction;
252
+ }
253
+
254
+ vector<double>& LinearRegression::getPredictedYs()
255
+ {
256
+ return m_predictedYs;
257
+ }
258
+
259
+ double LinearRegression::getAdjustedRSquared()
260
+ {
261
+ return m_adjustedRSquared;
262
+ }
263
+
264
+ matrix<double> LinearRegression::multiplyMatrixByWeights(matrix<double>& mat)
265
+ {
266
+ if (mat.size2() != m_ws.size())
267
+ throw std::runtime_error("[LinearRegression::multiplyMatrixByWeights] invalid matrix dimensions!");
268
+
269
+ matrix<double> new_matrix = mat; // copy
270
+ for (unsigned int j = 0; j < new_matrix.size2(); ++j) // each column
271
+ {
272
+ double weight = m_ws.at(j);
273
+ for (unsigned int i = 0; i < new_matrix.size1(); ++i) // each row
274
+ new_matrix(i,j) *= weight;
275
+ }
276
+ return new_matrix;
277
+ }
278
+
279
+ matrix<double> LinearRegression::multiplyWeightsByMatrix(matrix<double>& mat)
280
+ {
281
+ if (mat.size1() != m_ws.size())
282
+ throw std::runtime_error("[LinearRegression::multiplyMatrixByWeights] invalid matrix dimensions!");
283
+
284
+ matrix<double> new_matrix = mat; // copy
285
+ for (unsigned int i = 0; i < new_matrix.size2(); ++i) // each row
286
+ {
287
+ double weight = m_ws.at(i);
288
+ for (unsigned int j = 0; j < new_matrix.size1(); ++j) // each column
289
+ new_matrix(i,j) *= weight;
290
+ }
291
+ return new_matrix;
292
+ }
293
+
294
+ vector<double>& LinearRegression::getStandardErrors()
295
+ {
296
+ return m_standardErrors;
297
+ }
298
+
299
+ double LinearRegression::getSSquared()
300
+ {
301
+ return m_sSquared;
302
+ }
303
+
304
+
305
+
@@ -0,0 +1,75 @@
1
+ #include "LinearRegression/OLSLinearRegression.h"
2
+ #include "utils/MathUtils.h"
3
+ #include "utils/Utils.h"
4
+
5
+ #include <iostream>
6
+ #include <boost/numeric/ublas/io.hpp>
7
+ using std::cout;
8
+ using std::endl;
9
+
10
+ namespace ublas = boost::numeric::ublas;
11
+ using Utils::operator+=;
12
+ using ublas::matrix;
13
+ using ublas::prod;
14
+
15
+ OLSLinearRegression::OLSLinearRegression(std::vector<double> xs, std::vector<double> ys,
16
+ std::vector<double> weights)
17
+ : LinearRegression(xs, ys, weights)
18
+ {
19
+ calculate();
20
+ }
21
+
22
+ OLSLinearRegression::OLSLinearRegression(std::vector<std::vector<double> > xs, std::vector<double> ys,
23
+ std::vector<double> weights)
24
+ : LinearRegression(xs, ys, weights)
25
+ {
26
+ calculate();
27
+ }
28
+
29
+ OLSLinearRegression::OLSLinearRegression(std::vector<std::vector<double> > xs, std::vector<double> ys,
30
+ double fixedConstant, std::vector<double> weights)
31
+ : LinearRegression(xs, ys, fixedConstant, weights)
32
+ {
33
+ calculate();
34
+ }
35
+
36
+ OLSLinearRegression::~OLSLinearRegression()
37
+ {}
38
+
39
+ void OLSLinearRegression::calculate()
40
+ {
41
+ checkDimensions();
42
+ // matrix based implementation
43
+
44
+ // b = inverse(X'WX)X'Wy
45
+ // where X is the data matrix (rows are observations, columns are our X variables. If a constant is to be estimated,
46
+ // then the first column is set to 1, and the first estimated parameter will be the constant).
47
+ // X' is the transpose of X
48
+ // W is a matrix with diag(w1, w2, w3) etc, where wi is the weight of observation i
49
+ // y is the column matrix containing the observed y's.
50
+ populateMembers();
51
+ EstimateBs();
52
+ if (m_paramsAreValid) calculateStatistics();
53
+ }
54
+
55
+ void OLSLinearRegression::EstimateBs()
56
+ {
57
+ matrix<double> Y = m_Y;
58
+ if (m_constantIsFixed)
59
+ {
60
+ for (int i = 0; i < m_n; ++i) Y(i, 0) -= m_constant;
61
+ }
62
+
63
+ m_A = prod(m_Xtranspose_W_X_inverse, m_Xtranspose_W);
64
+ m_B = prod(m_A, Y);
65
+
66
+ // set m_bs and constant
67
+ m_bs.resize(m_k);
68
+ for (int i = 0; i < m_k; ++i)
69
+ m_bs.at(i) = m_B(i, 0);
70
+
71
+ if (!m_constantIsFixed)
72
+ m_constant = m_B(m_p-1, 0);
73
+
74
+ m_paramsAreValid = true;
75
+ }
@@ -0,0 +1,50 @@
1
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
2
+
3
+ DecisionTreeExperiment::DecisionTreeExperiment()
4
+ : MLExperiment()
5
+ {
6
+
7
+ }
8
+
9
+ DecisionTreeExperiment::DecisionTreeExperiment(shared_ptr<MLExperiment> mlExperiment)
10
+ : MLExperiment(mlExperiment)
11
+ {
12
+
13
+ }
14
+
15
+
16
+
17
+ DecisionTreeExperiment::~DecisionTreeExperiment()
18
+ {
19
+
20
+ }
21
+
22
+ void DecisionTreeExperiment::setF(double f)
23
+ {
24
+ m_F = f;
25
+ }
26
+
27
+ void DecisionTreeExperiment::setZ(double z)
28
+ {
29
+ m_Z = z;
30
+ }
31
+
32
+ double DecisionTreeExperiment::getF()
33
+ {
34
+ return m_F;
35
+ }
36
+
37
+ double DecisionTreeExperiment::getZ()
38
+ {
39
+ return m_Z;
40
+ }
41
+
42
+ void DecisionTreeExperiment::incrementF(double increment)
43
+ {
44
+ m_F += increment;
45
+ }
46
+
47
+ double DecisionTreeExperiment::getY()
48
+ {
49
+ return m_yValue;
50
+ }
@@ -0,0 +1,195 @@
1
+ #include "MachineLearning/DecisionTree/DecisionTreeNode.h"
2
+ #include "MachineLearning/DecisionTree/SplitDefinition.h"
3
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
4
+ #include "utils/Utils.h"
5
+
6
+ #include <stdexcept>
7
+ using std::runtime_error;
8
+
9
+ bool DecisionTreeNode::m_missingValueDefined = false;
10
+ double DecisionTreeNode::m_missingValue = -1.0;
11
+
12
+ DecisionTreeNode::DecisionTreeNode( vector<shared_ptr<DecisionTreeExperiment> > experiments,
13
+ double sumZ,
14
+ double sumW,
15
+ Partition partition,
16
+ shared_ptr<SplitDefinition> parentSplitDefinition)
17
+ : m_experiments(experiments), m_nodeHasChildren(false), m_sumZ(sumZ), m_sumW(sumW),
18
+ m_whichPartitionAmI(partition), m_parentSplitDefinition(parentSplitDefinition)
19
+ {
20
+
21
+ }
22
+
23
+ DecisionTreeNode::~DecisionTreeNode()
24
+ {
25
+
26
+ }
27
+
28
+ shared_ptr<DecisionTreeNode> DecisionTreeNode::getTerminalNodeForExperiment(shared_ptr<DecisionTreeExperiment> experiment)
29
+ {
30
+ if (!m_nodeHasChildren)
31
+ throw std::runtime_error("Node is a terminal node, so you shouldn't ask it for a terminal node!");
32
+
33
+ if (m_splitDefinition.get() == 0)
34
+ throw std::runtime_error("Node has children, but split definition is empty");
35
+
36
+ shared_ptr<DecisionTreeNode> childForExperiment = chooseChild(experiment);
37
+
38
+ if (childForExperiment.get() == 0)
39
+ return childForExperiment;
40
+ else if (childForExperiment->getSumW() == 0)
41
+ {
42
+ // this likely means that the value is missing, but there weren't any missing values in the
43
+ // bagged training set. Therefore, there is no weight in the missing child.
44
+ // return an empty pointer, and this DecisionTreeNode will become the one chosen.
45
+ return shared_ptr<DecisionTreeNode>();
46
+ }
47
+ else if (childForExperiment->isTerminalNode())
48
+ return childForExperiment;
49
+ else
50
+ {
51
+ shared_ptr<DecisionTreeNode> terminalNode = childForExperiment->getTerminalNodeForExperiment(experiment);
52
+ if (terminalNode.get() == 0)
53
+ {
54
+ // we have encountered a NEW category...therefore we couldn't split on childForExperiment.
55
+ // therefore, return the child itself.
56
+ return childForExperiment;
57
+ }
58
+ else
59
+ return terminalNode;
60
+ }
61
+
62
+ }
63
+
64
+ shared_ptr<DecisionTreeNode> DecisionTreeNode::chooseChild(shared_ptr<DecisionTreeExperiment> experiment)
65
+ {
66
+ if (!m_nodeHasChildren)
67
+ throw std::runtime_error("[DecisionTreeNode::chooseChild] - this Decision Tree has no children!");
68
+
69
+ double featureValue = experiment->getFeatureValue(m_splitDefinition->getFeatureIndex());
70
+
71
+ if (m_missingValueDefined && m_missingValue == featureValue)
72
+ return m_missingChild;
73
+
74
+ if (m_splitDefinition->isCategorical()) // categorical variable
75
+ {
76
+ if (Utils::hasElement(m_splitDefinition->getLhsCategories(), featureValue))
77
+ return m_lhsChild;
78
+ else if (Utils::hasElement(m_splitDefinition->getRhsCategories(), featureValue))
79
+ return m_rhsChild;
80
+ else
81
+ {
82
+ // it's not missing, but not in left or right. Therefore, we have a NEW category.
83
+ // We should return an empty pointer, and let the parent handle it.
84
+ return shared_ptr<DecisionTreeNode>();
85
+ }
86
+ }
87
+ else // continuous variable
88
+ {
89
+ double splitValue = m_splitDefinition->getSplitValue();
90
+ if (m_missingValueDefined && m_missingValue == splitValue)
91
+ {
92
+ // complicated logic. Our split value equals the missing value. Therefore, we split off missing versus
93
+ // everything else (which gets put in the rhsChild). As our feature value is not the missing value, we choose
94
+ // the rhsChild.
95
+ return m_rhsChild;
96
+ }
97
+ else if (featureValue < splitValue)
98
+ return m_lhsChild;
99
+ else
100
+ return m_rhsChild;
101
+ }
102
+ }
103
+
104
+ void DecisionTreeNode::defineSplit( shared_ptr<SplitDefinition> splitDefinition,
105
+ shared_ptr<DecisionTreeNode> lhsChild,
106
+ shared_ptr<DecisionTreeNode> rhsChild,
107
+ shared_ptr<DecisionTreeNode> missingChild)
108
+ {
109
+ setChildren(lhsChild, rhsChild, missingChild);
110
+ m_splitDefinition = splitDefinition;
111
+ }
112
+
113
+ void DecisionTreeNode::setChildren( shared_ptr<DecisionTreeNode> lhsChild,
114
+ shared_ptr<DecisionTreeNode> rhsChild,
115
+ shared_ptr<DecisionTreeNode> missingChild)
116
+ {
117
+ m_nodeHasChildren = true;
118
+ m_lhsChild = lhsChild;
119
+ m_rhsChild = rhsChild;
120
+ m_missingChild = missingChild;
121
+ }
122
+
123
+ vector<shared_ptr<DecisionTreeExperiment> > DecisionTreeNode::getExperiments()
124
+ {
125
+ return m_experiments;
126
+ }
127
+
128
+ bool DecisionTreeNode::isTerminalNode()
129
+ {
130
+ return !m_nodeHasChildren;
131
+ }
132
+
133
+ void DecisionTreeNode::clearExperimentsWithinTree()
134
+ {
135
+ m_experiments.clear();
136
+ if (m_nodeHasChildren)
137
+ {
138
+ m_lhsChild->clearExperimentsWithinTree();
139
+ m_rhsChild->clearExperimentsWithinTree();
140
+ m_missingChild->clearExperimentsWithinTree();
141
+ }
142
+ }
143
+
144
+ double DecisionTreeNode::getSumZ()
145
+ {
146
+ return m_sumZ;
147
+ }
148
+
149
+ double DecisionTreeNode::getSumW()
150
+ {
151
+ return m_sumW;
152
+ }
153
+
154
+ void DecisionTreeNode::setMissingValue( double missingValue )
155
+ {
156
+ m_missingValue = missingValue;
157
+ m_missingValueDefined = true;
158
+ }
159
+
160
+ shared_ptr<SplitDefinition> DecisionTreeNode::getSplitDefinition()
161
+ {
162
+ return m_splitDefinition;
163
+ }
164
+
165
+ shared_ptr<SplitDefinition> DecisionTreeNode::getParentSplitDefinition()
166
+ {
167
+ return m_parentSplitDefinition;
168
+ }
169
+
170
+ Partition DecisionTreeNode::getPartition()
171
+ {
172
+ return m_whichPartitionAmI;
173
+ }
174
+
175
+ void DecisionTreeNode::setSumZ( double sumZ )
176
+ {
177
+ m_sumZ = sumZ;
178
+ }
179
+
180
+ void DecisionTreeNode::setSumW( double sumW )
181
+ {
182
+ m_sumW = sumW;
183
+ }
184
+
185
+ void DecisionTreeNode::updateSums()
186
+ {
187
+ m_sumW = 0.0;
188
+ m_sumZ = 0.0;
189
+ for (unsigned int i=0; i<m_experiments.size(); ++i)
190
+ {
191
+ double w = m_experiments.at(i)->getWeight();
192
+ m_sumW += w;
193
+ m_sumZ += w * m_experiments.at(i)->getZ();
194
+ }
195
+ }