ml4r 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. data/ext/ml4r/LinearRegression/LinearRegression.cpp +305 -0
  2. data/ext/ml4r/LinearRegression/OLSLinearRegression.cpp +75 -0
  3. data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeExperiment.cpp +50 -0
  4. data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeNode.cpp +195 -0
  5. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitter.cpp +551 -0
  6. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterCategorical.cpp +22 -0
  7. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterContinuous.cpp +21 -0
  8. data/ext/ml4r/MachineLearning/DecisionTree/SplitDefinition.cpp +142 -0
  9. data/ext/ml4r/MachineLearning/GBM/BernoulliCalculator.cpp +95 -0
  10. data/ext/ml4r/MachineLearning/GBM/GBMEstimator.cpp +601 -0
  11. data/ext/ml4r/MachineLearning/GBM/GBMOutput.cpp +86 -0
  12. data/ext/ml4r/MachineLearning/GBM/GBMRunner.cpp +117 -0
  13. data/ext/ml4r/MachineLearning/GBM/GaussianCalculator.cpp +94 -0
  14. data/ext/ml4r/MachineLearning/GBM/ZenithGBM.cpp +317 -0
  15. data/ext/ml4r/MachineLearning/MLData/MLData.cpp +232 -0
  16. data/ext/ml4r/MachineLearning/MLData/MLDataFields.cpp +1 -0
  17. data/ext/ml4r/MachineLearning/MLData/MLDataReader.cpp +139 -0
  18. data/ext/ml4r/MachineLearning/MLData/ZenithMLData.cpp +96 -0
  19. data/ext/ml4r/MachineLearning/MLData/ZenithMLDataReader.cpp +113 -0
  20. data/ext/ml4r/MachineLearning/MLExperiment.cpp +69 -0
  21. data/ext/ml4r/MachineLearning/MLRunner.cpp +183 -0
  22. data/ext/ml4r/MachineLearning/MLUtils.cpp +15 -0
  23. data/ext/ml4r/MachineLearning/RandomForest/RandomForestEstimator.cpp +172 -0
  24. data/ext/ml4r/MachineLearning/RandomForest/RandomForestOutput.cpp +66 -0
  25. data/ext/ml4r/MachineLearning/RandomForest/RandomForestRunner.cpp +84 -0
  26. data/ext/ml4r/MachineLearning/RandomForest/ZenithRandomForest.cpp +184 -0
  27. data/ext/ml4r/ml4r.cpp +34 -0
  28. data/ext/ml4r/ml4r_wrap.cpp +15727 -0
  29. data/ext/ml4r/utils/MathUtils.cpp +204 -0
  30. data/ext/ml4r/utils/StochasticUtils.cpp +73 -0
  31. data/ext/ml4r/utils/Utils.cpp +14 -0
  32. data/ext/ml4r/utils/VlcMessage.cpp +3 -0
  33. metadata +33 -1
@@ -0,0 +1,551 @@
1
+ #include "MachineLearning/DecisionTree/NodeSplitter.h"
2
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
3
+ #include "MachineLearning/DecisionTree/SplitDefinition.h"
4
+ #include "MachineLearning/DecisionTree/DecisionTreeNode.h"
5
+ #include "MachineLearning/DecisionTree/CategoryInfo.h"
6
+ #include "MachineLearning/MLData/MLData.h"
7
+ #include "MachineLearning/GBM/GBMEstimator.h"
8
+ #include "utils/Utils.h"
9
+ #include "utils/StochasticUtils.h"
10
+
11
+ #include <boost/foreach.hpp>
12
+ #include <boost/lexical_cast.hpp>
13
+ #include <cmath>
14
+ using boost::lexical_cast;
15
+
16
+ NodeSplitter::NodeSplitter(MLData* data, int minObservations, double scale)
17
+ : m_data(data), m_minObservations(minObservations),m_scale(scale)
18
+ {
19
+ m_missingValueDefined = m_data->missingValueDefined();
20
+ if (m_missingValueDefined)
21
+ m_missingValue = m_data->getMissingValue();
22
+ }
23
+
24
+ NodeSplitter::~NodeSplitter() {}
25
+
26
+ double NodeSplitter::calculateImprovement(double lhsSumW, double lhsSumZ, double rhsSumW, double rhsSumZ, double missingSumW, double missingSumZ)
27
+ {
28
+ double improvement = 0.0;
29
+
30
+ if (missingSumW == 0)
31
+ {
32
+ double meanZDifference = lhsSumZ / lhsSumW - rhsSumZ / rhsSumW;
33
+ improvement = lhsSumW * rhsSumW * pow(meanZDifference, 2) / (lhsSumW + rhsSumW);
34
+ }
35
+ else
36
+ {
37
+ double meanLRDifference = lhsSumZ / lhsSumW - rhsSumZ / rhsSumW;
38
+ double meanLMDifference = lhsSumZ / lhsSumW - missingSumZ / missingSumW;
39
+ double meanRMDifference = rhsSumZ / rhsSumW - missingSumZ / missingSumW;
40
+
41
+ improvement += lhsSumW * rhsSumW * pow(meanLRDifference, 2);
42
+ improvement += lhsSumW * missingSumW * pow(meanLMDifference, 2);
43
+ improvement += rhsSumW * missingSumW * pow(meanRMDifference, 2);
44
+ improvement /= (lhsSumW + rhsSumW + missingSumW);
45
+ }
46
+
47
+ return improvement;
48
+ }
49
+
50
+ // double NodeSplitter::calculateImprovement(double lhsSumW, double lhsSumZ, double rhsSumW, double rhsSumZ, double missingSumW, double missingSumZ)
51
+ // {
52
+ // double improvement = 0.0;
53
+ //
54
+ // if (lhsSumW)
55
+ // improvement += pow(lhsSumZ, 2) / lhsSumW;
56
+ //
57
+ // if (rhsSumW)
58
+ // improvement += pow(rhsSumZ, 2) / rhsSumW;
59
+ //
60
+ // if (missingSumW)
61
+ // improvement += pow(missingSumZ, 2) / missingSumW;
62
+ //
63
+ // return improvement;
64
+ // }
65
+
66
+ shared_ptr<DecisionTreeNode> NodeSplitter::createLhsChild( shared_ptr<SplitDefinition> splitDefinition )
67
+ {
68
+ return createChild(splitDefinition, LHS);
69
+ }
70
+
71
+ shared_ptr<DecisionTreeNode> NodeSplitter::createRhsChild( shared_ptr<SplitDefinition> splitDefinition )
72
+ {
73
+ return createChild(splitDefinition, RHS);
74
+ }
75
+
76
+ shared_ptr<DecisionTreeNode> NodeSplitter::createMissingChild( shared_ptr<SplitDefinition> splitDefinition )
77
+ {
78
+ return createChild(splitDefinition, MISSING);
79
+ }
80
+
81
+ shared_ptr<DecisionTreeNode> NodeSplitter::createChild( shared_ptr<SplitDefinition> splitDefinition, Partition partition )
82
+ {
83
+ vector<shared_ptr<DecisionTreeExperiment> > experiments = splitDefinition->getNodeToSplit()->getExperiments();
84
+ vector<shared_ptr<DecisionTreeExperiment> > childExperiments =
85
+ partitionExperiments(experiments, splitDefinition, partition);
86
+
87
+ double sumZ;
88
+ double sumW;
89
+ if (partition == LHS)
90
+ {
91
+ sumZ = splitDefinition->getLhsSumZ();
92
+ sumW = splitDefinition->getLhsSumW();
93
+ }
94
+ else if (partition == RHS)
95
+ {
96
+ sumZ = splitDefinition->getRhsSumZ();
97
+ sumW = splitDefinition->getRhsSumW();
98
+ }
99
+ else
100
+ {
101
+ sumZ = splitDefinition->getMissingSumZ();
102
+ sumW = splitDefinition->getMissingSumW();
103
+ }
104
+ shared_ptr<DecisionTreeNode> child =
105
+ shared_ptr<DecisionTreeNode>(new DecisionTreeNode(childExperiments, sumZ, sumW, partition, splitDefinition));
106
+ return child;
107
+ }
108
+
109
+ vector<shared_ptr<DecisionTreeExperiment> > NodeSplitter::partitionExperiments(vector<shared_ptr<DecisionTreeExperiment> >& experiments,
110
+ shared_ptr<SplitDefinition> splitDefinition, Partition partition)
111
+ {
112
+ bool rhs = !partition;
113
+ vector<shared_ptr<DecisionTreeExperiment> > partitionExperiments;
114
+
115
+ if (partition == LHS)
116
+ partitionExperiments.reserve(splitDefinition->getLhsExperimentCount());
117
+ else if (partition == RHS)
118
+ partitionExperiments.reserve(splitDefinition->getRhsExperimentCount());
119
+ else if (partition == MISSING)
120
+ partitionExperiments.reserve(splitDefinition->getMissingExperimentCount());
121
+
122
+ int featureIndex = splitDefinition->getFeatureIndex();
123
+
124
+ if (splitDefinition->isCategorical())
125
+ {
126
+ // categorical
127
+ set<double>& lhsCategories = splitDefinition->getLhsCategories();
128
+ set<double>& rhsCategories = splitDefinition->getRhsCategories();
129
+
130
+ BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
131
+ {
132
+ double featureValue = experiment->getFeatureValue(featureIndex);
133
+ set<double>::const_iterator lhsIt = lhsCategories.find(featureValue);
134
+ set<double>::const_iterator rhsIt = rhsCategories.find(featureValue);
135
+
136
+ if ((partition == MISSING && m_missingValueDefined && m_missingValue == featureValue) ||
137
+ (partition == LHS && lhsIt != lhsCategories.end()) ||
138
+ (partition == RHS && rhsIt != rhsCategories.end()))
139
+ {
140
+ partitionExperiments.push_back(experiment);
141
+ }
142
+ }
143
+ }
144
+ else
145
+ {
146
+ // continuous
147
+ double splitValue = splitDefinition->getSplitValue();
148
+ BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
149
+ {
150
+ double featureValue = experiment->getFeatureValue(featureIndex);
151
+
152
+ if (m_missingValueDefined && featureValue == m_missingValue)
153
+ {
154
+ // experiment has a missing value
155
+ if (partition == MISSING)
156
+ partitionExperiments.push_back(experiment);
157
+ }
158
+ else if ((partition == LHS && featureValue < splitValue) ||
159
+ (partition == RHS && featureValue > splitValue))
160
+ partitionExperiments.push_back(experiment);
161
+ }
162
+ }
163
+ return partitionExperiments;
164
+ }
165
+
166
+ vector<shared_ptr<DecisionTreeNode> > NodeSplitter::splitNode( shared_ptr<DecisionTreeNode> nodeToSplit, vector<int> featuresToConsider )
167
+ {
168
+ vector<shared_ptr<DecisionTreeNode> > children;
169
+
170
+ if (nodeToSplit->getSumW() == 0)
171
+ return children;
172
+
173
+ // find terminal node with best improvement for any of those variables
174
+ shared_ptr<SplitDefinition> bestSplit;
175
+ double bestImprovement = 0.0;
176
+
177
+ vector<double> vecImprovements;
178
+ vector<shared_ptr<SplitDefinition> > vecSplits;
179
+
180
+ vecImprovements.reserve(featuresToConsider.size());
181
+ vecSplits.reserve(featuresToConsider.size());
182
+
183
+ set<int>& categoricalFeatures = m_data->getCategoricalFeatureIndices();
184
+
185
+ BOOST_FOREACH(int featureIndex, featuresToConsider)
186
+ {
187
+
188
+ shared_ptr<SplitDefinition> split;
189
+
190
+ if (Utils::hasElement(categoricalFeatures,featureIndex))
191
+ split = createCategoricalSplitDefinition(nodeToSplit, featureIndex);
192
+ else
193
+ split = createContinuousSplitDefinition(nodeToSplit, featureIndex);
194
+
195
+ vecSplits.push_back(split);
196
+ vecImprovements.push_back(split.get() ? split->getImprovement() : 0);
197
+
198
+ if (!split.get()) // it returned an invalid
199
+ continue;
200
+
201
+ if (split->getImprovement() > bestImprovement)
202
+ {
203
+ bestImprovement = split->getImprovement();
204
+ bestSplit = split;
205
+ }
206
+ }
207
+ if (bestImprovement == 0.0)
208
+ return children;
209
+
210
+ if (m_scale != std::numeric_limits<double>::infinity() && vecImprovements.size() > 1)
211
+ {
212
+ vector<float> exp_u;
213
+ BOOST_FOREACH(double improvement, vecImprovements)
214
+ exp_u.push_back(m_scale * improvement / bestImprovement);
215
+
216
+ vector<float> pdf = StochasticUtils::convertHistogramToPdf(exp_u);
217
+ int bestIndex = StochasticUtils::chooseCategoryFromPdf(pdf);
218
+ bestImprovement = vecImprovements.at(bestIndex);
219
+ bestSplit = vecSplits.at(bestIndex);
220
+ }
221
+
222
+ int featureIndex = bestSplit->getFeatureIndex();
223
+ bool isCategorical = Utils::hasElement(categoricalFeatures,featureIndex);
224
+
225
+ shared_ptr<DecisionTreeNode> lhsChild = createLhsChild(bestSplit);
226
+ shared_ptr<DecisionTreeNode> rhsChild = createRhsChild(bestSplit);
227
+ shared_ptr<DecisionTreeNode> missingChild = createMissingChild(bestSplit);
228
+
229
+ nodeToSplit->defineSplit(bestSplit,lhsChild,rhsChild,missingChild);
230
+
231
+ // if (m_parameters.verbose)
232
+ // vlcMessage.Write("Split at feature index " + ToString(bestSplit->getFeatureIndex()) + " at value " + ToString(bestSplit->getSplitValue()) + " with improvement " + ToString(bestSplit->getImprovement()));
233
+
234
+ // finally, remove the node we just split from the terminal nodes, and add the children
235
+ children.push_back(lhsChild);
236
+ children.push_back(rhsChild);
237
+ children.push_back(missingChild);
238
+
239
+ return children;
240
+ }
241
+
242
+ struct FeatureSorter
243
+ {
244
+ FeatureSorter()
245
+ : featureIndexToSort(-1)
246
+ {}
247
+
248
+ int featureIndexToSort;
249
+
250
+ bool operator() (shared_ptr<DecisionTreeExperiment> a, shared_ptr<DecisionTreeExperiment> b)
251
+ {
252
+ if (featureIndexToSort == -1)
253
+ throw std::runtime_error("SortOnFeature object doesn't know which feature to sort on!");
254
+
255
+ return a->getFeatureValue(featureIndexToSort) < b->getFeatureValue(featureIndexToSort);
256
+ }
257
+ } featureSorter;
258
+
259
+ shared_ptr<SplitDefinition> NodeSplitter::createContinuousSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
260
+ {
261
+ vector<shared_ptr<DecisionTreeExperiment> > sortedExperiments = node->getExperiments();
262
+ // vector<shared_ptr<DecisionTreeExperiment> >& sortedExperiments = node->getSortedExperimentsForFeature(featureIndex);
263
+
264
+ featureSorter.featureIndexToSort = featureIndex;
265
+ sort(sortedExperiments.begin(), sortedExperiments.end(), featureSorter);
266
+
267
+ double rhsSumZ = 0, rhsSumW = 0, lhsSumZ = 0, lhsSumW = 0;
268
+ double missingSumZ = 0, missingSumW = 0;
269
+
270
+ vector<double> vecLhsSumZ;
271
+ vector<double> vecLhsSumW;
272
+ vector<int> vecLhsCount;
273
+ vector<double> vecRhsSumZ;
274
+ vector<double> vecRhsSumW;
275
+ vector<int> vecRhsCount;
276
+ vector<double> vecMissingSumZ;
277
+ vector<double> vecMissingSumW;
278
+ vector<int> vecMissingCount;
279
+ vector<double> vecImprovement;
280
+ vector<int> vecPosition;
281
+
282
+ double bestLhsSumZ;
283
+ double bestLhsSumW;
284
+ int bestLhsCount;
285
+ double bestRhsSumZ;
286
+ double bestRhsSumW;
287
+ int bestRhsCount;
288
+ double bestMissingSumZ;
289
+ double bestMissingSumW;
290
+ int bestMissingCount;
291
+
292
+ double bestImprovement = 0.0;
293
+ int bestPosition = -1;
294
+ int bestIndex = -1;
295
+
296
+ int lhsCount = 0, missingCount = 0;
297
+ int rhsCount = (int) sortedExperiments.size();
298
+
299
+ rhsSumZ = node->getSumZ();
300
+ rhsSumW = node->getSumW();
301
+ int position = -1;
302
+ double previousFeatureValue = 0;
303
+
304
+ BOOST_FOREACH(shared_ptr<DecisionTreeExperiment> experiment, sortedExperiments)
305
+ {
306
+ double featureValue = experiment->getFeatureValue(featureIndex);
307
+
308
+ if (featureValue != previousFeatureValue)
309
+ {
310
+ // vlcMessage.Write("featureValue != previousFeatureValue => " + ToString(featureValue != previousFeatureValue));
311
+ // vlcMessage.Write("lhsSumW => " + ToString(lhsSumW));
312
+ // vlcMessage.Write("lhsSumZ => " + ToString(lhsSumZ));
313
+ // vlcMessage.Write("rhsSumW => " + ToString(rhsSumW));
314
+ // vlcMessage.Write("rhsSumZ => " + ToString(rhsSumZ));
315
+ // vlcMessage.Write("missingSumW => " + ToString(missingSumW));
316
+ // vlcMessage.Write("missingSumZ => " + ToString(missingSumZ));
317
+ // vlcMessage.Write("improvement => " + ToString(improvement));
318
+ // vlcMessage.Write("bestImprovement => " + ToString(bestImprovement));
319
+ // vlcMessage.Write("m_minObservations => " + ToString(m_minObservations));
320
+ }
321
+
322
+ if (featureValue != previousFeatureValue &&
323
+ lhsCount >= m_minObservations &&
324
+ rhsCount >= m_minObservations
325
+ )
326
+ {
327
+ double improvement = calculateImprovement(lhsSumW, lhsSumZ, rhsSumW, rhsSumZ, missingSumW, missingSumZ);
328
+ vecPosition.push_back(position);
329
+ vecImprovement.push_back(improvement);
330
+ vecLhsSumZ.push_back(lhsSumZ);
331
+ vecLhsSumW.push_back(lhsSumW);
332
+ vecLhsCount.push_back(lhsCount);
333
+ vecRhsSumZ.push_back(rhsSumZ);
334
+ vecRhsSumW.push_back(rhsSumW);
335
+ vecRhsCount.push_back(rhsCount);
336
+ vecMissingSumZ.push_back(missingSumZ);
337
+ vecMissingSumW.push_back(missingSumW);
338
+ vecMissingCount.push_back(missingCount);
339
+
340
+ if (improvement > bestImprovement)
341
+ {
342
+ bestImprovement = improvement;
343
+ bestPosition = position;
344
+ bestIndex = (int) vecPosition.size() - 1;
345
+ }
346
+ // if (improvement > bestImprovement)
347
+ // {
348
+ // bestImprovement = improvement;
349
+ // bestPosition = position;
350
+ // bestLhsSumZ = lhsSumZ;
351
+ // bestLhsSumW = lhsSumW;
352
+ // bestLhsCount = lhsCount;
353
+ // bestRhsSumZ = rhsSumZ;
354
+ // bestRhsSumW = rhsSumW;
355
+ // bestRhsCount = rhsCount;
356
+ // bestMissingSumZ = missingSumZ;
357
+ // bestMissingSumW = missingSumW;
358
+ // bestMissingCount = missingCount;
359
+ // }
360
+ // vlcMessage.Write("improvement => " + ToString(improvement));
361
+
362
+ }
363
+ double weight = experiment->getWeight();
364
+ double z = experiment->getZ();
365
+ rhsSumZ -= weight * z;
366
+ rhsSumW -= weight;
367
+ --rhsCount;
368
+
369
+ if (m_missingValueDefined && featureValue == m_missingValue)
370
+ {
371
+ missingSumZ += weight * z;
372
+ missingSumW += weight;
373
+ ++missingCount;
374
+ }
375
+ else
376
+ {
377
+ lhsSumZ += weight * z;
378
+ lhsSumW += weight;
379
+ ++lhsCount;
380
+ }
381
+
382
+ previousFeatureValue = featureValue;
383
+ ++position;
384
+ }
385
+
386
+ if (bestPosition == -1)
387
+ return shared_ptr<SplitDefinition>();
388
+
389
+ if (m_scale != std::numeric_limits<double>::infinity() && vecImprovement.size() > 1)
390
+ {
391
+ vector<float> exp_u;
392
+ exp_u.reserve(vecImprovement.size());
393
+ BOOST_FOREACH(double& improvement, vecImprovement)
394
+ {
395
+ exp_u.push_back(exp(m_scale * improvement / bestImprovement));
396
+ }
397
+ vector<float> pdf = StochasticUtils::convertHistogramToPdf(exp_u);
398
+ bestIndex = StochasticUtils::chooseCategoryFromPdf(pdf, "improvements");
399
+ }
400
+
401
+ bestLhsSumZ = vecLhsSumZ.at(bestIndex);
402
+ bestLhsSumW = vecLhsSumW.at(bestIndex);
403
+ bestLhsCount = vecLhsCount.at(bestIndex);
404
+ bestRhsSumZ = vecRhsSumZ.at(bestIndex);
405
+ bestRhsSumW = vecRhsSumW.at(bestIndex);
406
+ bestRhsCount = vecRhsCount.at(bestIndex);
407
+ bestMissingSumZ = vecMissingSumZ.at(bestIndex);
408
+ bestMissingSumW = vecMissingSumW.at(bestIndex);
409
+ bestMissingCount = vecMissingCount.at(bestIndex);
410
+ bestImprovement = vecImprovement.at(bestIndex);
411
+ bestPosition = vecPosition.at(bestIndex);
412
+
413
+ if (bestPosition >= (int) (sortedExperiments.size()-1))
414
+ throw std::runtime_error(string("Unexpected bestPosition: ") + lexical_cast<string>(bestPosition));
415
+
416
+ double lhsFeatureValue = sortedExperiments.at(bestPosition)->getFeatureValue(featureIndex);
417
+ double rhsFeatureValue = sortedExperiments.at(bestPosition + 1)->getFeatureValue(featureIndex);
418
+
419
+ double splitValue;
420
+ if (m_missingValueDefined && (lhsFeatureValue == m_missingValue))
421
+ splitValue = m_missingValue;
422
+ else
423
+ splitValue = 0.5 * (lhsFeatureValue + rhsFeatureValue);
424
+
425
+
426
+ shared_ptr<SplitDefinition> splitDefinition = shared_ptr<SplitDefinition>
427
+ (new SplitDefinition(node, featureIndex, splitValue, bestLhsSumZ, bestLhsSumW, bestLhsCount,
428
+ bestRhsSumZ, bestRhsSumW, bestRhsCount, bestMissingSumZ, bestMissingSumW,
429
+ bestMissingCount, bestImprovement));
430
+
431
+ // create SplitDefinition
432
+ return splitDefinition;
433
+ }
434
+
435
+ shared_ptr<SplitDefinition> NodeSplitter::createCategoricalSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
436
+ {
437
+ vector<shared_ptr<DecisionTreeExperiment> > experiments = node->getExperiments();
438
+
439
+ map<double, CategoryInfo> experimentsPerCategory;
440
+
441
+ double missingSumZ = 0, missingSumW = 0;
442
+ int missingCount = 0;
443
+ BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
444
+ {
445
+ double featureValue = experiment->getFeatureValue(featureIndex);
446
+
447
+ if (m_missingValueDefined && m_missingValue == featureValue)
448
+ {
449
+ double w = experiment->getWeight();
450
+ double z = experiment->getZ();
451
+ missingSumZ += w * z;
452
+ missingSumW += w;
453
+ missingCount++;
454
+ }
455
+ else
456
+ {
457
+ CategoryInfo& info = experimentsPerCategory[featureValue];
458
+ info.category = featureValue;
459
+ info.addExperiment(experiment);
460
+ }
461
+ }
462
+
463
+ if (experimentsPerCategory.size() == 1)
464
+ return shared_ptr<SplitDefinition>(); // can't split one thing!
465
+
466
+ // put them into a vector to make sorting easier!
467
+ vector<CategoryInfo> categoryInfoVector;
468
+ typedef pair<double, CategoryInfo> ElementType;
469
+ BOOST_FOREACH(ElementType e, experimentsPerCategory)
470
+ categoryInfoVector.push_back(e.second);
471
+
472
+ sort(categoryInfoVector.begin(), categoryInfoVector.end());
473
+
474
+ double rhsSumZ = 0, rhsSumW = 0, lhsSumZ = 0, lhsSumW = 0;
475
+ double bestImprovement = 0.0;
476
+ int bestPosition = -1;
477
+ double bestLhsSumZ = -1;
478
+ double bestLhsSumW = -1;
479
+ int bestLhsCount = -1;
480
+ double bestRhsSumZ = -1;
481
+ double bestRhsSumW = -1;
482
+ int bestRhsCount = -1;
483
+ int lhsCount = 0, rhsCount = 0;
484
+
485
+ rhsSumZ = node->getSumZ();
486
+ rhsSumW = node->getSumW();
487
+ rhsCount = (int) experiments.size();
488
+
489
+ int position = -1;
490
+
491
+ BOOST_FOREACH(CategoryInfo& e, categoryInfoVector)
492
+ {
493
+ double improvement = calculateImprovement(lhsSumW, lhsSumZ, rhsSumW, rhsSumZ, missingSumW, missingSumZ);
494
+ if ( improvement > bestImprovement
495
+ && lhsCount >= m_minObservations
496
+ && rhsCount >= m_minObservations)
497
+ {
498
+ bestImprovement = improvement;
499
+ bestPosition = position;
500
+ bestLhsSumZ = lhsSumZ;
501
+ bestLhsSumW = lhsSumW;
502
+ bestLhsCount = lhsCount;
503
+ bestRhsSumZ = rhsSumZ;
504
+ bestRhsSumW = rhsSumW;
505
+ bestRhsCount = rhsCount;
506
+ }
507
+
508
+ ++position;
509
+ rhsSumW -= e.sumW;
510
+ rhsSumZ -= e.sumZ;
511
+ rhsCount -= e.countN;
512
+
513
+ lhsSumW += e.sumW;
514
+ lhsSumZ += e.sumZ;
515
+ lhsCount += e.countN;
516
+ }
517
+
518
+ if (bestPosition == -1 && missingSumW == 0)
519
+ return shared_ptr<SplitDefinition>();
520
+
521
+ set<double> lhsCategories;
522
+ set<double> rhsCategories;
523
+
524
+ int index = -1;
525
+ BOOST_FOREACH(CategoryInfo& info, categoryInfoVector)
526
+ {
527
+ ++index;
528
+ if (index <= bestPosition)
529
+ lhsCategories.insert(info.category);
530
+ else
531
+ rhsCategories.insert(info.category);
532
+ }
533
+
534
+ // we have what we need to create a split definition now
535
+ shared_ptr<SplitDefinition> splitDefinition = shared_ptr<SplitDefinition>
536
+ (new SplitDefinition(node, featureIndex, lhsCategories, rhsCategories, bestLhsSumZ, bestLhsSumW,
537
+ bestLhsCount, bestRhsSumZ, bestRhsSumW, bestRhsCount, missingSumZ, missingSumW, missingCount, bestImprovement));
538
+
539
+ return splitDefinition;
540
+ }
541
+
542
+ shared_ptr<SplitDefinition> NodeSplitter::createSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
543
+ {
544
+ if (Utils::hasElement(m_data->getCategoricalFeatureIndices(),featureIndex))
545
+ return createCategoricalSplitDefinition(node, featureIndex);
546
+ else
547
+ return createContinuousSplitDefinition(node, featureIndex);
548
+ }
549
+
550
+
551
+
@@ -0,0 +1,22 @@
1
+ #include "MachineLearning/DecisionTree/NodeSplitterCategorical.h"
2
+ #include "MachineLearning/DecisionTree/DecisionTreeNode.h"
3
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
4
+ #include "MachineLearning/DecisionTree/SplitDefinition.h"
5
+ #include "MachineLearning/DecisionTree/CategoryInfo.h"
6
+
7
+ NodeSplitterCategorical::NodeSplitterCategorical(MLData* data, int minObservations, double scale)
8
+ : NodeSplitter(data, minObservations, scale)
9
+ {
10
+
11
+ }
12
+
13
+ NodeSplitterCategorical::~NodeSplitterCategorical()
14
+ {
15
+
16
+ }
17
+
18
+ shared_ptr<SplitDefinition> NodeSplitterCategorical::createSplitDefinition(shared_ptr<DecisionTreeNode> node,
19
+ int featureIndex)
20
+ {
21
+
22
+ }
@@ -0,0 +1,21 @@
1
+ #include "MachineLearning/DecisionTree/NodeSplitterContinuous.h"
2
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
3
+ #include "MachineLearning/DecisionTree/SplitDefinition.h"
4
+ #include "MachineLearning/DecisionTree/DecisionTreeNode.h"
5
+
6
+
7
+
8
+ NodeSplitterContinuous::NodeSplitterContinuous(MLData* data, int minObservations, double scale)
9
+ : NodeSplitter(data, minObservations, scale)
10
+ {
11
+ }
12
+
13
+ NodeSplitterContinuous::~NodeSplitterContinuous()
14
+ {
15
+
16
+ }
17
+
18
+ shared_ptr<SplitDefinition> NodeSplitterContinuous::createSplitDefinition(shared_ptr<DecisionTreeNode> node, int featureIndex)
19
+ {
20
+
21
+ }