mlpack 3.4.2
Loading...
Searching...
No Matches
diagonal_gmm.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_GMM_DIAGONAL_GMM_HPP
15#define MLPACK_METHODS_GMM_DIAGONAL_GMM_HPP
16
17#include <mlpack/prereqs.hpp>
19
20// This is the default fitting method class.
21#include "em_fit.hpp"
22
23// This is the default covariance matrix constraint.
25
26namespace mlpack {
27namespace gmm {
28
75{
76 private:
78 size_t gaussians;
80 size_t dimensionality;
81
83 std::vector<distribution::DiagonalGaussianDistribution> dists;
84
86 arma::vec weights;
87
88 public:
93 gaussians(0),
94 dimensionality(0)
95 {
96 // Warn the user. They probably don't want to do this. If this
97 // constructor is being used (because it is required by some template
98 // classes), the user should know that it is potentially dangerous.
99 Log::Debug << "DiagonalGMM::DiagonalGMM(): no parameters given;"
100 "Estimate() may fail " << "unless parameters are set." << std::endl;
101 }
102
110 DiagonalGMM(const size_t gaussians, const size_t dimensionality);
111
118 DiagonalGMM(const std::vector<distribution::DiagonalGaussianDistribution>&
119 dists, const arma::vec& weights) :
120 gaussians(dists.size()),
121 dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
122 dists(dists),
123 weights(weights) { /* Nothing to do. */ }
124
127
130
132 size_t Gaussians() const { return gaussians; }
134 size_t Dimensionality() const { return dimensionality; }
135
142 {
143 return dists[i];
144 }
145
152 {
153 return dists[i];
154 }
155
157 const arma::vec& Weights() const { return weights; }
159 arma::vec& Weights() { return weights; }
160
167 double Probability(const arma::vec& observation) const;
168
175 double LogProbability(const arma::vec& observation) const;
176
184 double Probability(const arma::vec& observation,
185 const size_t component) const;
186
194 double LogProbability(const arma::vec& observation,
195 const size_t component) const;
202 arma::vec Random() const;
203
226 template<typename FittingType = EMFit<kmeans::KMeans<>, DiagonalConstraint,
227 distribution::DiagonalGaussianDistribution>>
228 double Train(const arma::mat& observations,
229 const size_t trials = 1,
230 const bool useExistingModel = false,
231 FittingType fitter = FittingType());
232
258 template<typename FittingType = EMFit<kmeans::KMeans<>, DiagonalConstraint,
259 distribution::DiagonalGaussianDistribution>>
260 double Train(const arma::mat& observations,
261 const arma::vec& probabilities,
262 const size_t trials = 1,
263 const bool useExistingModel = false,
264 FittingType fitter = FittingType());
265
283 void Classify(const arma::mat& observations,
284 arma::Row<size_t>& labels) const;
285
289 template<typename Archive>
290 void serialize(Archive& ar, const unsigned int /* version */);
291
292 private:
302 double LogLikelihood(
303 const arma::mat& observations,
304 const std::vector<distribution::DiagonalGaussianDistribution>& dists,
305 const arma::vec& weights) const;
306};
307
308} // namespace gmm
309} // namespace mlpack
310
311// Include implementation.
312#include "diagonal_gmm_impl.hpp"
313
314#endif // MLPACK_METHODS_GMM_DIAGONAL_GMM_HPP
static MLPACK_EXPORT util::NullOutStream Debug
MLPACK_EXPORT is required for global variables, so that they are properly exported by the Windows com...
Definition: log.hpp:79
A single multivariate Gaussian distribution with diagonal covariance.
A Diagonal Gaussian Mixture Model.
size_t Gaussians() const
Return the number of Gaussians in the model.
DiagonalGMM & operator=(const DiagonalGMM &other)
Copy operator for DiagonalGMMs.
arma::vec Random() const
Return a randomly generated observation according to the probability distribution defined by this obj...
distribution::DiagonalGaussianDistribution & Component(size_t i)
Return a reference to a component distribution.
void Classify(const arma::mat &observations, arma::Row< size_t > &labels) const
Classify the given observations as being from an individual component in this DiagonalGMM.
arma::vec & Weights()
Return a reference to the a priori weights of each Gaussian.
double LogProbability(const arma::vec &observation, const size_t component) const
Return the log probability that the given observation came from the given Gaussian component in this ...
const arma::vec & Weights() const
Return a const reference to the a priori weights of each Gaussian.
double Train(const arma::mat &observations, const arma::vec &probabilities, const size_t trials=1, const bool useExistingModel=false, FittingType fitter=FittingType())
Estimate the probability distribution directly from the given observations, taking into account the p...
double LogProbability(const arma::vec &observation) const
Return the log probability that the given observation came from this distribution.
DiagonalGMM()
Create an empty Diagonal Gaussian Mixture Model, with zero gaussians.
size_t Dimensionality() const
Return the dimensionality of the model.
DiagonalGMM(const size_t gaussians, const size_t dimensionality)
Create a GMM with the given number of Gaussians, each of which have the specified dimensionality.
double Probability(const arma::vec &observation) const
Return the probability that the given observation came from this distribution.
double Train(const arma::mat &observations, const size_t trials=1, const bool useExistingModel=false, FittingType fitter=FittingType())
Estimate the probability distribution directly from the given observations, using the given algorithm...
DiagonalGMM(const DiagonalGMM &other)
Copy constructor for DiagonalGMMs.
void serialize(Archive &ar, const unsigned int)
Serialize the DiagonalGMM.
const distribution::DiagonalGaussianDistribution & Component(size_t i) const
Return a const reference to a component distribution.
DiagonalGMM(const std::vector< distribution::DiagonalGaussianDistribution > &dists, const arma::vec &weights)
Create a DiagonalGMM with the given dists and weights.
double Probability(const arma::vec &observation, const size_t component) const
Return the probability that the given observation came from the given Gaussian component in this dist...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.