ml4r 0.1.4 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- data/ext/ml4r/LinearRegression/LinearRegression.cpp +305 -0
- data/ext/ml4r/LinearRegression/OLSLinearRegression.cpp +75 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeExperiment.cpp +50 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeNode.cpp +195 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitter.cpp +551 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterCategorical.cpp +22 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterContinuous.cpp +21 -0
- data/ext/ml4r/MachineLearning/DecisionTree/SplitDefinition.cpp +142 -0
- data/ext/ml4r/MachineLearning/GBM/BernoulliCalculator.cpp +95 -0
- data/ext/ml4r/MachineLearning/GBM/GBMEstimator.cpp +601 -0
- data/ext/ml4r/MachineLearning/GBM/GBMOutput.cpp +86 -0
- data/ext/ml4r/MachineLearning/GBM/GBMRunner.cpp +117 -0
- data/ext/ml4r/MachineLearning/GBM/GaussianCalculator.cpp +94 -0
- data/ext/ml4r/MachineLearning/GBM/ZenithGBM.cpp +317 -0
- data/ext/ml4r/MachineLearning/MLData/MLData.cpp +232 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataFields.cpp +1 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataReader.cpp +139 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLData.cpp +96 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLDataReader.cpp +113 -0
- data/ext/ml4r/MachineLearning/MLExperiment.cpp +69 -0
- data/ext/ml4r/MachineLearning/MLRunner.cpp +183 -0
- data/ext/ml4r/MachineLearning/MLUtils.cpp +15 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestEstimator.cpp +172 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestOutput.cpp +66 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestRunner.cpp +84 -0
- data/ext/ml4r/MachineLearning/RandomForest/ZenithRandomForest.cpp +184 -0
- data/ext/ml4r/ml4r.cpp +34 -0
- data/ext/ml4r/ml4r_wrap.cpp +15727 -0
- data/ext/ml4r/utils/MathUtils.cpp +204 -0
- data/ext/ml4r/utils/StochasticUtils.cpp +73 -0
- data/ext/ml4r/utils/Utils.cpp +14 -0
- data/ext/ml4r/utils/VlcMessage.cpp +3 -0
- metadata +33 -1
@@ -0,0 +1,551 @@
|
|
1
|
+
#include "MachineLearning/DecisionTree/NodeSplitter.h"
|
2
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
3
|
+
#include "MachineLearning/DecisionTree/SplitDefinition.h"
|
4
|
+
#include "MachineLearning/DecisionTree/DecisionTreeNode.h"
|
5
|
+
#include "MachineLearning/DecisionTree/CategoryInfo.h"
|
6
|
+
#include "MachineLearning/MLData/MLData.h"
|
7
|
+
#include "MachineLearning/GBM/GBMEstimator.h"
|
8
|
+
#include "utils/Utils.h"
|
9
|
+
#include "utils/StochasticUtils.h"
|
10
|
+
|
11
|
+
#include <boost/foreach.hpp>
|
12
|
+
#include <boost/lexical_cast.hpp>
|
13
|
+
#include <cmath>
|
14
|
+
using boost::lexical_cast;
|
15
|
+
|
16
|
+
NodeSplitter::NodeSplitter(MLData* data, int minObservations, double scale)
|
17
|
+
: m_data(data), m_minObservations(minObservations),m_scale(scale)
|
18
|
+
{
|
19
|
+
m_missingValueDefined = m_data->missingValueDefined();
|
20
|
+
if (m_missingValueDefined)
|
21
|
+
m_missingValue = m_data->getMissingValue();
|
22
|
+
}
|
23
|
+
|
24
|
+
NodeSplitter::~NodeSplitter() {}
|
25
|
+
|
26
|
+
double NodeSplitter::calculateImprovement(double lhsSumW, double lhsSumZ, double rhsSumW, double rhsSumZ, double missingSumW, double missingSumZ)
|
27
|
+
{
|
28
|
+
double improvement = 0.0;
|
29
|
+
|
30
|
+
if (missingSumW == 0)
|
31
|
+
{
|
32
|
+
double meanZDifference = lhsSumZ / lhsSumW - rhsSumZ / rhsSumW;
|
33
|
+
improvement = lhsSumW * rhsSumW * pow(meanZDifference, 2) / (lhsSumW + rhsSumW);
|
34
|
+
}
|
35
|
+
else
|
36
|
+
{
|
37
|
+
double meanLRDifference = lhsSumZ / lhsSumW - rhsSumZ / rhsSumW;
|
38
|
+
double meanLMDifference = lhsSumZ / lhsSumW - missingSumZ / missingSumW;
|
39
|
+
double meanRMDifference = rhsSumZ / rhsSumW - missingSumZ / missingSumW;
|
40
|
+
|
41
|
+
improvement += lhsSumW * rhsSumW * pow(meanLRDifference, 2);
|
42
|
+
improvement += lhsSumW * missingSumW * pow(meanLMDifference, 2);
|
43
|
+
improvement += rhsSumW * missingSumW * pow(meanRMDifference, 2);
|
44
|
+
improvement /= (lhsSumW + rhsSumW + missingSumW);
|
45
|
+
}
|
46
|
+
|
47
|
+
return improvement;
|
48
|
+
}
|
49
|
+
|
50
|
+
// double NodeSplitter::calculateImprovement(double lhsSumW, double lhsSumZ, double rhsSumW, double rhsSumZ, double missingSumW, double missingSumZ)
|
51
|
+
// {
|
52
|
+
// double improvement = 0.0;
|
53
|
+
//
|
54
|
+
// if (lhsSumW)
|
55
|
+
// improvement += pow(lhsSumZ, 2) / lhsSumW;
|
56
|
+
//
|
57
|
+
// if (rhsSumW)
|
58
|
+
// improvement += pow(rhsSumZ, 2) / rhsSumW;
|
59
|
+
//
|
60
|
+
// if (missingSumW)
|
61
|
+
// improvement += pow(missingSumZ, 2) / missingSumW;
|
62
|
+
//
|
63
|
+
// return improvement;
|
64
|
+
// }
|
65
|
+
|
66
|
+
shared_ptr<DecisionTreeNode> NodeSplitter::createLhsChild( shared_ptr<SplitDefinition> splitDefinition )
|
67
|
+
{
|
68
|
+
return createChild(splitDefinition, LHS);
|
69
|
+
}
|
70
|
+
|
71
|
+
shared_ptr<DecisionTreeNode> NodeSplitter::createRhsChild( shared_ptr<SplitDefinition> splitDefinition )
|
72
|
+
{
|
73
|
+
return createChild(splitDefinition, RHS);
|
74
|
+
}
|
75
|
+
|
76
|
+
shared_ptr<DecisionTreeNode> NodeSplitter::createMissingChild( shared_ptr<SplitDefinition> splitDefinition )
|
77
|
+
{
|
78
|
+
return createChild(splitDefinition, MISSING);
|
79
|
+
}
|
80
|
+
|
81
|
+
shared_ptr<DecisionTreeNode> NodeSplitter::createChild( shared_ptr<SplitDefinition> splitDefinition, Partition partition )
|
82
|
+
{
|
83
|
+
vector<shared_ptr<DecisionTreeExperiment> > experiments = splitDefinition->getNodeToSplit()->getExperiments();
|
84
|
+
vector<shared_ptr<DecisionTreeExperiment> > childExperiments =
|
85
|
+
partitionExperiments(experiments, splitDefinition, partition);
|
86
|
+
|
87
|
+
double sumZ;
|
88
|
+
double sumW;
|
89
|
+
if (partition == LHS)
|
90
|
+
{
|
91
|
+
sumZ = splitDefinition->getLhsSumZ();
|
92
|
+
sumW = splitDefinition->getLhsSumW();
|
93
|
+
}
|
94
|
+
else if (partition == RHS)
|
95
|
+
{
|
96
|
+
sumZ = splitDefinition->getRhsSumZ();
|
97
|
+
sumW = splitDefinition->getRhsSumW();
|
98
|
+
}
|
99
|
+
else
|
100
|
+
{
|
101
|
+
sumZ = splitDefinition->getMissingSumZ();
|
102
|
+
sumW = splitDefinition->getMissingSumW();
|
103
|
+
}
|
104
|
+
shared_ptr<DecisionTreeNode> child =
|
105
|
+
shared_ptr<DecisionTreeNode>(new DecisionTreeNode(childExperiments, sumZ, sumW, partition, splitDefinition));
|
106
|
+
return child;
|
107
|
+
}
|
108
|
+
|
109
|
+
vector<shared_ptr<DecisionTreeExperiment> > NodeSplitter::partitionExperiments(vector<shared_ptr<DecisionTreeExperiment> >& experiments,
|
110
|
+
shared_ptr<SplitDefinition> splitDefinition, Partition partition)
|
111
|
+
{
|
112
|
+
bool rhs = !partition;
|
113
|
+
vector<shared_ptr<DecisionTreeExperiment> > partitionExperiments;
|
114
|
+
|
115
|
+
if (partition == LHS)
|
116
|
+
partitionExperiments.reserve(splitDefinition->getLhsExperimentCount());
|
117
|
+
else if (partition == RHS)
|
118
|
+
partitionExperiments.reserve(splitDefinition->getRhsExperimentCount());
|
119
|
+
else if (partition == MISSING)
|
120
|
+
partitionExperiments.reserve(splitDefinition->getMissingExperimentCount());
|
121
|
+
|
122
|
+
int featureIndex = splitDefinition->getFeatureIndex();
|
123
|
+
|
124
|
+
if (splitDefinition->isCategorical())
|
125
|
+
{
|
126
|
+
// categorical
|
127
|
+
set<double>& lhsCategories = splitDefinition->getLhsCategories();
|
128
|
+
set<double>& rhsCategories = splitDefinition->getRhsCategories();
|
129
|
+
|
130
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
131
|
+
{
|
132
|
+
double featureValue = experiment->getFeatureValue(featureIndex);
|
133
|
+
set<double>::const_iterator lhsIt = lhsCategories.find(featureValue);
|
134
|
+
set<double>::const_iterator rhsIt = rhsCategories.find(featureValue);
|
135
|
+
|
136
|
+
if ((partition == MISSING && m_missingValueDefined && m_missingValue == featureValue) ||
|
137
|
+
(partition == LHS && lhsIt != lhsCategories.end()) ||
|
138
|
+
(partition == RHS && rhsIt != rhsCategories.end()))
|
139
|
+
{
|
140
|
+
partitionExperiments.push_back(experiment);
|
141
|
+
}
|
142
|
+
}
|
143
|
+
}
|
144
|
+
else
|
145
|
+
{
|
146
|
+
// continuous
|
147
|
+
double splitValue = splitDefinition->getSplitValue();
|
148
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
149
|
+
{
|
150
|
+
double featureValue = experiment->getFeatureValue(featureIndex);
|
151
|
+
|
152
|
+
if (m_missingValueDefined && featureValue == m_missingValue)
|
153
|
+
{
|
154
|
+
// experiment has a missing value
|
155
|
+
if (partition == MISSING)
|
156
|
+
partitionExperiments.push_back(experiment);
|
157
|
+
}
|
158
|
+
else if ((partition == LHS && featureValue < splitValue) ||
|
159
|
+
(partition == RHS && featureValue > splitValue))
|
160
|
+
partitionExperiments.push_back(experiment);
|
161
|
+
}
|
162
|
+
}
|
163
|
+
return partitionExperiments;
|
164
|
+
}
|
165
|
+
|
166
|
+
vector<shared_ptr<DecisionTreeNode> > NodeSplitter::splitNode( shared_ptr<DecisionTreeNode> nodeToSplit, vector<int> featuresToConsider )
|
167
|
+
{
|
168
|
+
vector<shared_ptr<DecisionTreeNode> > children;
|
169
|
+
|
170
|
+
if (nodeToSplit->getSumW() == 0)
|
171
|
+
return children;
|
172
|
+
|
173
|
+
// find terminal node with best improvement for any of those variables
|
174
|
+
shared_ptr<SplitDefinition> bestSplit;
|
175
|
+
double bestImprovement = 0.0;
|
176
|
+
|
177
|
+
vector<double> vecImprovements;
|
178
|
+
vector<shared_ptr<SplitDefinition> > vecSplits;
|
179
|
+
|
180
|
+
vecImprovements.reserve(featuresToConsider.size());
|
181
|
+
vecSplits.reserve(featuresToConsider.size());
|
182
|
+
|
183
|
+
set<int>& categoricalFeatures = m_data->getCategoricalFeatureIndices();
|
184
|
+
|
185
|
+
BOOST_FOREACH(int featureIndex, featuresToConsider)
|
186
|
+
{
|
187
|
+
|
188
|
+
shared_ptr<SplitDefinition> split;
|
189
|
+
|
190
|
+
if (Utils::hasElement(categoricalFeatures,featureIndex))
|
191
|
+
split = createCategoricalSplitDefinition(nodeToSplit, featureIndex);
|
192
|
+
else
|
193
|
+
split = createContinuousSplitDefinition(nodeToSplit, featureIndex);
|
194
|
+
|
195
|
+
vecSplits.push_back(split);
|
196
|
+
vecImprovements.push_back(split.get() ? split->getImprovement() : 0);
|
197
|
+
|
198
|
+
if (!split.get()) // it returned an invalid
|
199
|
+
continue;
|
200
|
+
|
201
|
+
if (split->getImprovement() > bestImprovement)
|
202
|
+
{
|
203
|
+
bestImprovement = split->getImprovement();
|
204
|
+
bestSplit = split;
|
205
|
+
}
|
206
|
+
}
|
207
|
+
if (bestImprovement == 0.0)
|
208
|
+
return children;
|
209
|
+
|
210
|
+
if (m_scale != std::numeric_limits<double>::infinity() && vecImprovements.size() > 1)
|
211
|
+
{
|
212
|
+
vector<float> exp_u;
|
213
|
+
BOOST_FOREACH(double improvement, vecImprovements)
|
214
|
+
exp_u.push_back(m_scale * improvement / bestImprovement);
|
215
|
+
|
216
|
+
vector<float> pdf = StochasticUtils::convertHistogramToPdf(exp_u);
|
217
|
+
int bestIndex = StochasticUtils::chooseCategoryFromPdf(pdf);
|
218
|
+
bestImprovement = vecImprovements.at(bestIndex);
|
219
|
+
bestSplit = vecSplits.at(bestIndex);
|
220
|
+
}
|
221
|
+
|
222
|
+
int featureIndex = bestSplit->getFeatureIndex();
|
223
|
+
bool isCategorical = Utils::hasElement(categoricalFeatures,featureIndex);
|
224
|
+
|
225
|
+
shared_ptr<DecisionTreeNode> lhsChild = createLhsChild(bestSplit);
|
226
|
+
shared_ptr<DecisionTreeNode> rhsChild = createRhsChild(bestSplit);
|
227
|
+
shared_ptr<DecisionTreeNode> missingChild = createMissingChild(bestSplit);
|
228
|
+
|
229
|
+
nodeToSplit->defineSplit(bestSplit,lhsChild,rhsChild,missingChild);
|
230
|
+
|
231
|
+
// if (m_parameters.verbose)
|
232
|
+
// vlcMessage.Write("Split at feature index " + ToString(bestSplit->getFeatureIndex()) + " at value " + ToString(bestSplit->getSplitValue()) + " with improvement " + ToString(bestSplit->getImprovement()));
|
233
|
+
|
234
|
+
// finally, remove the node we just split from the terminal nodes, and add the children
|
235
|
+
children.push_back(lhsChild);
|
236
|
+
children.push_back(rhsChild);
|
237
|
+
children.push_back(missingChild);
|
238
|
+
|
239
|
+
return children;
|
240
|
+
}
|
241
|
+
|
242
|
+
struct FeatureSorter
|
243
|
+
{
|
244
|
+
FeatureSorter()
|
245
|
+
: featureIndexToSort(-1)
|
246
|
+
{}
|
247
|
+
|
248
|
+
int featureIndexToSort;
|
249
|
+
|
250
|
+
bool operator() (shared_ptr<DecisionTreeExperiment> a, shared_ptr<DecisionTreeExperiment> b)
|
251
|
+
{
|
252
|
+
if (featureIndexToSort == -1)
|
253
|
+
throw std::runtime_error("SortOnFeature object doesn't know which feature to sort on!");
|
254
|
+
|
255
|
+
return a->getFeatureValue(featureIndexToSort) < b->getFeatureValue(featureIndexToSort);
|
256
|
+
}
|
257
|
+
} featureSorter;
|
258
|
+
|
259
|
+
shared_ptr<SplitDefinition> NodeSplitter::createContinuousSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
|
260
|
+
{
|
261
|
+
vector<shared_ptr<DecisionTreeExperiment> > sortedExperiments = node->getExperiments();
|
262
|
+
// vector<shared_ptr<DecisionTreeExperiment> >& sortedExperiments = node->getSortedExperimentsForFeature(featureIndex);
|
263
|
+
|
264
|
+
featureSorter.featureIndexToSort = featureIndex;
|
265
|
+
sort(sortedExperiments.begin(), sortedExperiments.end(), featureSorter);
|
266
|
+
|
267
|
+
double rhsSumZ = 0, rhsSumW = 0, lhsSumZ = 0, lhsSumW = 0;
|
268
|
+
double missingSumZ = 0, missingSumW = 0;
|
269
|
+
|
270
|
+
vector<double> vecLhsSumZ;
|
271
|
+
vector<double> vecLhsSumW;
|
272
|
+
vector<int> vecLhsCount;
|
273
|
+
vector<double> vecRhsSumZ;
|
274
|
+
vector<double> vecRhsSumW;
|
275
|
+
vector<int> vecRhsCount;
|
276
|
+
vector<double> vecMissingSumZ;
|
277
|
+
vector<double> vecMissingSumW;
|
278
|
+
vector<int> vecMissingCount;
|
279
|
+
vector<double> vecImprovement;
|
280
|
+
vector<int> vecPosition;
|
281
|
+
|
282
|
+
double bestLhsSumZ;
|
283
|
+
double bestLhsSumW;
|
284
|
+
int bestLhsCount;
|
285
|
+
double bestRhsSumZ;
|
286
|
+
double bestRhsSumW;
|
287
|
+
int bestRhsCount;
|
288
|
+
double bestMissingSumZ;
|
289
|
+
double bestMissingSumW;
|
290
|
+
int bestMissingCount;
|
291
|
+
|
292
|
+
double bestImprovement = 0.0;
|
293
|
+
int bestPosition = -1;
|
294
|
+
int bestIndex = -1;
|
295
|
+
|
296
|
+
int lhsCount = 0, missingCount = 0;
|
297
|
+
int rhsCount = (int) sortedExperiments.size();
|
298
|
+
|
299
|
+
rhsSumZ = node->getSumZ();
|
300
|
+
rhsSumW = node->getSumW();
|
301
|
+
int position = -1;
|
302
|
+
double previousFeatureValue = 0;
|
303
|
+
|
304
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment> experiment, sortedExperiments)
|
305
|
+
{
|
306
|
+
double featureValue = experiment->getFeatureValue(featureIndex);
|
307
|
+
|
308
|
+
if (featureValue != previousFeatureValue)
|
309
|
+
{
|
310
|
+
// vlcMessage.Write("featureValue != previousFeatureValue => " + ToString(featureValue != previousFeatureValue));
|
311
|
+
// vlcMessage.Write("lhsSumW => " + ToString(lhsSumW));
|
312
|
+
// vlcMessage.Write("lhsSumZ => " + ToString(lhsSumZ));
|
313
|
+
// vlcMessage.Write("rhsSumW => " + ToString(rhsSumW));
|
314
|
+
// vlcMessage.Write("rhsSumZ => " + ToString(rhsSumZ));
|
315
|
+
// vlcMessage.Write("missingSumW => " + ToString(missingSumW));
|
316
|
+
// vlcMessage.Write("missingSumZ => " + ToString(missingSumZ));
|
317
|
+
// vlcMessage.Write("improvement => " + ToString(improvement));
|
318
|
+
// vlcMessage.Write("bestImprovement => " + ToString(bestImprovement));
|
319
|
+
// vlcMessage.Write("m_minObservations => " + ToString(m_minObservations));
|
320
|
+
}
|
321
|
+
|
322
|
+
if (featureValue != previousFeatureValue &&
|
323
|
+
lhsCount >= m_minObservations &&
|
324
|
+
rhsCount >= m_minObservations
|
325
|
+
)
|
326
|
+
{
|
327
|
+
double improvement = calculateImprovement(lhsSumW, lhsSumZ, rhsSumW, rhsSumZ, missingSumW, missingSumZ);
|
328
|
+
vecPosition.push_back(position);
|
329
|
+
vecImprovement.push_back(improvement);
|
330
|
+
vecLhsSumZ.push_back(lhsSumZ);
|
331
|
+
vecLhsSumW.push_back(lhsSumW);
|
332
|
+
vecLhsCount.push_back(lhsCount);
|
333
|
+
vecRhsSumZ.push_back(rhsSumZ);
|
334
|
+
vecRhsSumW.push_back(rhsSumW);
|
335
|
+
vecRhsCount.push_back(rhsCount);
|
336
|
+
vecMissingSumZ.push_back(missingSumZ);
|
337
|
+
vecMissingSumW.push_back(missingSumW);
|
338
|
+
vecMissingCount.push_back(missingCount);
|
339
|
+
|
340
|
+
if (improvement > bestImprovement)
|
341
|
+
{
|
342
|
+
bestImprovement = improvement;
|
343
|
+
bestPosition = position;
|
344
|
+
bestIndex = (int) vecPosition.size() - 1;
|
345
|
+
}
|
346
|
+
// if (improvement > bestImprovement)
|
347
|
+
// {
|
348
|
+
// bestImprovement = improvement;
|
349
|
+
// bestPosition = position;
|
350
|
+
// bestLhsSumZ = lhsSumZ;
|
351
|
+
// bestLhsSumW = lhsSumW;
|
352
|
+
// bestLhsCount = lhsCount;
|
353
|
+
// bestRhsSumZ = rhsSumZ;
|
354
|
+
// bestRhsSumW = rhsSumW;
|
355
|
+
// bestRhsCount = rhsCount;
|
356
|
+
// bestMissingSumZ = missingSumZ;
|
357
|
+
// bestMissingSumW = missingSumW;
|
358
|
+
// bestMissingCount = missingCount;
|
359
|
+
// }
|
360
|
+
// vlcMessage.Write("improvement => " + ToString(improvement));
|
361
|
+
|
362
|
+
}
|
363
|
+
double weight = experiment->getWeight();
|
364
|
+
double z = experiment->getZ();
|
365
|
+
rhsSumZ -= weight * z;
|
366
|
+
rhsSumW -= weight;
|
367
|
+
--rhsCount;
|
368
|
+
|
369
|
+
if (m_missingValueDefined && featureValue == m_missingValue)
|
370
|
+
{
|
371
|
+
missingSumZ += weight * z;
|
372
|
+
missingSumW += weight;
|
373
|
+
++missingCount;
|
374
|
+
}
|
375
|
+
else
|
376
|
+
{
|
377
|
+
lhsSumZ += weight * z;
|
378
|
+
lhsSumW += weight;
|
379
|
+
++lhsCount;
|
380
|
+
}
|
381
|
+
|
382
|
+
previousFeatureValue = featureValue;
|
383
|
+
++position;
|
384
|
+
}
|
385
|
+
|
386
|
+
if (bestPosition == -1)
|
387
|
+
return shared_ptr<SplitDefinition>();
|
388
|
+
|
389
|
+
if (m_scale != std::numeric_limits<double>::infinity() && vecImprovement.size() > 1)
|
390
|
+
{
|
391
|
+
vector<float> exp_u;
|
392
|
+
exp_u.reserve(vecImprovement.size());
|
393
|
+
BOOST_FOREACH(double& improvement, vecImprovement)
|
394
|
+
{
|
395
|
+
exp_u.push_back(exp(m_scale * improvement / bestImprovement));
|
396
|
+
}
|
397
|
+
vector<float> pdf = StochasticUtils::convertHistogramToPdf(exp_u);
|
398
|
+
bestIndex = StochasticUtils::chooseCategoryFromPdf(pdf, "improvements");
|
399
|
+
}
|
400
|
+
|
401
|
+
bestLhsSumZ = vecLhsSumZ.at(bestIndex);
|
402
|
+
bestLhsSumW = vecLhsSumW.at(bestIndex);
|
403
|
+
bestLhsCount = vecLhsCount.at(bestIndex);
|
404
|
+
bestRhsSumZ = vecRhsSumZ.at(bestIndex);
|
405
|
+
bestRhsSumW = vecRhsSumW.at(bestIndex);
|
406
|
+
bestRhsCount = vecRhsCount.at(bestIndex);
|
407
|
+
bestMissingSumZ = vecMissingSumZ.at(bestIndex);
|
408
|
+
bestMissingSumW = vecMissingSumW.at(bestIndex);
|
409
|
+
bestMissingCount = vecMissingCount.at(bestIndex);
|
410
|
+
bestImprovement = vecImprovement.at(bestIndex);
|
411
|
+
bestPosition = vecPosition.at(bestIndex);
|
412
|
+
|
413
|
+
if (bestPosition >= (int) (sortedExperiments.size()-1))
|
414
|
+
throw std::runtime_error(string("Unexpected bestPosition: ") + lexical_cast<string>(bestPosition));
|
415
|
+
|
416
|
+
double lhsFeatureValue = sortedExperiments.at(bestPosition)->getFeatureValue(featureIndex);
|
417
|
+
double rhsFeatureValue = sortedExperiments.at(bestPosition + 1)->getFeatureValue(featureIndex);
|
418
|
+
|
419
|
+
double splitValue;
|
420
|
+
if (m_missingValueDefined && (lhsFeatureValue == m_missingValue))
|
421
|
+
splitValue = m_missingValue;
|
422
|
+
else
|
423
|
+
splitValue = 0.5 * (lhsFeatureValue + rhsFeatureValue);
|
424
|
+
|
425
|
+
|
426
|
+
shared_ptr<SplitDefinition> splitDefinition = shared_ptr<SplitDefinition>
|
427
|
+
(new SplitDefinition(node, featureIndex, splitValue, bestLhsSumZ, bestLhsSumW, bestLhsCount,
|
428
|
+
bestRhsSumZ, bestRhsSumW, bestRhsCount, bestMissingSumZ, bestMissingSumW,
|
429
|
+
bestMissingCount, bestImprovement));
|
430
|
+
|
431
|
+
// create SplitDefinition
|
432
|
+
return splitDefinition;
|
433
|
+
}
|
434
|
+
|
435
|
+
shared_ptr<SplitDefinition> NodeSplitter::createCategoricalSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
|
436
|
+
{
|
437
|
+
vector<shared_ptr<DecisionTreeExperiment> > experiments = node->getExperiments();
|
438
|
+
|
439
|
+
map<double, CategoryInfo> experimentsPerCategory;
|
440
|
+
|
441
|
+
double missingSumZ = 0, missingSumW = 0;
|
442
|
+
int missingCount = 0;
|
443
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
444
|
+
{
|
445
|
+
double featureValue = experiment->getFeatureValue(featureIndex);
|
446
|
+
|
447
|
+
if (m_missingValueDefined && m_missingValue == featureValue)
|
448
|
+
{
|
449
|
+
double w = experiment->getWeight();
|
450
|
+
double z = experiment->getZ();
|
451
|
+
missingSumZ += w * z;
|
452
|
+
missingSumW += w;
|
453
|
+
missingCount++;
|
454
|
+
}
|
455
|
+
else
|
456
|
+
{
|
457
|
+
CategoryInfo& info = experimentsPerCategory[featureValue];
|
458
|
+
info.category = featureValue;
|
459
|
+
info.addExperiment(experiment);
|
460
|
+
}
|
461
|
+
}
|
462
|
+
|
463
|
+
if (experimentsPerCategory.size() == 1)
|
464
|
+
return shared_ptr<SplitDefinition>(); // can't split one thing!
|
465
|
+
|
466
|
+
// put them into a vector to make sorting easier!
|
467
|
+
vector<CategoryInfo> categoryInfoVector;
|
468
|
+
typedef pair<double, CategoryInfo> ElementType;
|
469
|
+
BOOST_FOREACH(ElementType e, experimentsPerCategory)
|
470
|
+
categoryInfoVector.push_back(e.second);
|
471
|
+
|
472
|
+
sort(categoryInfoVector.begin(), categoryInfoVector.end());
|
473
|
+
|
474
|
+
double rhsSumZ = 0, rhsSumW = 0, lhsSumZ = 0, lhsSumW = 0;
|
475
|
+
double bestImprovement = 0.0;
|
476
|
+
int bestPosition = -1;
|
477
|
+
double bestLhsSumZ = -1;
|
478
|
+
double bestLhsSumW = -1;
|
479
|
+
int bestLhsCount = -1;
|
480
|
+
double bestRhsSumZ = -1;
|
481
|
+
double bestRhsSumW = -1;
|
482
|
+
int bestRhsCount = -1;
|
483
|
+
int lhsCount = 0, rhsCount = 0;
|
484
|
+
|
485
|
+
rhsSumZ = node->getSumZ();
|
486
|
+
rhsSumW = node->getSumW();
|
487
|
+
rhsCount = (int) experiments.size();
|
488
|
+
|
489
|
+
int position = -1;
|
490
|
+
|
491
|
+
BOOST_FOREACH(CategoryInfo& e, categoryInfoVector)
|
492
|
+
{
|
493
|
+
double improvement = calculateImprovement(lhsSumW, lhsSumZ, rhsSumW, rhsSumZ, missingSumW, missingSumZ);
|
494
|
+
if ( improvement > bestImprovement
|
495
|
+
&& lhsCount >= m_minObservations
|
496
|
+
&& rhsCount >= m_minObservations)
|
497
|
+
{
|
498
|
+
bestImprovement = improvement;
|
499
|
+
bestPosition = position;
|
500
|
+
bestLhsSumZ = lhsSumZ;
|
501
|
+
bestLhsSumW = lhsSumW;
|
502
|
+
bestLhsCount = lhsCount;
|
503
|
+
bestRhsSumZ = rhsSumZ;
|
504
|
+
bestRhsSumW = rhsSumW;
|
505
|
+
bestRhsCount = rhsCount;
|
506
|
+
}
|
507
|
+
|
508
|
+
++position;
|
509
|
+
rhsSumW -= e.sumW;
|
510
|
+
rhsSumZ -= e.sumZ;
|
511
|
+
rhsCount -= e.countN;
|
512
|
+
|
513
|
+
lhsSumW += e.sumW;
|
514
|
+
lhsSumZ += e.sumZ;
|
515
|
+
lhsCount += e.countN;
|
516
|
+
}
|
517
|
+
|
518
|
+
if (bestPosition == -1 && missingSumW == 0)
|
519
|
+
return shared_ptr<SplitDefinition>();
|
520
|
+
|
521
|
+
set<double> lhsCategories;
|
522
|
+
set<double> rhsCategories;
|
523
|
+
|
524
|
+
int index = -1;
|
525
|
+
BOOST_FOREACH(CategoryInfo& info, categoryInfoVector)
|
526
|
+
{
|
527
|
+
++index;
|
528
|
+
if (index <= bestPosition)
|
529
|
+
lhsCategories.insert(info.category);
|
530
|
+
else
|
531
|
+
rhsCategories.insert(info.category);
|
532
|
+
}
|
533
|
+
|
534
|
+
// we have what we need to create a split definition now
|
535
|
+
shared_ptr<SplitDefinition> splitDefinition = shared_ptr<SplitDefinition>
|
536
|
+
(new SplitDefinition(node, featureIndex, lhsCategories, rhsCategories, bestLhsSumZ, bestLhsSumW,
|
537
|
+
bestLhsCount, bestRhsSumZ, bestRhsSumW, bestRhsCount, missingSumZ, missingSumW, missingCount, bestImprovement));
|
538
|
+
|
539
|
+
return splitDefinition;
|
540
|
+
}
|
541
|
+
|
542
|
+
shared_ptr<SplitDefinition> NodeSplitter::createSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
|
543
|
+
{
|
544
|
+
if (Utils::hasElement(m_data->getCategoricalFeatureIndices(),featureIndex))
|
545
|
+
return createCategoricalSplitDefinition(node, featureIndex);
|
546
|
+
else
|
547
|
+
return createContinuousSplitDefinition(node, featureIndex);
|
548
|
+
}
|
549
|
+
|
550
|
+
|
551
|
+
|
@@ -0,0 +1,22 @@
|
|
1
|
+
#include "MachineLearning/DecisionTree/NodeSplitterCategorical.h"
|
2
|
+
#include "MachineLearning/DecisionTree/DecisionTreeNode.h"
|
3
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
4
|
+
#include "MachineLearning/DecisionTree/SplitDefinition.h"
|
5
|
+
#include "MachineLearning/DecisionTree/CategoryInfo.h"
|
6
|
+
|
7
|
+
NodeSplitterCategorical::NodeSplitterCategorical(MLData* data, int minObservations, double scale)
|
8
|
+
: NodeSplitter(data, minObservations, scale)
|
9
|
+
{
|
10
|
+
|
11
|
+
}
|
12
|
+
|
13
|
+
NodeSplitterCategorical::~NodeSplitterCategorical()
|
14
|
+
{
|
15
|
+
|
16
|
+
}
|
17
|
+
|
18
|
+
shared_ptr<SplitDefinition> NodeSplitterCategorical::createSplitDefinition(shared_ptr<DecisionTreeNode> node,
|
19
|
+
int featureIndex)
|
20
|
+
{
|
21
|
+
|
22
|
+
}
|
@@ -0,0 +1,21 @@
|
|
1
|
+
#include "MachineLearning/DecisionTree/NodeSplitterContinuous.h"
|
2
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
3
|
+
#include "MachineLearning/DecisionTree/SplitDefinition.h"
|
4
|
+
#include "MachineLearning/DecisionTree/DecisionTreeNode.h"
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
NodeSplitterContinuous::NodeSplitterContinuous(MLData* data, int minObservations, double scale)
|
9
|
+
: NodeSplitter(data, minObservations, scale)
|
10
|
+
{
|
11
|
+
}
|
12
|
+
|
13
|
+
NodeSplitterContinuous::~NodeSplitterContinuous()
|
14
|
+
{
|
15
|
+
|
16
|
+
}
|
17
|
+
|
18
|
+
shared_ptr<SplitDefinition> NodeSplitterContinuous::createSplitDefinition(shared_ptr<DecisionTreeNode> node, int featureIndex)
|
19
|
+
{
|
20
|
+
|
21
|
+
}
|