qtrocket/sim/RK4Solver.h

115 lines
2.7 KiB
C++

#ifndef SIM_RK4SOLVER_H
#define SIM_RK4SOLVER_H
/// \cond
// C headers
// C++ headers
#include <cmath>
#include <functional>
#include <limits>
#include <vector>
// 3rd party headers
/// \endcond
// qtrocket headers
#include "sim/DESolver.h"
#include "utils/Logger.h"
namespace sim {
/**
* @brief Runge-Kutta 4th order coupled ODE solver.
* @note This was written outside of the context of QtRocket, and it is very generic. There are
* some features of this solver that are note used by QtRocket, for example, it can solve
* and arbitrarily large system of coupled ODEs, but QtRocket only makes use of a system
* of size 6 (x, y, z, xDot, yDot, zDot) at a time.
*
* @tparam Ts
*/
template<typename... Ts>
class RK4Solver : public DESolver
{
public:
RK4Solver(Ts... funcs)
{
(odes.push_back(funcs), ...);
temp.resize(sizeof...(Ts));
}
virtual ~RK4Solver() {}
void setTimeStep(double inTs) override { dt = inTs; halfDT = dt / 2.0; }
void step(const std::vector<double>& curVal, std::vector<double>& res, double t = 0.0) override
{
if(dt == std::numeric_limits<double>::quiet_NaN())
{
utils::Logger::getInstance()->error("Calling RK4Solver without setting dt first is an error");
res[0] = std::numeric_limits<double>::quiet_NaN();
}
for(size_t i = 0; i < len; ++i)
{
k1[i] = odes[i](curVal, t);
}
// compute k2 values. This involves stepping the current values forward a half-step
// based on k1, so we do the stepping first
for(size_t i = 0; i < len; ++i)
{
temp[i] = curVal[i] + k1[i]*dt / 2.0;
}
for(size_t i = 0; i < len; ++i)
{
k2[i] = odes[i](temp, t + halfDT);
}
// repeat for k3
for(size_t i = 0; i < len; ++i)
{
temp[i] = curVal[i] + k2[i]*dt / 2.0;
}
for(size_t i = 0; i < len; ++i)
{
k3[i] = odes[i](temp, t + halfDT);
}
// now k4
for(size_t i = 0; i < len; ++i)
{
temp[i] = curVal[i] + k3[i]*dt;
}
for(size_t i = 0; i < len; ++i)
{
k4[i] = odes[i](temp, t + dt);
}
// now compute the result
for(size_t i = 0; i < len; ++i)
{
res[i] = curVal[i] + (dt / 6.0)*(k1[i] + 2.0*k2[i] + 2.0*k3[i] + k4[i]);
}
}
private:
std::vector<std::function<double(const std::vector<double>&, double)>> odes;
static constexpr size_t len = sizeof...(Ts);
double k1[len];
double k2[len];
double k3[len];
double k4[len];
std::vector<double> temp;
double dt = std::numeric_limits<double>::quiet_NaN();
double halfDT = 0.0;
};
} // namespace sim
#endif // SIM_RK4SOLVER_H