ml4r 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/ext/ml4r/LinearRegression/LinearRegression.cpp +305 -0
- data/ext/ml4r/LinearRegression/OLSLinearRegression.cpp +75 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeExperiment.cpp +50 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeNode.cpp +195 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitter.cpp +551 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterCategorical.cpp +22 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterContinuous.cpp +21 -0
- data/ext/ml4r/MachineLearning/DecisionTree/SplitDefinition.cpp +142 -0
- data/ext/ml4r/MachineLearning/GBM/BernoulliCalculator.cpp +95 -0
- data/ext/ml4r/MachineLearning/GBM/GBMEstimator.cpp +601 -0
- data/ext/ml4r/MachineLearning/GBM/GBMOutput.cpp +86 -0
- data/ext/ml4r/MachineLearning/GBM/GBMRunner.cpp +117 -0
- data/ext/ml4r/MachineLearning/GBM/GaussianCalculator.cpp +94 -0
- data/ext/ml4r/MachineLearning/GBM/ZenithGBM.cpp +317 -0
- data/ext/ml4r/MachineLearning/MLData/MLData.cpp +232 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataFields.cpp +1 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataReader.cpp +139 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLData.cpp +96 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLDataReader.cpp +113 -0
- data/ext/ml4r/MachineLearning/MLExperiment.cpp +69 -0
- data/ext/ml4r/MachineLearning/MLRunner.cpp +183 -0
- data/ext/ml4r/MachineLearning/MLUtils.cpp +15 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestEstimator.cpp +172 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestOutput.cpp +66 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestRunner.cpp +84 -0
- data/ext/ml4r/MachineLearning/RandomForest/ZenithRandomForest.cpp +184 -0
- data/ext/ml4r/ml4r.cpp +34 -0
- data/ext/ml4r/ml4r_wrap.cpp +15727 -0
- data/ext/ml4r/utils/MathUtils.cpp +204 -0
- data/ext/ml4r/utils/StochasticUtils.cpp +73 -0
- data/ext/ml4r/utils/Utils.cpp +14 -0
- data/ext/ml4r/utils/VlcMessage.cpp +3 -0
- metadata +33 -1
@@ -0,0 +1,117 @@
|
|
1
|
+
|
2
|
+
#include "MachineLearning/GBM/GBMRunner.h"
|
3
|
+
#include "MachineLearning/GBM/GBMEstimator.h"
|
4
|
+
#include "MachineLearning/GBM/GBMOutput.h"
|
5
|
+
#include "MachineLearning/GBM/BernoulliCalculator.h"
|
6
|
+
#include "MachineLearning/GBM/GaussianCalculator.h"
|
7
|
+
#include "MachineLearning/MLData/MLData.h"
|
8
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
9
|
+
#include "MachineLearning/DecisionTree/DecisionTreeNode.h"
|
10
|
+
#include "MachineLearning/DecisionTree/FeatureInteraction.h"
|
11
|
+
|
12
|
+
#include "utils/VlcMessage.h"
|
13
|
+
|
14
|
+
// #ifdef TBB_USE_THREADING_TOOLS
|
15
|
+
// #undef TBB_USE_THREADING_TOOLS
|
16
|
+
// #endif
|
17
|
+
// #define TBB_USE_THREADING_TOOLS 1
|
18
|
+
// #include "tbb/task_scheduler_init.h"
|
19
|
+
// #include "tbb/parallel_for.h"
|
20
|
+
// #include "tbb/blocked_range.h"
|
21
|
+
// #include "tbb/explicit_range.h"
|
22
|
+
|
23
|
+
#include <math.h>
|
24
|
+
#include <boost/pointer_cast.hpp>
|
25
|
+
#include <boost/make_shared.hpp>
|
26
|
+
#include <boost/foreach.hpp>
|
27
|
+
using boost::make_shared;
|
28
|
+
using boost::dynamic_pointer_cast;
|
29
|
+
|
30
|
+
|
31
|
+
GBMRunner::GBMRunner()
|
32
|
+
{
|
33
|
+
parameters = make_shared<GBMParameters>();
|
34
|
+
}
|
35
|
+
|
36
|
+
GBMRunner::~GBMRunner()
|
37
|
+
{
|
38
|
+
|
39
|
+
}
|
40
|
+
|
41
|
+
void GBMRunner::config()
|
42
|
+
{
|
43
|
+
|
44
|
+
vector<string>& dataFeatures = m_data->getFeatures();
|
45
|
+
|
46
|
+
// parameters->loadedFeatures = dataFeatures;
|
47
|
+
if (parameters->featuresToRun.empty())
|
48
|
+
parameters->featuresToRun = dataFeatures;
|
49
|
+
else
|
50
|
+
{
|
51
|
+
BOOST_FOREACH(string feature, parameters->featuresToRun)
|
52
|
+
{
|
53
|
+
if (!Utils::hasElement(dataFeatures, feature))
|
54
|
+
throw std::runtime_error("Feature '" + feature + "' specified as part of parameter 'featuresToRun', but feature not found in data");
|
55
|
+
}
|
56
|
+
}
|
57
|
+
if (parameters->featuresToRun.empty())
|
58
|
+
throw std::runtime_error("There are no features to run!");
|
59
|
+
|
60
|
+
if (m_data->missingValueDefined())
|
61
|
+
DecisionTreeNode::setMissingValue(m_data->getMissingValue());
|
62
|
+
|
63
|
+
}
|
64
|
+
|
65
|
+
void GBMRunner::estimateMore(int numTrees)
|
66
|
+
{
|
67
|
+
int numFolds = m_data->getNumFolds();
|
68
|
+
int numThreads = numFolds; // TODO: change this!
|
69
|
+
|
70
|
+
// tbb::task_scheduler_init init(numFolds);
|
71
|
+
// static tbb::simple_partitioner sp;
|
72
|
+
|
73
|
+
int grainSize = numFolds / numThreads;
|
74
|
+
|
75
|
+
// tbb::parallel_for(explicit_range<size_t>(0, numFolds, grainSize),
|
76
|
+
// [&](const explicit_range<size_t>& r) {
|
77
|
+
// int threadNumber = r.begin() / grainSize;
|
78
|
+
// for(size_t foldIndex=r.begin(); foldIndex!=r.end(); ++foldIndex)
|
79
|
+
for (int foldIndex = 0; foldIndex < numFolds; ++foldIndex)
|
80
|
+
{
|
81
|
+
vlcMessage.Begin("Estimating more...");
|
82
|
+
|
83
|
+
shared_ptr<GBMEstimator> estimator = dynamic_pointer_cast<GBMEstimator>(m_estimators.at(foldIndex));
|
84
|
+
estimator->estimateMore(numTrees);
|
85
|
+
|
86
|
+
vlcMessage.End();
|
87
|
+
}
|
88
|
+
// }, sp);
|
89
|
+
}
|
90
|
+
|
91
|
+
void GBMRunner::capTrees( int numTrees )
|
92
|
+
{
|
93
|
+
BOOST_FOREACH(shared_ptr<MLOutput>& output, m_outputObjects)
|
94
|
+
{
|
95
|
+
shared_ptr<GBMOutput> gbmOutput = dynamic_pointer_cast<GBMOutput>(output);
|
96
|
+
gbmOutput->capTrees(numTrees);
|
97
|
+
}
|
98
|
+
}
|
99
|
+
|
100
|
+
vector<FeatureInteraction> GBMRunner::getFeatureInteractions( int howMany )
|
101
|
+
{
|
102
|
+
config();
|
103
|
+
GBMEstimator gbmEstimator(m_data, m_data->getExperiments(), parameters);
|
104
|
+
return gbmEstimator.findInteractions(howMany);
|
105
|
+
}
|
106
|
+
|
107
|
+
shared_ptr<MLEstimator> GBMRunner::createEstimator(MLData* data, vector<shared_ptr<MLExperiment> > trainingExperiments)
|
108
|
+
{
|
109
|
+
return shared_ptr<MLEstimator>(shared_ptr<GBMEstimator>(new GBMEstimator(data, trainingExperiments, parameters)));
|
110
|
+
}
|
111
|
+
|
112
|
+
|
113
|
+
|
114
|
+
|
115
|
+
|
116
|
+
|
117
|
+
|
@@ -0,0 +1,94 @@
|
|
1
|
+
#include "MachineLearning/GBM/GaussianCalculator.h"
|
2
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
3
|
+
|
4
|
+
#include <boost/foreach.hpp>
|
5
|
+
|
6
|
+
GaussianCalculator::GaussianCalculator()
|
7
|
+
{
|
8
|
+
|
9
|
+
}
|
10
|
+
|
11
|
+
GaussianCalculator::~GaussianCalculator()
|
12
|
+
{
|
13
|
+
|
14
|
+
}
|
15
|
+
|
16
|
+
double GaussianCalculator::calculateDeviance(vector<shared_ptr<DecisionTreeExperiment> >& experiments)
|
17
|
+
{
|
18
|
+
double sumSquaredErrors = 0.0;
|
19
|
+
double sumWeight = 0.0;
|
20
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
21
|
+
{
|
22
|
+
double error = experiment->getY() - experiment->getPrediction();
|
23
|
+
sumSquaredErrors += experiment->getWeight() * error * error;
|
24
|
+
sumWeight += experiment->getWeight();
|
25
|
+
}
|
26
|
+
return sumSquaredErrors / sumWeight;
|
27
|
+
}
|
28
|
+
|
29
|
+
void GaussianCalculator::populateInitialF(vector<shared_ptr<DecisionTreeExperiment> >& experiments, bool useInitialPredictions)
|
30
|
+
{
|
31
|
+
if (!useInitialPredictions)
|
32
|
+
{
|
33
|
+
// compute mean
|
34
|
+
double sumY = 0.0;
|
35
|
+
double sumW = 0.0;
|
36
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
37
|
+
{
|
38
|
+
sumY += experiment->getWeight() * experiment->getY();
|
39
|
+
sumW += experiment->getWeight();
|
40
|
+
}
|
41
|
+
double meanY = sumY / sumW;
|
42
|
+
|
43
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
44
|
+
{
|
45
|
+
experiment->setPrediction(meanY);
|
46
|
+
}
|
47
|
+
}
|
48
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
49
|
+
{
|
50
|
+
experiment->setF(calculateF(experiment->getPrediction()));
|
51
|
+
}
|
52
|
+
}
|
53
|
+
|
54
|
+
void GaussianCalculator::updateZ(vector<shared_ptr<DecisionTreeExperiment> >& experiments)
|
55
|
+
{
|
56
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
57
|
+
{
|
58
|
+
double z = experiment->getY() - experiment->getPrediction();
|
59
|
+
experiment->setZ(z);
|
60
|
+
}
|
61
|
+
}
|
62
|
+
|
63
|
+
double GaussianCalculator::computeFIncrement(vector<shared_ptr<DecisionTreeExperiment> >& experiments)
|
64
|
+
{
|
65
|
+
double sumZ = 0.0;
|
66
|
+
double sumW = 0.0;
|
67
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
68
|
+
{
|
69
|
+
sumZ += experiment->getWeight() * experiment->getZ();
|
70
|
+
sumW += experiment->getWeight();
|
71
|
+
}
|
72
|
+
if (sumW == 0)
|
73
|
+
return 0.0;
|
74
|
+
|
75
|
+
return sumZ / sumW;
|
76
|
+
}
|
77
|
+
|
78
|
+
void GaussianCalculator::updatePredictions(vector<shared_ptr<DecisionTreeExperiment> >& experiments)
|
79
|
+
{
|
80
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
81
|
+
{
|
82
|
+
experiment->setPrediction(calculatePrediction(experiment->getF()));
|
83
|
+
}
|
84
|
+
}
|
85
|
+
|
86
|
+
double GaussianCalculator::calculatePrediction(double f)
|
87
|
+
{
|
88
|
+
return f;
|
89
|
+
}
|
90
|
+
|
91
|
+
double GaussianCalculator::calculateF(double prediction)
|
92
|
+
{
|
93
|
+
return prediction;
|
94
|
+
}
|
@@ -0,0 +1,317 @@
|
|
1
|
+
// #include "MachineLearning/GBM/ZenithGBM.h"
|
2
|
+
// #include "MachineLearning/GBM/GBMRunner.h"
|
3
|
+
// #include "MachineLearning/DecisionTree/FeatureInteraction.h"
|
4
|
+
// #include "MachineLearning/DecisionTree/SplitDefinition.h"
|
5
|
+
// #include "MachineLearning/MLData/MLData.h"
|
6
|
+
// #include "MachineLearning/gbm/GBMParameters.h"
|
7
|
+
|
8
|
+
// #include "stringConversion.h"
|
9
|
+
// #include "RubyUtils.h"
|
10
|
+
// using namespace RubyUtils;
|
11
|
+
|
12
|
+
|
13
|
+
|
14
|
+
// void zenith_gbm_Free(void* v)
|
15
|
+
// {
|
16
|
+
// delete (reinterpret_cast<GBMRunner*>(v));
|
17
|
+
// }
|
18
|
+
|
19
|
+
// OtInterface::VALUE zenith_gbm_New(int argc, VALUE* argv, VALUE klass)
|
20
|
+
// {
|
21
|
+
// VALUE obj = otRuby->DataWrapStruct(klass, 0, zenith_gbm_Free, 0);
|
22
|
+
// otRuby->rb_obj_call_init(obj, argc, argv);
|
23
|
+
// return obj;
|
24
|
+
// }
|
25
|
+
|
26
|
+
// OtInterface::VALUE zenith_gbm_Initialize(VALUE self)
|
27
|
+
// {
|
28
|
+
// if (otRuby->GetDataPtr(self)) zenith_gbm_Free(otRuby->GetDataPtr(self));
|
29
|
+
// otRuby->SetDataPtr(self, NULL);
|
30
|
+
|
31
|
+
// GBMRunner* gbm = new GBMRunner();
|
32
|
+
// if (gbm == NULL) otRuby->rb_sys_fail("ZenithGBM class could not be created");
|
33
|
+
// otRuby->SetDataPtr(self, gbm);
|
34
|
+
// return self;
|
35
|
+
// }
|
36
|
+
|
37
|
+
// OtInterface::VALUE zenith_gbm_estimate(VALUE self)
|
38
|
+
// {
|
39
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
40
|
+
// try
|
41
|
+
// {
|
42
|
+
// gbm->execute();
|
43
|
+
// }
|
44
|
+
// catch (std::exception e)
|
45
|
+
// {
|
46
|
+
// vlcMessage.Raise((string("Caught error: ") + e.what()).c_str());
|
47
|
+
// }
|
48
|
+
|
49
|
+
// return TOtRubyInterface::Qnil;
|
50
|
+
// }
|
51
|
+
|
52
|
+
|
53
|
+
// OtInterface::VALUE zenith_gbm_estimateMore(VALUE self, VALUE numTrees)
|
54
|
+
// {
|
55
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
56
|
+
// try
|
57
|
+
// {
|
58
|
+
// gbm->estimateMore(RubyUtils::fromValue<int>(numTrees));
|
59
|
+
// }
|
60
|
+
// catch (std::exception e)
|
61
|
+
// {
|
62
|
+
// vlcMessage.Raise((string("Caught error: ") + e.what()).c_str());
|
63
|
+
// }
|
64
|
+
|
65
|
+
// return TOtRubyInterface::Qnil;
|
66
|
+
// }
|
67
|
+
|
68
|
+
|
69
|
+
// OtInterface::VALUE zenith_gbm_setFeaturesToRun(VALUE self, VALUE featuresValue)
|
70
|
+
// {
|
71
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
72
|
+
// gbm->parameters->featuresToRun = RubyUtils::fromValue<vector<string> >(featuresValue);
|
73
|
+
// return TOtRubyInterface::Qnil;
|
74
|
+
// }
|
75
|
+
|
76
|
+
// OtInterface::VALUE zenith_gbm_setData(VALUE self, VALUE data)
|
77
|
+
// {
|
78
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
79
|
+
// MLData* mlData = (MLData*)otRuby->GetDataPtr(data);
|
80
|
+
// gbm->setData(mlData);
|
81
|
+
// return TOtRubyInterface::Qnil;
|
82
|
+
// }
|
83
|
+
|
84
|
+
// OtInterface::VALUE zenith_gbm_setTryMVariables(VALUE self, VALUE mVariablesValue)
|
85
|
+
// {
|
86
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
87
|
+
// gbm->parameters->tryMVariables = RubyUtils::fromValue<int>(mVariablesValue);
|
88
|
+
// return TOtRubyInterface::Qnil;
|
89
|
+
// }
|
90
|
+
|
91
|
+
// OtInterface::VALUE zenith_gbm_setKTerminalNodes(VALUE self, VALUE kNodesValue)
|
92
|
+
// {
|
93
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
94
|
+
// gbm->parameters->growKDecisionTreeNodes = RubyUtils::fromValue<int>(kNodesValue);
|
95
|
+
// return TOtRubyInterface::Qnil;
|
96
|
+
// }
|
97
|
+
|
98
|
+
// OtInterface::VALUE zenith_gbm_setNumIterations(VALUE self, VALUE numIterationsValue)
|
99
|
+
// {
|
100
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
101
|
+
// gbm->parameters->numIterations = RubyUtils::fromValue<int>(numIterationsValue);
|
102
|
+
// return TOtRubyInterface::Qnil;
|
103
|
+
// }
|
104
|
+
|
105
|
+
// OtInterface::VALUE zenith_gbm_setShrinkageFactor(VALUE self, VALUE shrinkageFactorValue)
|
106
|
+
// {
|
107
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
108
|
+
// gbm->parameters->shrinkageFactor = RubyUtils::fromValue<double>(shrinkageFactorValue);
|
109
|
+
// return TOtRubyInterface::Qnil;
|
110
|
+
// }
|
111
|
+
|
112
|
+
// OtInterface::VALUE zenith_gbm_setBagFraction(VALUE self, VALUE bagFractionValue)
|
113
|
+
// {
|
114
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
115
|
+
// gbm->parameters->bagFraction = RubyUtils::fromValue<double>(bagFractionValue);
|
116
|
+
// return TOtRubyInterface::Qnil;
|
117
|
+
// }
|
118
|
+
|
119
|
+
// OtInterface::VALUE zenith_gbm_setTrainingExperimentIds(VALUE self, VALUE experimentIdsValue)
|
120
|
+
// {
|
121
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
122
|
+
// gbm->parameters->trainingExperimentIds = RubyUtils::fromValue<vector<int> >(experimentIdsValue);
|
123
|
+
// return TOtRubyInterface::Qnil;
|
124
|
+
// }
|
125
|
+
|
126
|
+
// OtInterface::VALUE zenith_gbm_predictions(VALUE self, VALUE newMlData)
|
127
|
+
// {
|
128
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
129
|
+
// MLData* data = (MLData*)otRuby->GetDataPtr(newMlData);
|
130
|
+
|
131
|
+
// vector<double> predictions;
|
132
|
+
|
133
|
+
// try
|
134
|
+
// {
|
135
|
+
// predictions = gbm->getPredictions(data);
|
136
|
+
// }
|
137
|
+
// catch (std::exception e)
|
138
|
+
// {
|
139
|
+
// vlcMessage.Raise((string("Could not get predictions. Error: ") + e.what()).c_str());
|
140
|
+
// }
|
141
|
+
|
142
|
+
// return RubyUtils::toValue(predictions);
|
143
|
+
// }
|
144
|
+
|
145
|
+
// OtInterface::VALUE zenith_gbm_training_predictions(VALUE self)
|
146
|
+
// {
|
147
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
148
|
+
// vector<double> predictions;
|
149
|
+
|
150
|
+
// try
|
151
|
+
// {
|
152
|
+
// predictions = gbm->getMeanTrainingPredictions();
|
153
|
+
// }
|
154
|
+
// catch (std::exception e)
|
155
|
+
// {
|
156
|
+
// vlcMessage.Raise((string("Could not get training predictions. Error: ") + e.what()).c_str());
|
157
|
+
// }
|
158
|
+
|
159
|
+
// return RubyUtils::toValue(predictions);
|
160
|
+
// }
|
161
|
+
|
162
|
+
// OtInterface::VALUE zenith_gbm_crossvalidation_predictions(VALUE self)
|
163
|
+
// {
|
164
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
165
|
+
// vector<double> predictions;
|
166
|
+
|
167
|
+
// try
|
168
|
+
// {
|
169
|
+
// predictions = gbm->getCrossValidationPredictions();
|
170
|
+
// }
|
171
|
+
// catch (std::exception e)
|
172
|
+
// {
|
173
|
+
// vlcMessage.Raise((string("Could not get cross validation predictions. Error: ") + e.what()).c_str());
|
174
|
+
// }
|
175
|
+
|
176
|
+
// return RubyUtils::toValue(predictions);
|
177
|
+
// }
|
178
|
+
|
179
|
+
// OtInterface::VALUE zenith_gbm_minObservations(VALUE self, VALUE minObservations)
|
180
|
+
// {
|
181
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
182
|
+
// gbm->parameters->minObservations = RubyUtils::fromValue<int>(minObservations);
|
183
|
+
// return TOtRubyInterface::Qnil;
|
184
|
+
// }
|
185
|
+
|
186
|
+
// OtInterface::VALUE zenith_gbm_setDistribution(VALUE self, VALUE rb_distribution)
|
187
|
+
// {
|
188
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
189
|
+
// string distribution = stringToLower(RubyUtils::fromValue<string>(rb_distribution));
|
190
|
+
|
191
|
+
// if (distribution == "bernoulli")
|
192
|
+
// gbm->parameters->distribution = BERNOULLI;
|
193
|
+
// else if (distribution == "gaussian")
|
194
|
+
// gbm->parameters->distribution = GAUSSIAN;
|
195
|
+
// else
|
196
|
+
// throw std::invalid_argument("ZenithGBM::distribution = " + distribution);
|
197
|
+
|
198
|
+
// return TOtRubyInterface::Qnil;
|
199
|
+
// }
|
200
|
+
|
201
|
+
// OtInterface::VALUE zenith_gbm_verbose(VALUE self, VALUE rb_verbose)
|
202
|
+
// {
|
203
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
204
|
+
// bool verbose = RubyUtils::fromValue<bool>(rb_verbose);
|
205
|
+
// gbm->parameters->verbose = verbose;
|
206
|
+
// return TOtRubyInterface::Qnil;
|
207
|
+
// }
|
208
|
+
|
209
|
+
// OtInterface::VALUE zenith_gbm_setGreedy( VALUE self, VALUE rb_greedy )
|
210
|
+
// {
|
211
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
212
|
+
// bool greedy = RubyUtils::fromValue<bool>(rb_greedy);
|
213
|
+
// gbm->parameters->greedy = greedy;
|
214
|
+
// return TOtRubyInterface::Qnil;
|
215
|
+
// }
|
216
|
+
|
217
|
+
// OtInterface::VALUE zenith_gbm_setRfToLevel( VALUE self, VALUE rb_rfToLevel )
|
218
|
+
// {
|
219
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
220
|
+
// int rfToLevel = RubyUtils::fromValue<int>(rb_rfToLevel);
|
221
|
+
// gbm->parameters->rfToLevel = rfToLevel;
|
222
|
+
// return TOtRubyInterface::Qnil;
|
223
|
+
// }
|
224
|
+
|
225
|
+
// OtInterface::VALUE zenith_gbm_capTrees( VALUE self, VALUE rb_cap )
|
226
|
+
// {
|
227
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
228
|
+
// int cap = RubyUtils::fromValue<int>(rb_cap);
|
229
|
+
// gbm->capTrees(cap);
|
230
|
+
// return TOtRubyInterface::Qnil;
|
231
|
+
// }
|
232
|
+
|
233
|
+
// OtInterface::VALUE zenith_gbm_setScale( VALUE self, VALUE rb_scale )
|
234
|
+
// {
|
235
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
236
|
+
// double scale = RubyUtils::fromValue<double>(rb_scale);
|
237
|
+
// gbm->parameters->scale = scale;;
|
238
|
+
// return TOtRubyInterface::Qnil;
|
239
|
+
// }
|
240
|
+
|
241
|
+
// OtInterface::VALUE splitValueOrCategories(MLData* data, shared_ptr<SplitDefinition> splitDefinition, Partition partition)
|
242
|
+
// {
|
243
|
+
// VALUE returnValue;
|
244
|
+
// if (splitDefinition->isCategorical())
|
245
|
+
// {
|
246
|
+
// // categories as an array
|
247
|
+
// if (partition == LHS)
|
248
|
+
// returnValue = RubyUtils::toValue(splitDefinition->getLhsCategories());
|
249
|
+
// else if (partition == RHS)
|
250
|
+
// returnValue = RubyUtils::toValue(splitDefinition->getRhsCategories());
|
251
|
+
// else if (partition == MISSING)
|
252
|
+
// {
|
253
|
+
// set<double> setMissingValue;
|
254
|
+
// setMissingValue.insert(data->getMissingValue());
|
255
|
+
// returnValue = RubyUtils::toValue(setMissingValue);
|
256
|
+
// }
|
257
|
+
// else
|
258
|
+
// throw std::runtime_error("Primary partition should be either LHS, RHS or MISSING!");
|
259
|
+
// }
|
260
|
+
// else
|
261
|
+
// {
|
262
|
+
// // split value as a double
|
263
|
+
// returnValue = RubyUtils::toValue(splitDefinition->getSplitValue());
|
264
|
+
// }
|
265
|
+
// return returnValue;
|
266
|
+
// }
|
267
|
+
|
268
|
+
// OtInterface::VALUE zenith_gbm_getFeatureInteractions( VALUE self, VALUE howMany )
|
269
|
+
// {
|
270
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
271
|
+
// vector<FeatureInteraction> featureInteractions;
|
272
|
+
|
273
|
+
// try
|
274
|
+
// {
|
275
|
+
// featureInteractions = gbm->getFeatureInteractions(RubyUtils::fromValue<int>(howMany));
|
276
|
+
// }
|
277
|
+
// catch (std::exception e)
|
278
|
+
// {
|
279
|
+
// vlcMessage.Raise((string("Could not get calculate interactions. Error: ") + e.what()).c_str());
|
280
|
+
// }
|
281
|
+
// MLData* data = gbm->getData();
|
282
|
+
// vector<string> featureNames = data->getFeatures();
|
283
|
+
|
284
|
+
// vector<vector<VALUE> > returnVector;
|
285
|
+
// returnVector.reserve(featureInteractions.size());
|
286
|
+
|
287
|
+
// BOOST_FOREACH(auto& interaction, featureInteractions)
|
288
|
+
// {
|
289
|
+
// vector<VALUE> v;
|
290
|
+
// v.reserve(6);
|
291
|
+
|
292
|
+
// // improvement
|
293
|
+
// v.push_back(RubyUtils::toValue(interaction.secondarySplitDefinition->getImprovement()));
|
294
|
+
|
295
|
+
// // primary feature name
|
296
|
+
// v.push_back(RubyUtils::toValue(featureNames.at(interaction.primarySplitDefinition->getFeatureIndex())));
|
297
|
+
|
298
|
+
// // either the split value (as double), or the categories
|
299
|
+
// v.push_back(splitValueOrCategories(data, interaction.primarySplitDefinition, interaction.primaryPartition));
|
300
|
+
|
301
|
+
// // the partition chosen
|
302
|
+
// Partition p = interaction.primaryPartition;
|
303
|
+
// int partition = (p == LHS ? 1 : (p == RHS ? 2 : 3));
|
304
|
+
// v.push_back(RubyUtils::toValue(partition));
|
305
|
+
|
306
|
+
// // second feature name
|
307
|
+
// v.push_back(RubyUtils::toValue(featureNames.at(interaction.secondarySplitDefinition->getFeatureIndex())));
|
308
|
+
|
309
|
+
// // secondary split value / left hand side categories
|
310
|
+
// v.push_back(splitValueOrCategories(data, interaction.secondarySplitDefinition, LHS));
|
311
|
+
|
312
|
+
// // no need for a secondary partition, as all children of the second partition are important
|
313
|
+
// returnVector.push_back(v);
|
314
|
+
// }
|
315
|
+
// return RubyUtils::toValue(returnVector);
|
316
|
+
// }
|
317
|
+
|