106 lines
2.3 KiB
C++
106 lines
2.3 KiB
C++
#ifndef SIM_RK4SOLVER_H
|
|
#define SIM_RK4SOLVER_H
|
|
|
|
/// \cond
|
|
// C headers
|
|
// C++ headers
|
|
#include <cmath>
|
|
#include <functional>
|
|
#include <limits>
|
|
#include <vector>
|
|
|
|
// 3rd party headers
|
|
|
|
/// \endcond
|
|
|
|
// qtrocket headers
|
|
#include "sim/DESolver.h"
|
|
#include "utils/Logger.h"
|
|
|
|
namespace sim {
|
|
|
|
template<typename... Ts>
|
|
class RK4Solver : public DESolver
|
|
{
|
|
public:
|
|
|
|
RK4Solver(Ts... funcs)
|
|
{
|
|
(odes.push_back(funcs), ...);
|
|
temp.resize(sizeof...(Ts));
|
|
|
|
}
|
|
virtual ~RK4Solver() {}
|
|
|
|
void setTimeStep(double inTs) override { dt = inTs; halfDT = dt / 2.0; }
|
|
|
|
void step(double t, const std::vector<double>& curVal, std::vector<double>& res) override
|
|
{
|
|
if(dt == std::numeric_limits<double>::quiet_NaN())
|
|
{
|
|
utils::Logger::getInstance()->error("Calling RK4Solver without setting dt first is an error");
|
|
res[0] = std::numeric_limits<double>::quiet_NaN();
|
|
}
|
|
|
|
for(size_t i = 0; i < len; ++i)
|
|
{
|
|
k1[i] = odes[i](t, curVal);
|
|
}
|
|
// compute k2 values. This involves stepping the current values forward a half-step
|
|
// based on k1, so we do the stepping first
|
|
for(size_t i = 0; i < len; ++i)
|
|
{
|
|
temp[i] = curVal[i] + k1[i]*dt / 2.0;
|
|
}
|
|
for(size_t i = 0; i < len; ++i)
|
|
{
|
|
k2[i] = odes[i](t + halfDT, temp);
|
|
}
|
|
// repeat for k3
|
|
for(size_t i = 0; i < len; ++i)
|
|
{
|
|
temp[i] = curVal[i] + k2[i]*dt / 2.0;
|
|
}
|
|
for(size_t i = 0; i < len; ++i)
|
|
{
|
|
k3[i] = odes[i](t + halfDT, temp);
|
|
}
|
|
|
|
// now k4
|
|
for(size_t i = 0; i < len; ++i)
|
|
{
|
|
temp[i] = curVal[i] + k3[i]*dt;
|
|
}
|
|
for(size_t i = 0; i < len; ++i)
|
|
{
|
|
k4[i] = odes[i](t + dt, temp);
|
|
}
|
|
|
|
// now compute the result
|
|
for(size_t i = 0; i < len; ++i)
|
|
{
|
|
res[i] = curVal[i] + (dt / 6.0)*(k1[i] + 2.0*k2[i] + 2.0*k3[i] + k4[i]);
|
|
}
|
|
|
|
}
|
|
|
|
private:
|
|
std::vector<std::function<double(double, const std::vector<double>&)>> odes;
|
|
|
|
static constexpr size_t len = sizeof...(Ts);
|
|
double k1[len];
|
|
double k2[len];
|
|
double k3[len];
|
|
double k4[len];
|
|
|
|
std::vector<double> temp;
|
|
|
|
double dt = std::numeric_limits<double>::quiet_NaN();
|
|
double halfDT = 0.0;
|
|
|
|
};
|
|
|
|
} // namespace sim
|
|
|
|
#endif // SIM_RK4SOLVER_H
|