mlpack 3.4.2
Loading...
Searching...
No Matches
continuous_double_pole_cart.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
15#define MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
16
17#include <mlpack/prereqs.hpp>
18
19namespace mlpack {
20namespace rl {
21
29{
30 public:
36 class State
37 {
38 public:
42 State() : data(dimension)
43 { /* Nothing to do here. */ }
44
50 State(const arma::colvec& data) : data(data)
51 { /* Nothing to do here */ }
52
54 arma::colvec Data() const { return data; }
56 arma::colvec& Data() { return data; }
57
59 double Position() const { return data[0]; }
61 double& Position() { return data[0]; }
62
64 double Velocity() const { return data[1]; }
66 double& Velocity() { return data[1]; }
67
69 double Angle(const size_t i) const { return data[2 * i]; }
71 double& Angle(const size_t i) { return data[2 * i]; }
72
74 double AngularVelocity(const size_t i) const { return data[2 * i + 1]; }
76 double& AngularVelocity(const size_t i) { return data[2 * i + 1]; }
77
79 const arma::colvec& Encode() const { return data; }
80
82 static constexpr size_t dimension = 6;
83
84 private:
86 arma::colvec data;
87 };
88
92 struct Action
93 {
94 double action[1];
95 // Storing degree of freedom
96 const int size = 1;
97 };
98
116 ContinuousDoublePoleCart(const double m1 = 0.1,
117 const double m2 = 0.01,
118 const double l1 = 0.5,
119 const double l2 = 0.05,
120 const double gravity = 9.8,
121 const double massCart = 1.0,
122 const double forceMag = 10.0,
123 const double tau = 0.02,
124 const double thetaThresholdRadians = 36 * 2 *
125 3.1416 / 360,
126 const double xThreshold = 2.4,
127 const double doneReward = 0.0,
128 const size_t maxSteps = 0) :
129 m1(m1),
130 m2(m2),
131 l1(l1),
132 l2(l2),
133 gravity(gravity),
134 massCart(massCart),
135 forceMag(forceMag),
136 tau(tau),
137 thetaThresholdRadians(thetaThresholdRadians),
138 xThreshold(xThreshold),
139 doneReward(doneReward),
140 maxSteps(maxSteps),
141 stepsPerformed(0)
142 { /* Nothing to do here */ }
143
153 double Sample(const State& state,
154 const Action& action,
155 State& nextState)
156 {
157 // Update the number of steps performed.
158 stepsPerformed++;
159
160 arma::vec dydx(6, arma::fill::zeros);
161 dydx[0] = state.Velocity();
162 dydx[2] = state.AngularVelocity(1);
163 dydx[4] = state.AngularVelocity(2);
164 Dsdt(state, action, dydx);
165 RK4(state, action, dydx, nextState);
166
167 // Check if the episode has terminated.
168 bool done = IsTerminal(nextState);
169
170 // Do not reward agent if it failed.
171 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
172 return doneReward;
173 else if (done)
174 return 0;
175
180 return 1.0;
181 }
182
191 void Dsdt(const State& state,
192 const Action& action,
193 arma::vec& dydx)
194 {
195 double totalForce = action.action[0];
196 double totalMass = massCart;
197 double omega1 = state.AngularVelocity(1);
198 double omega2 = state.AngularVelocity(2);
199 double sinTheta1 = std::sin(state.Angle(1));
200 double sinTheta2 = std::sin(state.Angle(2));
201 double cosTheta1 = std::cos(state.Angle(1));
202 double cosTheta2 = std::cos(state.Angle(2));
203
204 // Calculate total effective force.
205 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
206 std::sin(2 * state.Angle(1));
207 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
208 std::sin(2 * state.Angle(2));
209
210 // Calculate total effective mass.
211 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
212 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
213
214 // Calculate acceleration.
215 double xAcc = totalForce / totalMass;
216 dydx[1] = xAcc;
217
218 // Calculate angular acceleration.
219 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
220 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
221 }
222
232 void RK4(const State& state,
233 const Action& action,
234 arma::vec& dydx,
235 State& nextState)
236 {
237 const double hh = tau * 0.5;
238 const double h6 = tau / 6;
239 arma::vec yt(6);
240 arma::vec dyt(6);
241 arma::vec dym(6);
242
243 yt = state.Data() + (hh * dydx);
244 Dsdt(State(yt), action, dyt);
245 dyt[0] = yt[1];
246 dyt[2] = yt[3];
247 dyt[4] = yt[5];
248 yt = state.Data() + (hh * dyt);
249
250 Dsdt(State(yt), action, dym);
251 dym[0] = yt[1];
252 dym[2] = yt[3];
253 dym[4] = yt[5];
254 yt = state.Data() + (tau * dym);
255 dym += dyt;
256
257 Dsdt(State(yt), action, dyt);
258 dyt[0] = yt[1];
259 dyt[2] = yt[3];
260 dyt[4] = yt[5];
261 nextState.Data() = state.Data() + h6 * (dydx + dyt + 2 * dym);
262 }
263
272 double Sample(const State& state, const Action& action)
273 {
274 State nextState;
275 return Sample(state, action, nextState);
276 }
277
284 {
285 stepsPerformed = 0;
286 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
287 }
288
295 bool IsTerminal(const State& state) const
296 {
297 if (maxSteps != 0 && stepsPerformed >= maxSteps)
298 {
299 Log::Info << "Episode terminated due to the maximum number of steps"
300 "being taken.";
301 return true;
302 }
303 if (std::abs(state.Position()) > xThreshold)
304 {
305 Log::Info << "Episode terminated due to cart crossing threshold";
306 return true;
307 }
308 if (std::abs(state.Angle(1)) > thetaThresholdRadians ||
309 std::abs(state.Angle(2)) > thetaThresholdRadians)
310 {
311 Log::Info << "Episode terminated due to pole falling";
312 return true;
313 }
314 return false;
315 }
316
318 size_t StepsPerformed() const { return stepsPerformed; }
319
321 size_t MaxSteps() const { return maxSteps; }
323 size_t& MaxSteps() { return maxSteps; }
324
325 private:
327 double m1;
328
330 double m2;
331
333 double l1;
334
336 double l2;
337
339 double gravity;
340
342 double massCart;
343
345 double forceMag;
346
348 double tau;
349
351 double thetaThresholdRadians;
352
354 double xThreshold;
355
357 double doneReward;
358
360 size_t maxSteps;
361
363 size_t stepsPerformed;
364};
365
366} // namespace rl
367} // namespace mlpack
368
369#endif
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
Implementation of the state of Continuous Double Pole Cart.
const arma::colvec & Encode() const
Encode the state to a vector..
arma::colvec Data() const
Get the internal representation of the state.
double & Velocity()
Modify the velocity of the cart.
double Velocity() const
Get the velocity of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
State(const arma::colvec &data)
Construct a state instance from given data.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double & Position()
Modify the position of the cart.
double Position() const
Get the position of the cart.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
static constexpr size_t dimension
Dimension of the encoded state.
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of Continuous Double Pole Cart Balancing task.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Continuous Double Pole Cart instance.
size_t & MaxSteps()
Set the maximum number of steps allowed.
size_t StepsPerformed() const
Get the number of steps performed.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
ContinuousDoublePoleCart(const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Double Pole Cart instance using the given constants.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
size_t MaxSteps() const
Get the maximum number of steps allowed.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method.
double Sample(const State &state, const Action &action)
Dynamics of Continuous Double Pole Cart.
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of action of Continuous Double Pole Cart.