mlpack 3.4.2
Loading...
Searching...
No Matches
hmm.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_HMM_HMM_HPP
15#define MLPACK_METHODS_HMM_HMM_HPP
16
17#include <mlpack/prereqs.hpp>
19
20namespace mlpack {
21namespace hmm {
22
84template<typename Distribution = distribution::DiscreteDistribution>
85class HMM
86{
87 public:
105 HMM(const size_t states = 0,
106 const Distribution emissions = Distribution(),
107 const double tolerance = 1e-5);
108
136 HMM(const arma::vec& initial,
137 const arma::mat& transition,
138 const std::vector<Distribution>& emission,
139 const double tolerance = 1e-5);
140
169 double Train(const std::vector<arma::mat>& dataSeq);
170
191 void Train(const std::vector<arma::mat>& dataSeq,
192 const std::vector<arma::Row<size_t> >& stateSeq);
193
212 double LogEstimate(const arma::mat& dataSeq,
213 arma::mat& stateLogProb,
214 arma::mat& forwardLogProb,
215 arma::mat& backwardLogProb,
216 arma::vec& logScales) const;
217
236 double Estimate(const arma::mat& dataSeq,
237 arma::mat& stateProb,
238 arma::mat& forwardProb,
239 arma::mat& backwardProb,
240 arma::vec& scales) const;
241
253 double Estimate(const arma::mat& dataSeq,
254 arma::mat& stateProb) const;
255
267 void Generate(const size_t length,
268 arma::mat& dataSequence,
269 arma::Row<size_t>& stateSequence,
270 const size_t startState = 0) const;
271
282 double Predict(const arma::mat& dataSeq,
283 arma::Row<size_t>& stateSeq) const;
284
291 double LogLikelihood(const arma::mat& dataSeq) const;
292
305 void Filter(const arma::mat& dataSeq,
306 arma::mat& filterSeq,
307 size_t ahead = 0) const;
308
320 void Smooth(const arma::mat& dataSeq,
321 arma::mat& smoothSeq) const;
322
324 const arma::vec& Initial() const { return initialProxy; }
326 arma::vec& Initial()
327 {
328 recalculateInitial = true;
329 return initialProxy;
330 }
331
333 const arma::mat& Transition() const { return transitionProxy; }
335 arma::mat& Transition()
336 {
337 recalculateTransition = true;
338 return transitionProxy;
339 }
340
342 const std::vector<Distribution>& Emission() const { return emission; }
344 std::vector<Distribution>& Emission() { return emission; }
345
347 size_t Dimensionality() const { return dimensionality; }
349 size_t& Dimensionality() { return dimensionality; }
350
352 double Tolerance() const { return tolerance; }
354 double& Tolerance() { return tolerance; }
355
359 template<typename Archive>
360 void load(Archive& ar, const unsigned int version);
361
365 template<typename Archive>
366 void save(Archive& ar, const unsigned int version) const;
367
369
370
371 protected:
372 // Helper functions.
383 void Forward(const arma::mat& dataSeq,
384 arma::vec& logScales,
385 arma::mat& forwardLogProb) const;
386
398 void Backward(const arma::mat& dataSeq,
399 const arma::vec& logScales,
400 arma::mat& backwardLogProb) const;
401
403 std::vector<Distribution> emission;
404
410
412 mutable arma::mat logTransition;
413
414 private:
420 void ConvertToLogSpace() const;
421
426 arma::vec initialProxy;
427
429 mutable arma::vec logInitial;
430
432 size_t dimensionality;
433
435 double tolerance;
436
441 mutable bool recalculateInitial;
442
447 mutable bool recalculateTransition;
448};
449
450} // namespace hmm
451} // namespace mlpack
452
453// Include implementation.
454#include "hmm_impl.hpp"
455
456#endif
A class that represents a Hidden Markov Model with an arbitrary type of emission distribution.
Definition: hmm.hpp:86
void Smooth(const arma::mat &dataSeq, arma::mat &smoothSeq) const
HMM smoothing.
double Train(const std::vector< arma::mat > &dataSeq)
Train the model using the Baum-Welch algorithm, with only the given unlabeled observations.
HMM(const size_t states=0, const Distribution emissions=Distribution(), const double tolerance=1e-5)
Create the Hidden Markov Model with the given number of hidden states and the given default distribut...
const std::vector< Distribution > & Emission() const
Return the emission distributions.
Definition: hmm.hpp:342
const arma::mat & Transition() const
Return the transition matrix.
Definition: hmm.hpp:333
void Train(const std::vector< arma::mat > &dataSeq, const std::vector< arma::Row< size_t > > &stateSeq)
Train the model using the given labeled observations; the transition and emission matrices are direct...
void Backward(const arma::mat &dataSeq, const arma::vec &logScales, arma::mat &backwardLogProb) const
The Backward algorithm (part of the Forward-Backward algorithm).
double Predict(const arma::mat &dataSeq, arma::Row< size_t > &stateSeq) const
Compute the most probable hidden state sequence for the given data sequence, using the Viterbi algori...
arma::mat logTransition
Transition probability matrix. No need to be mutable in mlpack 4.0.
Definition: hmm.hpp:412
size_t & Dimensionality()
Set the dimensionality of observations.
Definition: hmm.hpp:349
const arma::vec & Initial() const
Return the vector of initial state probabilities.
Definition: hmm.hpp:324
void save(Archive &ar, const unsigned int version) const
Save the object.
void load(Archive &ar, const unsigned int version)
Load the object.
arma::mat & Transition()
Return a modifiable transition matrix reference.
Definition: hmm.hpp:335
double Estimate(const arma::mat &dataSeq, arma::mat &stateProb, arma::mat &forwardProb, arma::mat &backwardProb, arma::vec &scales) const
Estimate the probabilities of each hidden state at each time step for each given data observation,...
size_t Dimensionality() const
Get the dimensionality of observations.
Definition: hmm.hpp:347
double & Tolerance()
Modify the tolerance of the Baum-Welch algorithm.
Definition: hmm.hpp:354
double Tolerance() const
Get the tolerance of the Baum-Welch algorithm.
Definition: hmm.hpp:352
arma::vec & Initial()
Modify the vector of initial state probabilities.
Definition: hmm.hpp:326
arma::mat transitionProxy
A proxy variable in linear space for logTransition.
Definition: hmm.hpp:409
double LogLikelihood(const arma::mat &dataSeq) const
Compute the log-likelihood of the given data sequence.
std::vector< Distribution > emission
Set of emission probability distributions; one for each state.
Definition: hmm.hpp:403
BOOST_SERIALIZATION_SPLIT_MEMBER()
void Generate(const size_t length, arma::mat &dataSequence, arma::Row< size_t > &stateSequence, const size_t startState=0) const
Generate a random data sequence of the given length.
void Forward(const arma::mat &dataSeq, arma::vec &logScales, arma::mat &forwardLogProb) const
The Forward algorithm (part of the Forward-Backward algorithm).
void Filter(const arma::mat &dataSeq, arma::mat &filterSeq, size_t ahead=0) const
HMM filtering.
double Estimate(const arma::mat &dataSeq, arma::mat &stateProb) const
Estimate the probabilities of each hidden state at each time step of each given data observation,...
HMM(const arma::vec &initial, const arma::mat &transition, const std::vector< Distribution > &emission, const double tolerance=1e-5)
Create the Hidden Markov Model with the given initial probability vector, the given transition matrix...
std::vector< Distribution > & Emission()
Return a modifiable emission probability matrix reference.
Definition: hmm.hpp:344
double LogEstimate(const arma::mat &dataSeq, arma::mat &stateLogProb, arma::mat &forwardLogProb, arma::mat &backwardLogProb, arma::vec &logScales) const
Estimate the probabilities of each hidden state at each time step for each given data observation,...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.