mlpack 3.4.2
Loading...
Searching...
No Matches
double_pole_cart.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
14#define MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace rl {
20
28{
29 public:
35 class State
36 {
37 public:
41 State() : data(dimension)
42 { /* Nothing to do here. */ }
43
49 State(const arma::colvec& data) : data(data)
50 { /* Nothing to do here */ }
51
53 arma::colvec Data() const { return data; }
55 arma::colvec& Data() { return data; }
56
58 double Position() const { return data[0]; }
60 double& Position() { return data[0]; }
61
63 double Velocity() const { return data[1]; }
65 double& Velocity() { return data[1]; }
66
68 double Angle(const size_t i) const { return data[2 * i]; }
70 double& Angle(const size_t i) { return data[2 * i]; }
71
73 double AngularVelocity(const size_t i) const { return data[2 * i + 1]; }
75 double& AngularVelocity(const size_t i) { return data[2 * i + 1]; }
76
78 const arma::colvec& Encode() const { return data; }
79
81 static constexpr size_t dimension = 6;
82
83 private:
85 arma::colvec data;
86 };
87
91 class Action
92 {
93 public:
95 {
98 };
99 // To store the action.
101
102 // Track the size of the action space.
103 static const size_t size = 2;
104 };
105
123 DoublePoleCart(const size_t maxSteps = 0,
124 const double m1 = 0.1,
125 const double m2 = 0.01,
126 const double l1 = 0.5,
127 const double l2 = 0.05,
128 const double gravity = 9.8,
129 const double massCart = 1.0,
130 const double forceMag = 10.0,
131 const double tau = 0.02,
132 const double thetaThresholdRadians = 36 * 2 * 3.1416 / 360,
133 const double xThreshold = 2.4,
134 const double doneReward = 0.0) :
135 maxSteps(maxSteps),
136 m1(m1),
137 m2(m2),
138 l1(l1),
139 l2(l2),
140 gravity(gravity),
141 massCart(massCart),
142 forceMag(forceMag),
143 tau(tau),
144 thetaThresholdRadians(thetaThresholdRadians),
145 xThreshold(xThreshold),
146 doneReward(doneReward),
147 stepsPerformed(0)
148 { /* Nothing to do here */ }
149
159 double Sample(const State& state,
160 const Action& action,
161 State& nextState)
162 {
163 // Update the number of steps performed.
164 stepsPerformed++;
165
166 arma::vec dydx(6, arma::fill::zeros);
167 dydx[0] = state.Velocity();
168 dydx[2] = state.AngularVelocity(1);
169 dydx[4] = state.AngularVelocity(2);
170 Dsdt(state, action, dydx);
171 RK4(state, action, dydx, nextState);
172
173 // Check if the episode has terminated.
174 bool done = IsTerminal(nextState);
175
176 // Do not reward agent if it failed.
177 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
178 return doneReward;
179 else if (done)
180 return 0;
181
186 return 1.0;
187 }
188
197 void Dsdt(const State& state,
198 const Action& action,
199 arma::vec& dydx)
200 {
201 double totalForce = action.action ? forceMag : -forceMag;
202 double totalMass = massCart;
203 double omega1 = state.AngularVelocity(1);
204 double omega2 = state.AngularVelocity(2);
205 double sinTheta1 = std::sin(state.Angle(1));
206 double sinTheta2 = std::sin(state.Angle(2));
207 double cosTheta1 = std::cos(state.Angle(1));
208 double cosTheta2 = std::cos(state.Angle(2));
209
210 // Calculate total effective force.
211 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
212 std::sin(2 * state.Angle(1));
213 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
214 std::sin(2 * state.Angle(2));
215
216 // Calculate total effective mass.
217 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
218 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
219
220 // Calculate acceleration.
221 double xAcc = totalForce / totalMass;
222 dydx[1] = xAcc;
223
224 // Calculate angular acceleration.
225 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
226 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
227 }
228
238 void RK4(const State& state,
239 const Action& action,
240 arma::vec& dydx,
241 State& nextState)
242 {
243 const double hh = tau * 0.5;
244 const double h6 = tau / 6;
245 arma::vec yt(6);
246 arma::vec dyt(6);
247 arma::vec dym(6);
248
249 yt = state.Data() + (hh * dydx);
250 Dsdt(State(yt), action, dyt);
251 dyt[0] = yt[1];
252 dyt[2] = yt[3];
253 dyt[4] = yt[5];
254 yt = state.Data() + (hh * dyt);
255
256 Dsdt(State(yt), action, dym);
257 dym[0] = yt[1];
258 dym[2] = yt[3];
259 dym[4] = yt[5];
260 yt = state.Data() + (tau * dym);
261 dym += dyt;
262
263 Dsdt(State(yt), action, dyt);
264 dyt[0] = yt[1];
265 dyt[2] = yt[3];
266 dyt[4] = yt[5];
267 nextState.Data() = state.Data() + h6 * (dydx + dyt + 2 * dym);
268 }
269
278 double Sample(const State& state, const Action& action)
279 {
280 State nextState;
281 return Sample(state, action, nextState);
282 }
283
290 {
291 stepsPerformed = 0;
292 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
293 }
294
301 bool IsTerminal(const State& state) const
302 {
303 if (maxSteps != 0 && stepsPerformed >= maxSteps)
304 {
305 Log::Info << "Episode terminated due to the maximum number of steps"
306 "being taken.";
307 return true;
308 }
309 if (std::abs(state.Position()) > xThreshold)
310 {
311 Log::Info << "Episode terminated due to cart crossing threshold";
312 return true;
313 }
314 if (std::abs(state.Angle(1)) > thetaThresholdRadians ||
315 std::abs(state.Angle(2)) > thetaThresholdRadians)
316 {
317 Log::Info << "Episode terminated due to pole falling";
318 return true;
319 }
320 return false;
321 }
322
324 size_t StepsPerformed() const { return stepsPerformed; }
325
327 size_t MaxSteps() const { return maxSteps; }
329 size_t& MaxSteps() { return maxSteps; }
330
331 private:
333 size_t maxSteps;
334
336 double m1;
337
339 double m2;
340
342 double l1;
343
345 double l2;
346
348 double gravity;
349
351 double massCart;
352
354 double forceMag;
355
357 double tau;
358
360 double thetaThresholdRadians;
361
363 double xThreshold;
364
366 double doneReward;
367
369 size_t stepsPerformed;
370};
371
372} // namespace rl
373} // namespace mlpack
374
375#endif
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
Implementation of action of Double Pole Cart.
Implementation of the state of Double Pole Cart.
const arma::colvec & Encode() const
Encode the state to a vector..
arma::colvec Data() const
Get the internal representation of the state.
double & Velocity()
Modify the velocity of the cart.
double Velocity() const
Get the velocity of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
State()
Construct a state instance.
State(const arma::colvec &data)
Construct a state instance from given data.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double & Position()
Modify the position of the cart.
double Position() const
Get the position of the cart.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
static constexpr size_t dimension
Dimension of the encoded state.
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of Double Pole Cart Balancing task.
DoublePoleCart(const size_t maxSteps=0, const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0)
Construct a Double Pole Cart instance using the given constants.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Double Pole Cart instance.
size_t & MaxSteps()
Set the maximum number of steps allowed.
size_t StepsPerformed() const
Get the number of steps performed.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
size_t MaxSteps() const
Get the maximum number of steps allowed.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method.
double Sample(const State &state, const Action &action)
Dynamics of Double Pole Cart.
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.