change rk4 integrator to take vector of doubles, spin off Qt gui in separate thread
This commit is contained in:
		
							parent
							
								
									1b855b2997
								
							
						
					
					
						commit
						90e5289609
					
				@ -83,24 +83,24 @@ void MainWindow::on_testButton2_clicked()
 | 
				
			|||||||
   double ts = 0.01;
 | 
					   double ts = 0.01;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   // X position/velocity. x[0] is X position, x[1] is X velocity
 | 
					   // X position/velocity. x[0] is X position, x[1] is X velocity
 | 
				
			||||||
   double x[2] = {0.0, initialVelocityX};
 | 
					   std::vector<double> x = {0.0, initialVelocityX};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   // Y position/velocity. y[0] is Y position, y[1] is Y velocity
 | 
					   // Y position/velocity. y[0] is Y position, y[1] is Y velocity
 | 
				
			||||||
   double y[2] = {0.0, initialVelocityY};
 | 
					   std::vector<double> y = {0.0, initialVelocityY};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   auto xvelODE = [mass, dragCoeff](double, double* x) -> double
 | 
					   auto xvelODE = [mass, dragCoeff](double, const std::vector<double>& x) -> double
 | 
				
			||||||
      {
 | 
					      {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
         return -dragCoeff * 1.225 * 0.00774192 / (2.0 * mass) * x[1]*x[1]; };
 | 
					         return -dragCoeff * 1.225 * 0.00774192 / (2.0 * mass) * x[1]*x[1]; };
 | 
				
			||||||
   auto xposODE = [](double, double* x) -> double { return x[1]; };
 | 
					   auto xposODE = [](double, const std::vector<double>& x) -> double { return x[1]; };
 | 
				
			||||||
   sim::RK4Solver xSolver(xposODE, xvelODE);
 | 
					   sim::RK4Solver xSolver(xposODE, xvelODE);
 | 
				
			||||||
   xSolver.setTimeStep(0.01);
 | 
					   xSolver.setTimeStep(0.01);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   auto yvelODE = [mass, dragCoeff](double, double* y) -> double
 | 
					   auto yvelODE = [mass, dragCoeff](double, const std::vector<double>& y) -> double
 | 
				
			||||||
      {
 | 
					      {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
         return -dragCoeff * 1.225 * 0.00774192 / (2.0 * mass) * y[1]*y[1] - 9.8; };
 | 
					         return -dragCoeff * 1.225 * 0.00774192 / (2.0 * mass) * y[1]*y[1] - 9.8; };
 | 
				
			||||||
   auto yposODE = [](double, double* y) -> double { return y[1]; };
 | 
					   auto yposODE = [](double, const std::vector<double>& y) -> double { return y[1]; };
 | 
				
			||||||
   sim::RK4Solver ySolver(yposODE, yvelODE);
 | 
					   sim::RK4Solver ySolver(yposODE, yvelODE);
 | 
				
			||||||
   ySolver.setTimeStep(0.01);
 | 
					   ySolver.setTimeStep(0.01);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -110,18 +110,16 @@ void MainWindow::on_testButton2_clicked()
 | 
				
			|||||||
   QTextStream cout(stdout);
 | 
					   QTextStream cout(stdout);
 | 
				
			||||||
   cout << "Initial X velocity: " << initialVelocityX << "\n";
 | 
					   cout << "Initial X velocity: " << initialVelocityX << "\n";
 | 
				
			||||||
   cout << "Initial Y velocity: " << initialVelocityY << "\n";
 | 
					   cout << "Initial Y velocity: " << initialVelocityY << "\n";
 | 
				
			||||||
   double resX[2];
 | 
					   std::vector<double> resX(2);
 | 
				
			||||||
   double resY[2];
 | 
					   std::vector<double> resY(2);
 | 
				
			||||||
   for(size_t i = 0; i < maxTs; ++i)
 | 
					   for(size_t i = 0; i < maxTs; ++i)
 | 
				
			||||||
   {
 | 
					   {
 | 
				
			||||||
      xSolver.step(i * ts, x, resX);
 | 
					      xSolver.step(i * ts, x, resX);
 | 
				
			||||||
      ySolver.step(i * ts, y, resY);
 | 
					      ySolver.step(i * ts, y, resY);
 | 
				
			||||||
      position.emplace_back(resX[0], resY[0], 0.0);
 | 
					      position.emplace_back(resX[0], resY[0], 0.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      x[0] = resX[0];
 | 
					      x = resX;
 | 
				
			||||||
      x[1] = resX[1];
 | 
					      y = resY;
 | 
				
			||||||
      y[0] = resY[0];
 | 
					 | 
				
			||||||
      y[1] = resY[1];
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
      cout << "(" << position[i].getX1() << ", " << position[i].getX2() << ")\n";
 | 
					      cout << "(" << position[i].getX1() << ", " << position[i].getX2() << ")\n";
 | 
				
			||||||
      if(y[0] < 0.0)
 | 
					      if(y[0] < 0.0)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										38
									
								
								main.cpp
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								main.cpp
									
									
									
									
									
								
							@ -4,7 +4,9 @@
 | 
				
			|||||||
#include <QLocale>
 | 
					#include <QLocale>
 | 
				
			||||||
#include <QTranslator>
 | 
					#include <QTranslator>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
int main(int argc, char *argv[])
 | 
					#include <thread>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void worker(int argc, char* argv[], int& ret)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
   QApplication a(argc, argv);
 | 
					   QApplication a(argc, argv);
 | 
				
			||||||
   a.setWindowIcon(QIcon(":/qtrocket.png"));
 | 
					   a.setWindowIcon(QIcon(":/qtrocket.png"));
 | 
				
			||||||
@ -26,5 +28,37 @@ int main(int argc, char *argv[])
 | 
				
			|||||||
   // Go!
 | 
					   // Go!
 | 
				
			||||||
   MainWindow w;
 | 
					   MainWindow w;
 | 
				
			||||||
   w.show();
 | 
					   w.show();
 | 
				
			||||||
   return a.exec();
 | 
					   ret = a.exec();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int main(int argc, char *argv[])
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					   /*
 | 
				
			||||||
 | 
					   QApplication a(argc, argv);
 | 
				
			||||||
 | 
					   a.setWindowIcon(QIcon(":/qtrocket.png"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   // Start translation component.
 | 
				
			||||||
 | 
					   // TODO: Only support US English at the moment. Anyone want to help translate?
 | 
				
			||||||
 | 
					   QTranslator translator;
 | 
				
			||||||
 | 
					   const QStringList uiLanguages = QLocale::system().uiLanguages();
 | 
				
			||||||
 | 
					   for (const QString &locale : uiLanguages)
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      const QString baseName = "qtrocket_" + QLocale(locale).name();
 | 
				
			||||||
 | 
					      if (translator.load(":/i18n/" + baseName))
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					         a.installTranslator(&translator);
 | 
				
			||||||
 | 
					         break;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   // Go!
 | 
				
			||||||
 | 
					   //MainWindow w;
 | 
				
			||||||
 | 
					   //w.show();
 | 
				
			||||||
 | 
					   //return a.exec();
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					   int ret = 0;
 | 
				
			||||||
 | 
					   std::thread guiThread(worker, argc, argv, std::ref(ret));
 | 
				
			||||||
 | 
					   guiThread.join();
 | 
				
			||||||
 | 
					   return ret;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -2,5 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
Rocket::Rocket()
 | 
					Rocket::Rocket()
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
 | 
					   propagator.setTimeStep(0.01);
 | 
				
			||||||
 | 
					   //propagator.set
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -59,6 +59,7 @@ HEADERS += \
 | 
				
			|||||||
    utils/BinMap.h \
 | 
					    utils/BinMap.h \
 | 
				
			||||||
    utils/CurlConnection.h \
 | 
					    utils/CurlConnection.h \
 | 
				
			||||||
    utils/Logger.h \
 | 
					    utils/Logger.h \
 | 
				
			||||||
 | 
					    utils/TSQueue.h \
 | 
				
			||||||
    utils/ThreadPool.h \
 | 
					    utils/ThreadPool.h \
 | 
				
			||||||
    utils/ThrustCurveAPI.h \
 | 
					    utils/ThrustCurveAPI.h \
 | 
				
			||||||
    utils/math/Constants.h \
 | 
					    utils/math/Constants.h \
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,8 @@
 | 
				
			|||||||
#ifndef SIM_DESOLVER_H
 | 
					#ifndef SIM_DESOLVER_H
 | 
				
			||||||
#define SIM_DESOLVER_H
 | 
					#define SIM_DESOLVER_H
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace sim
 | 
					namespace sim
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -11,7 +13,7 @@ public:
 | 
				
			|||||||
   virtual ~DESolver() {}
 | 
					   virtual ~DESolver() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   virtual void setTimeStep(double ts) = 0;
 | 
					   virtual void setTimeStep(double ts) = 0;
 | 
				
			||||||
   virtual void step(double t, double* curVal, double* res ) = 0;
 | 
					   virtual void step(double t, const std::vector<double>& curVal, std::vector<double>& res ) = 0;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace sim
 | 
					} // namespace sim
 | 
				
			||||||
 | 
				
			|||||||
@ -2,27 +2,56 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#include "sim/RK4Solver.h"
 | 
					#include "sim/RK4Solver.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace sim {
 | 
					namespace sim {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Propagator::Propagator()
 | 
					Propagator::Propagator()
 | 
				
			||||||
 | 
					   : integrator()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//   solver = std::make_unique<sim::DESolver>(
 | 
					
 | 
				
			||||||
//            new(RK4Solver(/* xvel */ [this](double, double* x) -> double { return })))
 | 
					   // This is a little strange, but I have to populate the integrator unique_ptr
 | 
				
			||||||
 | 
					   // with reset. make_unique() doesn't work because the compiler can't seem to
 | 
				
			||||||
 | 
					   // deduce the template parameters correctly, and I don't want to specify them
 | 
				
			||||||
 | 
					   // manually either. RK4Solver constructor is perfectly capable of deducing it's
 | 
				
			||||||
 | 
					   // template types, and it derives from DESolver, so we can just reset the unique_ptr
 | 
				
			||||||
 | 
					   // and pass it a freshly allocated RK4Solver pointer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   // The state vector has components of the form: (x, y, z, xdot, ydot, zdot)
 | 
					   // The state vector has components of the form: (x, y, z, xdot, ydot, zdot)
 | 
				
			||||||
 | 
					   integrator.reset(new RK4Solver(
 | 
				
			||||||
   integrator = std::make_unique<sim::DESolver>(new RK4Solver(
 | 
					      /* dx/dt  */ [](double, const std::vector<double>& s) -> double {return s[3]; },
 | 
				
			||||||
      /* dvx/dt */ [this](double, double* ) -> double { return getForceX() / getMass(); },
 | 
					      /* dy/dt  */ [](double, const std::vector<double>& s) -> double {return s[4]; },
 | 
				
			||||||
      /* dx/dt  */ [this](double, double* s) -> double {return s[3]; },
 | 
					      /* dz/dt  */ [](double, const std::vector<double>& s) -> double {return s[5]; },
 | 
				
			||||||
      /* dvy/dt */ [this](double, double* ) -> double { return getForceY() / getMass() },
 | 
					      /* dvx/dt */ [this](double, const std::vector<double>& ) -> double { return getForceX() / getMass(); },
 | 
				
			||||||
      /* dy/dt  */ [this](double, double* s) -> double {return s[4]; },
 | 
					      /* dvy/dt */ [this](double, const std::vector<double>& ) -> double { return getForceY() / getMass(); },
 | 
				
			||||||
      /* dvz/dt */ [this](double, double* ) -> double { return getForceZ() / getMass() },
 | 
					      /* dvz/dt */ [this](double, const std::vector<double>& ) -> double { return getForceZ() / getMass(); }));
 | 
				
			||||||
      /* dz/dt  */ [this](double, double* s) -> double {return s[5]; }));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   integrator->setTimeStep(timeStep);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Propagator::~Propagator()
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void Propagator::runUntilTerminate()
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					   while(true)
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      // nextState gets overwritten
 | 
				
			||||||
 | 
					      integrator->step(currentTime, currentState, nextState);
 | 
				
			||||||
 | 
					      std::swap(currentState, nextState);
 | 
				
			||||||
 | 
					      if(saveStates)
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					         states.push_back(currentState);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      if(currentState[1] < 0.0)
 | 
				
			||||||
 | 
					         break;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      currentTime += timeStep;
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace sim
 | 
					} // namespace sim
 | 
				
			||||||
 | 
				
			|||||||
@ -4,6 +4,7 @@
 | 
				
			|||||||
#include "sim/DESolver.h"
 | 
					#include "sim/DESolver.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <memory>
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace sim
 | 
					namespace sim
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
@ -12,23 +13,50 @@ class Propagator
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
    Propagator();
 | 
					    Propagator();
 | 
				
			||||||
 | 
					    ~Propagator();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   double getForceX();
 | 
					    void setInitialState(std::vector<double>& initialState)
 | 
				
			||||||
   double getForceY();
 | 
					    {
 | 
				
			||||||
   double getForceZ();
 | 
					       currentState = initialState;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   double getTorqueP();
 | 
					    const std::vector<double>& getCurrentState() const
 | 
				
			||||||
   double getTorqueQ();
 | 
					    {
 | 
				
			||||||
   double getTorqueR();
 | 
					       return currentState;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   double getMass();
 | 
					    void runUntilTerminate();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    void retainStates(bool s)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					       saveStates = s;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const std::vector<std::vector<double>>& getStates() const { return states; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    void setTimeStep(double ts) { timeStep = ts; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
private:
 | 
					private:
 | 
				
			||||||
 | 
					    double getForceX() { return 0.0; }
 | 
				
			||||||
 | 
					    double getForceY() { return 0.0; }
 | 
				
			||||||
 | 
					    double getForceZ() { return 0.0; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    double getTorqueP() { return 0.0; }
 | 
				
			||||||
 | 
					    double getTorqueQ() { return 0.0; }
 | 
				
			||||||
 | 
					    double getTorqueR() { return 0.0; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   double getMass() { return 0.0; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//private:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   std::unique_ptr<sim::DESolver> integrator;
 | 
					   std::unique_ptr<sim::DESolver> integrator;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   double currentState[6]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
 | 
					   std::vector<double> currentState{0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
 | 
				
			||||||
   double nextState[6]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
 | 
					   std::vector<double> nextState{0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
 | 
				
			||||||
 | 
					   bool saveStates{true};
 | 
				
			||||||
 | 
					   double currentTime{0.0};
 | 
				
			||||||
 | 
					   double timeStep{0.01};
 | 
				
			||||||
 | 
					   std::vector<std::vector<double>> states;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace sim
 | 
					} // namespace sim
 | 
				
			||||||
 | 
				
			|||||||
@ -19,13 +19,14 @@ public:
 | 
				
			|||||||
   RK4Solver(Ts... funcs)
 | 
					   RK4Solver(Ts... funcs)
 | 
				
			||||||
   {
 | 
					   {
 | 
				
			||||||
      (odes.push_back(funcs), ...);
 | 
					      (odes.push_back(funcs), ...);
 | 
				
			||||||
 | 
					      temp.resize(sizeof...(Ts));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   }
 | 
					   }
 | 
				
			||||||
   virtual ~RK4Solver() {}
 | 
					   virtual ~RK4Solver() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   void setTimeStep(double inTs) override { dt = inTs;  halfDT = dt / 2.0; }
 | 
					   void setTimeStep(double inTs) override { dt = inTs;  halfDT = dt / 2.0; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   void step(double t, double* curVal, double* res) override
 | 
					   void step(double t, const std::vector<double>& curVal, std::vector<double>& res) override
 | 
				
			||||||
   {
 | 
					   {
 | 
				
			||||||
      if(dt == std::numeric_limits<double>::quiet_NaN())
 | 
					      if(dt == std::numeric_limits<double>::quiet_NaN())
 | 
				
			||||||
      {
 | 
					      {
 | 
				
			||||||
@ -76,7 +77,7 @@ public:
 | 
				
			|||||||
   }
 | 
					   }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
private:
 | 
					private:
 | 
				
			||||||
   std::vector<std::function<double(double, double*)>> odes;
 | 
					   std::vector<std::function<double(double, const std::vector<double>&)>> odes;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   static constexpr size_t len = sizeof...(Ts);
 | 
					   static constexpr size_t len = sizeof...(Ts);
 | 
				
			||||||
   double k1[len];
 | 
					   double k1[len];
 | 
				
			||||||
@ -84,7 +85,7 @@ private:
 | 
				
			|||||||
   double k3[len];
 | 
					   double k3[len];
 | 
				
			||||||
   double k4[len];
 | 
					   double k4[len];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   double temp[len];
 | 
					   std::vector<double> temp;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   double dt = std::numeric_limits<double>::quiet_NaN();
 | 
					   double dt = std::numeric_limits<double>::quiet_NaN();
 | 
				
			||||||
   double halfDT = 0.0;
 | 
					   double halfDT = 0.0;
 | 
				
			||||||
 | 
				
			|||||||
@ -14,6 +14,7 @@ class StateData
 | 
				
			|||||||
public:
 | 
					public:
 | 
				
			||||||
   StateData();
 | 
					   StateData();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
private:
 | 
					private:
 | 
				
			||||||
   math::Vector3 position{0.0, 0.0, 0.0};
 | 
					   math::Vector3 position{0.0, 0.0, 0.0};
 | 
				
			||||||
   math::Vector3 velocity{0.0, 0.0, 0.0};
 | 
					   math::Vector3 velocity{0.0, 0.0, 0.0};
 | 
				
			||||||
@ -23,6 +24,8 @@ private:
 | 
				
			|||||||
   // Necessary?
 | 
					   // Necessary?
 | 
				
			||||||
   //math::Vector3 orientationAccel;
 | 
					   //math::Vector3 orientationAccel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   // This is an array because the integrator expects it
 | 
				
			||||||
 | 
					   double data[6];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										83
									
								
								utils/TSQueue.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								utils/TSQueue.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,83 @@
 | 
				
			|||||||
 | 
					#ifndef TSQUEUE_H
 | 
				
			||||||
 | 
					#define TSQUEUE_H
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <mutex>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <queue>
 | 
				
			||||||
 | 
					#include <condition_variable>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * @brief The TSQueue class is a very basic thread-safe queue
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					template<typename T>
 | 
				
			||||||
 | 
					class TSQueue
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					public:
 | 
				
			||||||
 | 
					   TSQueue()
 | 
				
			||||||
 | 
					      : mtx(),
 | 
				
			||||||
 | 
					        q(),
 | 
				
			||||||
 | 
					        cv()
 | 
				
			||||||
 | 
					   {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   void push(T newVal)
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      std::lock_guard<std::mutex> lck(mtx);
 | 
				
			||||||
 | 
					      q.push(newVal);
 | 
				
			||||||
 | 
					      cv.notify_one();
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   void waitAndPop(T& val)
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      std::unique_lock<std::mutex> lck(mtx);
 | 
				
			||||||
 | 
					      cv.wait(lck, [this]{return !q.empty(); });
 | 
				
			||||||
 | 
					      val = std::move(q.front());
 | 
				
			||||||
 | 
					      q.pop();
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   std::shared_ptr<T> waitAndPop()
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      std::unique_lock<std::mutex> lck(mtx);
 | 
				
			||||||
 | 
					      cv.wait(lck, [this] { return !q.empty(); });
 | 
				
			||||||
 | 
					      std::shared_ptr<T> res(std::make_shared<T>(std::move(q.front())));
 | 
				
			||||||
 | 
					      q.pop();
 | 
				
			||||||
 | 
					      return res;
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   bool tryPop(T& val)
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      std::unique_lock<std::mutex> lck(mtx);
 | 
				
			||||||
 | 
					      if(q.empty())
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					         return false;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      val = std::move(q.front());
 | 
				
			||||||
 | 
					      q.pop();
 | 
				
			||||||
 | 
					      return true;
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   std::shared_ptr<T> tryPop()
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      std::unique_lock<std::mutex> lck(mtx);
 | 
				
			||||||
 | 
					      if(q.empty())
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					         return std::shared_ptr<T>(); // nullptr
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      std::shared_ptr<T> retVal(std::move(q.front()));
 | 
				
			||||||
 | 
					      q.pop();
 | 
				
			||||||
 | 
					      return retVal;
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   bool empty() const
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      std::lock_guard<std::mutex> lck(mtx);
 | 
				
			||||||
 | 
					      return q.empty();
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					private:
 | 
				
			||||||
 | 
					   mutable std::mutex mtx;
 | 
				
			||||||
 | 
					   std::queue<T> q;
 | 
				
			||||||
 | 
					   std::condition_variable cv;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#endif // TSQUEUE_H
 | 
				
			||||||
@ -1,6 +1,45 @@
 | 
				
			|||||||
#include "ThreadPool.h"
 | 
					#include "ThreadPool.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ThreadPool::ThreadPool()
 | 
					#include <cstdint>
 | 
				
			||||||
{
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ThreadPool::ThreadPool()
 | 
				
			||||||
 | 
					   : done(false),
 | 
				
			||||||
 | 
					     q(),
 | 
				
			||||||
 | 
					     threads(),
 | 
				
			||||||
 | 
					     joiner(threads)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					   const std::size_t threadCount = std::thread::hardware_concurrency();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   try
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      for(size_t i = 0; i < threadCount; ++i)
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					         threads.push_back(std::thread(&ThreadPool::worker, this));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					   catch(...)
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      done = true;
 | 
				
			||||||
 | 
					      throw;
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ThreadPool::~ThreadPool()
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					   done = true;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					void ThreadPool::worker()
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					   while(!done)
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      std::function<void()> task;
 | 
				
			||||||
 | 
					      if(q.tryPop(task))
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					         task();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      else
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					         std::this_thread::yield();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,12 @@
 | 
				
			|||||||
#define THREADPOOL_H
 | 
					#define THREADPOOL_H
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <atomic>
 | 
					#include <atomic>
 | 
				
			||||||
 | 
					#include <functional>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					#include <thread>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "TSQueue.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * @brief A basic ThreadPool class
 | 
					 * @brief A basic ThreadPool class
 | 
				
			||||||
@ -10,8 +16,46 @@ class ThreadPool
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
   ThreadPool();
 | 
					   ThreadPool();
 | 
				
			||||||
 | 
					   ~ThreadPool();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   template<typename FunctionType>
 | 
				
			||||||
 | 
					   void submit(FunctionType f)
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					      q.push(std::function<void()>(f));
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
private:
 | 
					private:
 | 
				
			||||||
 | 
					   class JoinThreads
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					   public:
 | 
				
			||||||
 | 
					      explicit JoinThreads(std::vector<std::thread>& inThreads)
 | 
				
			||||||
 | 
					         : threads(inThreads)
 | 
				
			||||||
 | 
					      {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      ~JoinThreads()
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					         for(auto& i : threads)
 | 
				
			||||||
 | 
					         {
 | 
				
			||||||
 | 
					            if(i.joinable())
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					               i.join();
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					         }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					   private:
 | 
				
			||||||
 | 
					      std::vector<std::thread>& threads;
 | 
				
			||||||
 | 
					   };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   std::atomic_bool done;
 | 
					   std::atomic_bool done;
 | 
				
			||||||
 | 
					   TSQueue<std::function<void()>> q;
 | 
				
			||||||
 | 
					   std::vector<std::thread> threads;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   JoinThreads joiner;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   void worker();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif // THREADPOOL_H
 | 
					#endif // THREADPOOL_H
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user