mlpack 3.4.2
Loading...
Searching...
No Matches
mean_imputation.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_CORE_DATA_IMPUTE_STRATEGIES_MEAN_IMPUTATION_HPP
13#define MLPACK_CORE_DATA_IMPUTE_STRATEGIES_MEAN_IMPUTATION_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace data {
23template <typename T>
25{
26 public:
37 void Impute(arma::Mat<T>& input,
38 const T& mappedValue,
39 const size_t dimension,
40 const bool columnMajor = true)
41 {
42 double sum = 0;
43 size_t elems = 0; // excluding nan or missing target
44
45 using PairType = std::pair<size_t, size_t>;
46 // dimensions and indexes are saved as pairs inside this vector.
47 std::vector<PairType> targets;
48
49
50 // calculate number of elements and sum of them excluding mapped value or
51 // nan. while doing that, remember where mappedValue or NaN exists.
52 if (columnMajor)
53 {
54 for (size_t i = 0; i < input.n_cols; ++i)
55 {
56 if (input(dimension, i) == mappedValue ||
57 std::isnan(input(dimension, i)))
58 {
59 targets.emplace_back(dimension, i);
60 }
61 else
62 {
63 elems++;
64 sum += input(dimension, i);
65 }
66 }
67 }
68 else
69 {
70 for (size_t i = 0; i < input.n_rows; ++i)
71 {
72 if (input(i, dimension) == mappedValue ||
73 std::isnan(input(i, dimension)))
74 {
75 targets.emplace_back(i, dimension);
76 }
77 else
78 {
79 elems++;
80 sum += input(i, dimension);
81 }
82 }
83 }
84
85 if (elems == 0)
86 Log::Fatal << "it is impossible to calculate mean; no valid elements in "
87 << "the dimension" << std::endl;
88
89 // calculate mean;
90 const double mean = sum / elems;
91
92 // Now replace the calculated mean to the missing variables
93 // It only needs to loop through targets vector, not the whole matrix.
94 for (const PairType& target : targets)
95 {
96 input(target.first, target.second) = mean;
97 }
98 }
99}; // class MeanImputation
100
101} // namespace data
102} // namespace mlpack
103
104#endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
A simple mean imputation class.
void Impute(arma::Mat< T > &input, const T &mappedValue, const size_t dimension, const bool columnMajor=true)
Impute function searches through the input looking for mappedValue and replaces it with the mean of t...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.