mlpack 3.4.2
Loading...
Searching...
No Matches
information_gain.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
14#define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace tree {
20
26{
27 public:
31 template<bool UseWeights, typename CountType>
32 static double EvaluatePtr(const CountType* counts,
33 const size_t countLength,
34 const CountType totalCount)
35 {
36 double gain = 0.0;
37
38 for (size_t i = 0; i < countLength; ++i)
39 {
40 const double f = ((double) counts[i] / (double) totalCount);
41 if (f > 0.0)
42 gain += f * std::log2(f);
43 }
44
45 return gain;
46 }
47
59 template<bool UseWeights>
60 static double Evaluate(const arma::Row<size_t>& labels,
61 const size_t numClasses,
62 const arma::Row<double>& weights)
63 {
64 // Edge case: if there are no elements, the gain is zero.
65 if (labels.n_elem == 0)
66 return 0.0;
67
68 // Calculate the information gain.
69 double gain = 0.0;
70
71 // Count the number of elements in each class. Use four auxiliary vectors
72 // to exploit SIMD instructions if possible.
73 arma::vec countSpace(4 * numClasses, arma::fill::zeros);
74 arma::vec counts(countSpace.memptr(), numClasses, false, true);
75 arma::vec counts2(countSpace.memptr() + numClasses, numClasses, false,
76 true);
77 arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses, false,
78 true);
79 arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses, false,
80 true);
81
82 if (UseWeights)
83 {
84 // Sum all the weights up.
85 double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
86
87 // SIMD loop: add counts for four elements simultaneously (if the compiler
88 // manages to vectorize the loop).
89 for (size_t i = 3; i < labels.n_elem; i += 4)
90 {
91 const double weight1 = weights[i - 3];
92 const double weight2 = weights[i - 2];
93 const double weight3 = weights[i - 1];
94 const double weight4 = weights[i];
95
96 counts[labels[i - 3]] += weight1;
97 counts2[labels[i - 2]] += weight2;
98 counts3[labels[i - 1]] += weight3;
99 counts4[labels[i]] += weight4;
100
101 accWeights[0] += weight1;
102 accWeights[1] += weight2;
103 accWeights[2] += weight3;
104 accWeights[3] += weight4;
105 }
106
107 // Handle leftovers.
108 if (labels.n_elem % 4 == 1)
109 {
110 const double weight1 = weights[labels.n_elem - 1];
111 counts[labels[labels.n_elem - 1]] += weight1;
112 accWeights[0] += weight1;
113 }
114 else if (labels.n_elem % 4 == 2)
115 {
116 const double weight1 = weights[labels.n_elem - 2];
117 const double weight2 = weights[labels.n_elem - 1];
118
119 counts[labels[labels.n_elem - 2]] += weight1;
120 counts2[labels[labels.n_elem - 1]] += weight2;
121
122 accWeights[0] += weight1;
123 accWeights[1] += weight2;
124 }
125 else if (labels.n_elem % 4 == 3)
126 {
127 const double weight1 = weights[labels.n_elem - 3];
128 const double weight2 = weights[labels.n_elem - 2];
129 const double weight3 = weights[labels.n_elem - 1];
130
131 counts[labels[labels.n_elem - 3]] += weight1;
132 counts2[labels[labels.n_elem - 2]] += weight2;
133 counts3[labels[labels.n_elem - 1]] += weight3;
134
135 accWeights[0] += weight1;
136 accWeights[1] += weight2;
137 accWeights[2] += weight3;
138 }
139
140 accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
141 counts += counts2 + counts3 + counts4;
142
143 // Corner case: return 0 if no weight.
144 if (accWeights[0] == 0.0)
145 return 0.0;
146
147 for (size_t i = 0; i < numClasses; ++i)
148 {
149 const double f = ((double) counts[i] / (double) accWeights[0]);
150 if (f > 0.0)
151 gain += f * std::log2(f);
152 }
153 }
154 else
155 {
156 // SIMD loop: add counts for four elements simultaneously (if the compiler
157 // manages to vectorize the loop).
158 for (size_t i = 3; i < labels.n_elem; i += 4)
159 {
160 counts[labels[i - 3]]++;
161 counts2[labels[i - 2]]++;
162 counts3[labels[i - 1]]++;
163 counts4[labels[i]]++;
164 }
165
166 // Handle leftovers.
167 if (labels.n_elem % 4 == 1)
168 {
169 counts[labels[labels.n_elem - 1]]++;
170 }
171 else if (labels.n_elem % 4 == 2)
172 {
173 counts[labels[labels.n_elem - 2]]++;
174 counts2[labels[labels.n_elem - 1]]++;
175 }
176 else if (labels.n_elem % 4 == 3)
177 {
178 counts[labels[labels.n_elem - 3]]++;
179 counts2[labels[labels.n_elem - 2]]++;
180 counts3[labels[labels.n_elem - 1]]++;
181 }
182
183 counts += counts2 + counts3 + counts4;
184
185 for (size_t i = 0; i < numClasses; ++i)
186 {
187 const double f = ((double) counts[i] / (double) labels.n_elem);
188 if (f > 0.0)
189 gain += f * std::log2(f);
190 }
191 }
192
193 return gain;
194 }
195
203 static double Range(const size_t numClasses)
204 {
205 // The best possible case gives an information gain of 0. The worst
206 // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
207 // log2(1/n) = -log2(n). So, the range is log2(n).
208 return std::log2(numClasses);
209 }
210};
211
212} // namespace tree
213} // namespace mlpack
214
215#endif
The standard information gain criterion, used for calculating gain in decision trees.
static double EvaluatePtr(const CountType *counts, const size_t countLength, const CountType totalCount)
Evaluate the Gini impurity given a vector of class weight counts.
static double Evaluate(const arma::Row< size_t > &labels, const size_t numClasses, const arma::Row< double > &weights)
Given a set of labels, calculate the information gain of those labels.
static double Range(const size_t numClasses)
Return the range of the information gain for the given number of classes.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.