mlpack 3.4.2
Loading...
Searching...
No Matches
rp_tree_mean_split.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
14#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
15
16#include <mlpack/prereqs.hpp>
17#include "rp_tree_max_split.hpp"
20
21namespace mlpack {
22namespace tree {
23
32template<typename BoundType, typename MatType = arma::mat>
34{
35 public:
37 typedef typename MatType::elem_type ElemType;
39 struct SplitInfo
40 {
42 arma::Col<ElemType> direction;
44 arma::Col<ElemType> mean;
50 };
51
64 static bool SplitNode(const BoundType& /* bound */,
65 MatType& data,
66 const size_t begin,
67 const size_t count,
68 SplitInfo& splitInfo);
69
82 static size_t PerformSplit(MatType& data,
83 const size_t begin,
84 const size_t count,
85 const SplitInfo& splitInfo)
86 {
87 return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
88 splitInfo);
89 }
90
106 static size_t PerformSplit(MatType& data,
107 const size_t begin,
108 const size_t count,
109 const SplitInfo& splitInfo,
110 std::vector<size_t>& oldFromNew)
111 {
112 return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
113 splitInfo, oldFromNew);
114 }
115
122 template<typename VecType>
123 static bool AssignToLeftNode(const VecType& point, const SplitInfo& splitInfo)
124 {
125 if (splitInfo.meanSplit)
126 return arma::dot(point - splitInfo.mean, point - splitInfo.mean) <=
127 splitInfo.splitVal;
128
129 return (arma::dot(point, splitInfo.direction) <= splitInfo.splitVal);
130 }
131
132 private:
139 static ElemType GetAveragePointDistance(MatType& data,
140 const arma::uvec& samples);
141
151 static bool GetDotMedian(const MatType& data,
152 const arma::uvec& samples,
153 const arma::Col<ElemType>& direction,
154 ElemType& splitVal);
155
165 static bool GetMeanMedian(const MatType& data,
166 const arma::uvec& samples,
167 arma::Col<ElemType>& mean,
168 ElemType& splitVal);
169};
170
171} // namespace tree
172} // namespace mlpack
173
174// Include implementation.
175#include "rp_tree_mean_split_impl.hpp"
176
177#endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
This class splits a binary space tree.
MatType::elem_type ElemType
The element type held by the matrix type.
static bool SplitNode(const BoundType &, MatType &data, const size_t begin, const size_t count, SplitInfo &splitInfo)
Split the node according to the mean value in the dimension with maximum width.
static size_t PerformSplit(MatType &data, const size_t begin, const size_t count, const SplitInfo &splitInfo)
Perform the split process according to the information about the split.
static bool AssignToLeftNode(const VecType &point, const SplitInfo &splitInfo)
Indicates that a point should be assigned to the left subtree.
static size_t PerformSplit(MatType &data, const size_t begin, const size_t count, const SplitInfo &splitInfo, std::vector< size_t > &oldFromNew)
Perform the split process according to the information about the split and return the list of changed...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
An information about the partition.
ElemType splitVal
The value according to which the split will be performed.
arma::Col< ElemType > mean
The mean of some sampled points.
arma::Col< ElemType > direction
The normal to the hyperplane that will split the node.
bool meanSplit
Indicates that we should use the mean split algorithm instead of the median split.