1 #ifndef RF_EARLY_STOPPING_P_HXX 2 #define RF_EARLY_STOPPING_P_HXX 4 #include "rf_common.hxx" 13 T
power(T
const & in,
int n)
15 T result = NumericTraits<T>::one();
16 for(
int ii = 0; ii < n ;++ii)
34 void set_external_parameters(
ProblemSpec<T> const &prob,
int tree_count = 0,
bool is_weighted =
false)
37 is_weighted_ = is_weighted;
38 tree_count_ = tree_count;
48 template<
class WeightIter,
class T,
class C>
51 template<
class WeightIter,
class T,
class C>
78 void set_external_parameters(
ProblemSpec<T> const &prob,
int tree_count = 0,
bool is_weighted =
false)
80 max_tree_ =
ceil(max_tree_p * tree_count);
81 SB::set_external_parameters(prob, tree_count, is_weighted);
84 template<
class WeightIter,
class T,
class C>
87 if(k == SB::tree_count_ -1)
89 depths.push_back(
double(k+1)/
double(SB::tree_count_));
94 depths.push_back(
double(k+1)/
double(SB::tree_count_));
116 proportion_(proportion)
119 template<
class WeightIter,
class T,
class C>
122 if(k == SB::tree_count_ -1)
124 depths.push_back(
double(k+1)/
double(SB::tree_count_));
131 if(prob[
argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
133 depths.push_back(
double(k+1)/
double(SB::tree_count_));
139 if(prob[
argMax(prob)] > proportion_ * SB::tree_count_)
141 depths.push_back(
double(k+1)/
double(SB::tree_count_));
174 void set_external_parameters(
ProblemSpec<T> const &prob,
int tree_count = 0,
bool is_weighted =
false)
178 SB::set_external_parameters(prob, tree_count, is_weighted);
180 template<
class WeightIter,
class T,
class C>
183 if(k == SB::tree_count_ -1)
185 depths.push_back(
double(k+1)/
double(SB::tree_count_));
191 last_/= last_.
norm(1);
197 cur_ /= cur_.
norm(1);
199 double nrm = last_.
norm();
202 depths.push_back(
double(k+1)/
double(SB::tree_count_));
232 proportion_(proportion)
235 template<
class WeightIter,
class T,
class C>
238 if(k == SB::tree_count_ -1)
240 depths.push_back(
double(k+1)/
double(SB::tree_count_));
244 double a = prob[
argMax(prob)];
246 double b = prob[
argMax(prob)];
248 double margin = a - b;
251 if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
253 depths.push_back(
double(k+1)/
double(SB::tree_count_));
259 if(prob[
argMax(prob)] > proportion_ * SB::tree_count_)
261 depths.push_back(
double(k+1)/
double(SB::tree_count_));
299 double binomial(
int N,
int k,
double p)
302 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
305 template<
class WeightIter,
class T,
class C>
308 if(k == SB::tree_count_ -1)
310 depths.push_back(
double(k+1)/
double(SB::tree_count_));
318 int n_a = prob[index];
319 int n_b = prob[(index+1)%2];
320 int n_tilde = (SB::tree_count_ - n_a + n_b);
321 double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
322 vigra_precondition(p_a <= 1,
"probability should be smaller than 1");
328 for(
int ii = 0; ii <= n_b + n_a;++ii)
331 cum_val += binomial(n_b + n_a, ii, p_a);
332 if(cum_val >= 1 -alpha_)
341 depths.push_back(
double(k+1)/
double(SB::tree_count_));
379 double binomial(
int N,
int k,
double p)
382 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
385 template<
class WeightIter,
class T,
class C>
388 if(k == SB::tree_count_ -1)
390 depths.push_back(
double(k+1)/
double(SB::tree_count_));
398 int n_a = prob[index];
399 int n_b = prob[(index+1)%2];
400 int n_needed =
ceil(
double(SB::tree_count_)/2.0)-n_a;
401 int n_tilde = SB::tree_count_ - (n_a +n_b);
402 if(n_tilde <= 0) n_tilde = 0;
403 if(n_needed <= 0) n_needed = 0;
405 for(
int ii = n_needed; ii < n_tilde; ++ii)
406 p += binomial(n_tilde, ii, 0.5);
410 depths.push_back(
double(k+1)/
double(SB::tree_count_));
418 #endif //RF_EARLY_STOPPING_P_HXX Definition: rf_earlystopping.hxx:278
problem specification class for the random forest.
Definition: rf_common.hxx:533
StopAfterVoteCount(double proportion)
Definition: rf_earlystopping.hxx:114
V power(const V &x)
Exponentiation to a positive integer power by squaring.
Definition: mathutil.hxx:389
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2738
ArrayVector< double > depths
Definition: rf_earlystopping.hxx:377
Definition: accessor.hxx:43
NormTraits< MultiArrayView >::NormType norm(int type=2, bool useSquaredNorm=true) const
Definition: multi_array.hxx:2255
Definition: rf_earlystopping.hxx:60
StopIfConverging(double thresh, int num=10)
Definition: rf_earlystopping.hxx:167
StopIfProb(double alpha, MultiArrayView< 2, double > nck_)
Definition: rf_earlystopping.hxx:369
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
Definition: rf_earlystopping.hxx:152
Definition: rf_earlystopping.hxx:25
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
Definition: rf_earlystopping.hxx:104
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:655
Definition: rf_earlystopping.hxx:220
StopAfterTree(double max_tree)
Definition: rf_earlystopping.hxx:72
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition: fixedpoint.hxx:675
Definition: rf_earlystopping.hxx:357
ArrayVector< double > depths
Definition: rf_earlystopping.hxx:297
StopIfBinTest(double alpha, MultiArrayView< 2, double > nck_)
Definition: rf_earlystopping.hxx:288
StopIfMargin(double proportion)
Definition: rf_earlystopping.hxx:230