Add Nelder Mead.
This commit is contained in:
parent
abe406d7aa
commit
7af3714493
|
|
@ -8,17 +8,32 @@ find_package(ament_cmake REQUIRED)
|
|||
find_package(ament_cmake_ros REQUIRED)
|
||||
find_package(rclcpp REQUIRED)
|
||||
|
||||
add_library(cost_function
|
||||
src/CostFunction.cpp)
|
||||
target_compile_features(cost_function PUBLIC cxx_std_17)
|
||||
target_include_directories(cost_function PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
add_library(simplex_solver
|
||||
src/CostFunction.cpp
|
||||
src/SimplexSolver.cpp)
|
||||
target_compile_features(simplex_solver PUBLIC cxx_std_17) # Require C++17
|
||||
target_link_libraries(simplex_solver cost_function)
|
||||
target_compile_features(simplex_solver PUBLIC cxx_std_17)
|
||||
target_include_directories(simplex_solver PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
add_library(nelder_mead
|
||||
src/NelderMead.cpp)
|
||||
target_link_libraries(nelder_mead cost_function)
|
||||
target_compile_features(nelder_mead PUBLIC cxx_std_17)
|
||||
target_include_directories(nelder_mead PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
add_executable(main src/main.cpp)
|
||||
target_compile_features(main PUBLIC cxx_std_17) # Require C++17
|
||||
target_link_libraries(main simplex_solver)
|
||||
target_compile_features(main PUBLIC cxx_std_17)
|
||||
target_link_libraries(main cost_function simplex_solver nelder_mead)
|
||||
target_include_directories(main PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
|
@ -30,11 +45,13 @@ install(
|
|||
)
|
||||
install(
|
||||
TARGETS
|
||||
cost_function
|
||||
simplex_solver
|
||||
nelder_mead
|
||||
EXPORT export_${PROJECT_NAME}
|
||||
ARCHIVE DESTINATION lib/${PROJECT_NAME}
|
||||
LIBRARY DESTINATION lib/${PROJECT_NAME}
|
||||
RUNTIME DESTINATION bin/${PROJECT_NAME}
|
||||
ARCHIVE DESTINATION lib
|
||||
LIBRARY DESTINATION lib
|
||||
RUNTIME DESTINATION bin
|
||||
)
|
||||
install(
|
||||
TARGETS
|
||||
|
|
@ -45,7 +62,9 @@ ament_export_include_directories(
|
|||
include
|
||||
)
|
||||
ament_export_libraries(
|
||||
optimizer
|
||||
cost_function
|
||||
simplex_solver
|
||||
nelder_mead
|
||||
)
|
||||
ament_export_targets(
|
||||
export_${PROJECT_NAME}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2022 James Pace
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// For a license to this software contact
|
||||
// James Pace at jpace121@gmail.com.
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#ifndef J7S__NELDERMEAD_HPP_
|
||||
#define J7S__NELDERMEAD_HPP_
|
||||
|
||||
#include "j7s-optimization/CostFunction.hpp"
|
||||
#include "j7s-optimization/common.hpp"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace j7s
|
||||
{
|
||||
|
||||
class NelderMead
|
||||
{
|
||||
public:
|
||||
NelderMead(const CostFunction & costFunction, const std::vector<double> initSimplex);
|
||||
|
||||
IterationState update();
|
||||
|
||||
Coordinate bestCoord() const;
|
||||
|
||||
private:
|
||||
const CostFunction m_costFunction;
|
||||
std::vector<Coordinate> m_currentSimplex;
|
||||
|
||||
// Helper functions.
|
||||
double newPoint() const;
|
||||
std::vector<Coordinate> contract() const;
|
||||
double calcVolume() const;
|
||||
double findCentroid() const;
|
||||
double secondNewPoint(const Coordinate& newPoint) const;
|
||||
};
|
||||
|
||||
} // namespace j7s
|
||||
|
||||
#endif // J7S__NELDERMEAD_HPP_
|
||||
|
|
@ -32,9 +32,9 @@ private:
|
|||
std::vector<Coordinate> m_currentSimplex;
|
||||
|
||||
// Helper functions.
|
||||
double newPoint();
|
||||
std::vector<Coordinate> contract();
|
||||
double calcVolume();
|
||||
double newPoint() const;
|
||||
std::vector<Coordinate> contract() const;
|
||||
double calcVolume() const;
|
||||
};
|
||||
|
||||
} // namespace j7s
|
||||
|
|
|
|||
|
|
@ -34,6 +34,11 @@ struct Coordinate
|
|||
|
||||
// Sort by cost.
|
||||
bool operator<(const Coordinate & other) const { return (cost < other.cost); }
|
||||
bool operator<=(const Coordinate & other) const { return (cost <= other.cost); }
|
||||
bool operator>(const Coordinate & other) const { return (cost > other.cost); }
|
||||
bool operator>=(const Coordinate & other) const { return (cost >= other.cost); }
|
||||
bool operator==(const Coordinate & other) const { return (cost == other.cost); }
|
||||
bool operator!=(const Coordinate & other) const { return (cost != other.cost); }
|
||||
};
|
||||
|
||||
} // namespace j7s
|
||||
|
|
|
|||
|
|
@ -0,0 +1,182 @@
|
|||
// Copyright 2022 James Pace
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// For a license to this software contact
|
||||
// James Pace at jpace121@gmail.com.
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#include "j7s-optimization/NelderMead.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace j7s
|
||||
{
|
||||
NelderMead::NelderMead(
|
||||
const CostFunction & costFunction, const std::vector<double> initSimplex) :
|
||||
m_costFunction{costFunction}
|
||||
{
|
||||
m_currentSimplex.reserve(initSimplex.size());
|
||||
|
||||
for (const auto val : initSimplex)
|
||||
{
|
||||
m_currentSimplex.emplace_back(val, m_costFunction.eval(val));
|
||||
}
|
||||
std::sort(m_currentSimplex.begin(), m_currentSimplex.end());
|
||||
}
|
||||
|
||||
Coordinate NelderMead::bestCoord() const
|
||||
{
|
||||
return m_currentSimplex.front();
|
||||
}
|
||||
|
||||
double NelderMead::newPoint() const
|
||||
{
|
||||
if (m_currentSimplex.size() == 0)
|
||||
{
|
||||
throw std::runtime_error("Simplex can't be missing.");
|
||||
}
|
||||
|
||||
// Calculate sum.
|
||||
double biggest = m_currentSimplex.back().input;
|
||||
// Calculate volume.
|
||||
double sum = 0.0;
|
||||
// All but the biggest, (would be adding 0...)
|
||||
for (unsigned int index = 0; index < m_currentSimplex.size() - 1; index++)
|
||||
{
|
||||
const double diff = m_currentSimplex[index].input - biggest;
|
||||
sum += diff;
|
||||
}
|
||||
|
||||
const double newPoint = sum * (2.0 / m_currentSimplex.size());
|
||||
|
||||
return newPoint;
|
||||
}
|
||||
|
||||
double NelderMead::findCentroid() const
|
||||
{
|
||||
double sum = 0.0;
|
||||
for (unsigned int index = 0; index < m_currentSimplex.size() - 1; index++)
|
||||
{
|
||||
sum += m_currentSimplex[index].input;
|
||||
}
|
||||
|
||||
const double newPoint = sum * (1.0 / m_currentSimplex.size());
|
||||
|
||||
return newPoint;
|
||||
}
|
||||
|
||||
double NelderMead::secondNewPoint(const Coordinate& newPoint) const
|
||||
{
|
||||
// FxN in paper.
|
||||
const auto biggest = m_currentSimplex.back();
|
||||
// Fx0 in paper.
|
||||
const auto smallest = m_currentSimplex.front();
|
||||
// FxN-1 in paper.
|
||||
// TODO: Assuming simplex size here.
|
||||
const auto secondBiggest = *(m_currentSimplex.end() - 2);
|
||||
|
||||
const double centroid = findCentroid();
|
||||
|
||||
if(newPoint < smallest)
|
||||
{
|
||||
return 2.0*newPoint.input - centroid;
|
||||
}
|
||||
else if(smallest <= newPoint and newPoint < secondBiggest)
|
||||
{
|
||||
return 0.5*(centroid + newPoint.input);
|
||||
}
|
||||
else if(secondBiggest <= newPoint)
|
||||
{
|
||||
return 0.5*(biggest.input + centroid);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Shouldn't be able to get here?
|
||||
throw std::logic_error("Shouldn't be able to get here.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Coordinate> NelderMead::contract() const
|
||||
{
|
||||
const auto smallest = m_currentSimplex.front();
|
||||
|
||||
std::vector<Coordinate> newVector;
|
||||
newVector.reserve(m_currentSimplex.size());
|
||||
newVector.emplace_back(smallest);
|
||||
|
||||
// TODO: Really check size before I get here...
|
||||
for (auto it = m_currentSimplex.begin() + 1; it != m_currentSimplex.end(); it++)
|
||||
{
|
||||
const auto oldInput = it->input;
|
||||
const auto newInput = 0.5 * (oldInput + smallest.input);
|
||||
const auto newCost = m_costFunction.eval(newInput);
|
||||
newVector.emplace_back(newInput, newCost);
|
||||
}
|
||||
std::sort(newVector.begin(), newVector.end());
|
||||
|
||||
return newVector;
|
||||
}
|
||||
|
||||
double NelderMead::calcVolume() const
|
||||
{
|
||||
// TODO: For reals do something like:
|
||||
// https://math.stackexchange.com/questions/337197/finding-the-volume-of-a-tetrahedron-by-given-vertices
|
||||
// For now:
|
||||
// Sort by input and find the difference squared between the first and last.
|
||||
|
||||
const auto inputLess = [](const Coordinate & first, const Coordinate & second)
|
||||
{ return first.input < second.input; };
|
||||
|
||||
// Copy the vector so we don't sort the original.
|
||||
std::vector<Coordinate> simplexCopy = m_currentSimplex;
|
||||
std::sort(simplexCopy.begin(), simplexCopy.end(), inputLess);
|
||||
const auto smallest = simplexCopy.front();
|
||||
const auto biggest = simplexCopy.back();
|
||||
return std::pow(biggest.input - smallest.input, 2.0);
|
||||
}
|
||||
|
||||
IterationState NelderMead::update()
|
||||
{
|
||||
if (m_currentSimplex.size() < 3)
|
||||
{
|
||||
throw std::runtime_error("Simplex can't be a line.");
|
||||
}
|
||||
// Check for convergence and potentially early return.
|
||||
// TODO: Make configurable.
|
||||
const auto volume = calcVolume();
|
||||
if (volume < 1e-4)
|
||||
{
|
||||
return IterationState::CONVERGED;
|
||||
}
|
||||
|
||||
// Do update.
|
||||
Coordinate potential;
|
||||
potential.input = newPoint();
|
||||
potential.cost = m_costFunction.eval(potential.input);
|
||||
Coordinate secondPotential;
|
||||
secondPotential.input = secondNewPoint(potential);
|
||||
secondPotential.cost = m_costFunction.eval(secondPotential.input);
|
||||
|
||||
const auto minPotential = std::min(potential, secondPotential);
|
||||
|
||||
const auto secondBiggest = *(m_currentSimplex.end() - 2);
|
||||
if (minPotential.cost < secondBiggest.cost)
|
||||
{
|
||||
// Replace the last simplex value with the better one.
|
||||
*(m_currentSimplex.end() - 1) = minPotential;
|
||||
std::sort(m_currentSimplex.begin(), m_currentSimplex.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
// Do a contraction.
|
||||
m_currentSimplex = contract();
|
||||
}
|
||||
|
||||
return IterationState::OK;
|
||||
}
|
||||
|
||||
} // namespace j7s
|
||||
|
|
@ -33,7 +33,7 @@ Coordinate SimplexSolver::bestCoord() const
|
|||
return m_currentSimplex.front();
|
||||
}
|
||||
|
||||
double SimplexSolver::newPoint()
|
||||
double SimplexSolver::newPoint() const
|
||||
{
|
||||
if (m_currentSimplex.size() == 0)
|
||||
{
|
||||
|
|
@ -56,7 +56,7 @@ double SimplexSolver::newPoint()
|
|||
return newPoint;
|
||||
}
|
||||
|
||||
std::vector<Coordinate> SimplexSolver::contract()
|
||||
std::vector<Coordinate> SimplexSolver::contract() const
|
||||
{
|
||||
const auto smallest = m_currentSimplex.front();
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ std::vector<Coordinate> SimplexSolver::contract()
|
|||
return newVector;
|
||||
}
|
||||
|
||||
double SimplexSolver::calcVolume()
|
||||
double SimplexSolver::calcVolume() const
|
||||
{
|
||||
// TODO: For reals do something like:
|
||||
// https://math.stackexchange.com/questions/337197/finding-the-volume-of-a-tetrahedron-by-given-vertices
|
||||
|
|
|
|||
43
src/main.cpp
43
src/main.cpp
|
|
@ -11,12 +11,24 @@
|
|||
|
||||
#include "j7s-optimization/CostFunction.hpp"
|
||||
#include "j7s-optimization/SimplexSolver.hpp"
|
||||
#include "j7s-optimization/NelderMead.hpp"
|
||||
#include "j7s-optimization/common.hpp"
|
||||
|
||||
void runSimpleSimplex(const j7s::CostFunction& cost, const std::vector<double>& init_simplex);
|
||||
void runNelderMead(const j7s::CostFunction& cost, const std::vector<double>& init_simplex);
|
||||
|
||||
int main(int, char **)
|
||||
{
|
||||
const j7s::CostFunction cost(2.0, 3.0, 4.0);
|
||||
const std::vector<double> init_simplex = {-10, 0, 10};
|
||||
|
||||
runSimpleSimplex(cost, init_simplex);
|
||||
runNelderMead(cost, init_simplex);
|
||||
return 0;
|
||||
}
|
||||
|
||||
void runSimpleSimplex(const j7s::CostFunction& cost, const std::vector<double>& init_simplex)
|
||||
{
|
||||
j7s::SimplexSolver solver(cost, init_simplex);
|
||||
|
||||
j7s::IterationState state = j7s::IterationState::OK;
|
||||
|
|
@ -31,13 +43,38 @@ int main(int, char **)
|
|||
if (state == j7s::IterationState::CONVERGED)
|
||||
{
|
||||
const auto best = solver.bestCoord();
|
||||
std::cout << "Converged! Best Input: " << best.input << " Cost: " << best.cost << std::endl;
|
||||
std::cout << "SimpleSiplex Converged! Best Input: " << best.input << " Cost: " << best.cost << std::endl;
|
||||
std::cout << "Actual Best: " << cost.actualBest()
|
||||
<< " Cost: " << cost.eval(cost.actualBest()) << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Did not converge." << std::endl;
|
||||
std::cout << "SimpleSimplex did not converge." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void runNelderMead(const j7s::CostFunction& cost, const std::vector<double>& init_simplex)
|
||||
{
|
||||
j7s::NelderMead solver(cost, init_simplex);
|
||||
|
||||
j7s::IterationState state = j7s::IterationState::OK;
|
||||
for (int cnt = 0; cnt < 1000; cnt++)
|
||||
{
|
||||
state = solver.update();
|
||||
if (state == j7s::IterationState::CONVERGED)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (state == j7s::IterationState::CONVERGED)
|
||||
{
|
||||
const auto best = solver.bestCoord();
|
||||
std::cout << "Nelder Mead Converged! Best Input: " << best.input << " Cost: " << best.cost << std::endl;
|
||||
std::cout << "Actual Best: " << cost.actualBest()
|
||||
<< " Cost: " << cost.eval(cost.actualBest()) << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Nelder Mead did not converge." << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue