mlpack 3.4.2
Loading...
Searching...
No Matches
async_learning.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_RL_ASYNC_LEARNING_HPP
15#define MLPACK_METHODS_RL_ASYNC_LEARNING_HPP
16
17#include <mlpack/prereqs.hpp>
21#include "training_config.hpp"
22
23namespace mlpack {
24namespace rl {
25
50template <
51 typename WorkerType,
52 typename EnvironmentType,
53 typename NetworkType,
54 typename UpdaterType,
55 typename PolicyType
56>
58{
59 public:
70 NetworkType network,
71 PolicyType policy,
72 UpdaterType updater = UpdaterType(),
73 EnvironmentType environment = EnvironmentType());
74
88 template <typename Measure>
89 void Train(Measure& measure);
90
92 TrainingConfig& Config() { return config; }
94 const TrainingConfig& Config() const { return config; }
95
97 NetworkType& Network() { return learningNetwork; }
99 const NetworkType& Network() const { return learningNetwork; }
100
102 PolicyType& Policy() { return policy; }
104 const PolicyType& Policy() const { return policy; }
105
107 UpdaterType& Updater() { return updater; }
109 const UpdaterType& Updater() const { return updater; }
110
112 EnvironmentType& Environment() { return environment; }
114 const EnvironmentType& Environment() const { return environment; }
115
116 private:
118 TrainingConfig config;
119
121 NetworkType learningNetwork;
122
124 PolicyType policy;
125
127 UpdaterType updater;
128
130 EnvironmentType environment;
131};
132
141template <
142 typename EnvironmentType,
143 typename NetworkType,
144 typename UpdaterType,
145 typename PolicyType
146>
147class OneStepQLearningWorker;
148
157template <
158 typename EnvironmentType,
159 typename NetworkType,
160 typename UpdaterType,
161 typename PolicyType
162>
163class OneStepSarsaWorker;
164
173template <
174 typename EnvironmentType,
175 typename NetworkType,
176 typename UpdaterType,
177 typename PolicyType
178>
179class NStepQLearningWorker;
180
189template <
190 typename EnvironmentType,
191 typename NetworkType,
192 typename UpdaterType,
193 typename PolicyType
194>
196 NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
197 UpdaterType, PolicyType>;
198
207template <
208 typename EnvironmentType,
209 typename NetworkType,
210 typename UpdaterType,
211 typename PolicyType
212>
214 NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
215 UpdaterType, PolicyType>;
216
225template <
226 typename EnvironmentType,
227 typename NetworkType,
228 typename UpdaterType,
229 typename PolicyType
230>
232 NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
233 UpdaterType, PolicyType>;
234
235} // namespace rl
236} // namespace mlpack
237
238// Include implementation
239#include "async_learning_impl.hpp"
240
241#endif
Wrapper of various asynchronous learning algorithms, e.g.
const TrainingConfig & Config() const
Modify training config.
AsyncLearning(TrainingConfig config, NetworkType network, PolicyType policy, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Construct an instance of the given async learning algorithm.
NetworkType & Network()
Get learning network.
TrainingConfig & Config()
Get training config.
const PolicyType & Policy() const
Modify behavior policy.
PolicyType & Policy()
Get behavior policy.
const UpdaterType & Updater() const
Modify optimizer.
void Train(Measure &measure)
Starting async training.
UpdaterType & Updater()
Get optimizer.
EnvironmentType & Environment()
Get the environment.
const EnvironmentType & Environment() const
Modify the environment.
const NetworkType & Network() const
Modify learning network.
Forward declaration of NStepQLearningWorker.
Forward declaration of OneStepQLearningWorker.
Forward declaration of OneStepSarsaWorker.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.