Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion include/pysa/branching/branching.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ specific language governing permissions and limitations under the License.

namespace pysa::branching {

enum struct BranchResult{
BRANCHOK, // default exit status
BRANCHEXIT // terminate all threads
};
/**
* @brief Split a collection of branches into two.
* Branches is a container that supports efficient .front() and .pop_front()
Expand Down Expand Up @@ -96,7 +100,7 @@ void branching_impl(const Function &fn, Branches &branches,

// Define core
auto core_ = [fn = fn, &branches](std::size_t idx, auto &&stop) {
fn(branches[idx], stop);
return fn(branches[idx], stop);
};

// Initialize threads
Expand Down Expand Up @@ -133,12 +137,30 @@ void branching_impl(const Function &fn, Branches &branches,
return std::tuple{min_, max_};
};

bool exit_all = false;
// Avoid a race condition where count_n_branches_() may be 0 temporarily at
// start
std::this_thread::sleep_for(sleep_time);

// Keep going if there are still branches or the stop signal is off
while (count_n_branches_() && !*stop) {
for (auto& t: threads_){
if(t.is_ready()){
if(t.get() == BranchResult::BRANCHEXIT){
exit_all = true;
}
}
}
if(exit_all){
#ifndef NDEBUG
std::cerr << "# Brancher exit " << std::endl;
#endif
for (auto& t: threads_){
if(t.is_running())
t.stop();
}
return;
}
// Propagate branches between two threads
if (const auto [ei_, ni_] = balance_indexes_(); ei_ && ni_) {
const auto e_idx_ = ei_.value();
Expand Down
17 changes: 12 additions & 5 deletions include/pysa/dpll/dpll.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ specific language governing permissions and limitations under the License.

namespace pysa::branching {

template <bool depth_first = true, typename Branches, typename Collect>
void DPLL_(Branches &&branches, Collect &&collect, ConstStopPtr stop) {
template <bool depth_first = true, bool exit_on_first = false, typename Branches, typename Collect>
BranchResult DPLL_(Branches &&branches, Collect &&collect, ConstStopPtr stop) {
// While there are still branches ...
while (std::size(branches) && !*stop) {
// Get last branch (depth first)
Expand All @@ -45,11 +45,18 @@ void DPLL_(Branches &&branches, Collect &&collect, ConstStopPtr stop) {
branches.splice(std::end(branches), branch_.branch());

// Collect
collect(std::move(branch_));
if constexpr (exit_on_first){
if(collect(std::move(branch_))){
return BranchResult::BRANCHEXIT;
}
} else {
collect(std::move(branch_));
}
}
return BranchResult::BRANCHOK;
}

template <bool depth_first = true, typename Branches, typename Collect,
template <bool depth_first = true, bool exit_on_first = false, typename Branches, typename Collect,
typename... Args>
auto DPLL(Branches &&branches, Collect &&collect, Args &&...args) {
/*
Expand All @@ -59,7 +66,7 @@ auto DPLL(Branches &&branches, Collect &&collect, Args &&...args) {
// Get brancher
return branching(
[collect](auto &&branches, auto &&stop) {
DPLL_<depth_first>(branches, collect, stop);
return DPLL_<depth_first, exit_on_first>(branches, collect, stop);
},
std::forward<Branches>(branches), std::forward<Args>(args)...);
}
Expand Down
3 changes: 2 additions & 1 deletion tests/test_branching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ int main() {
{
// Use one thread
pysa::branching::TestBranching(28, 1, true);

pysa::branching::TestBranching<true>(28, 1, true);
// Use number of threads provided by the implementation
pysa::branching::TestBranching(30, 0, true);
pysa::branching::TestBranching<true>(30, 0, true);
}
#ifdef USE_MPI
MPI_Barrier(mpi_comm_world);
Expand Down
26 changes: 16 additions & 10 deletions tests/test_branching.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct Branch {

using Branches = std::list<Branch>;

template<bool stop_on_first=false>
void TestBranching(const std::size_t n, const std::size_t n_threads = 0,
const bool verbose = false) {
// How to collect the branches
Expand All @@ -72,11 +73,14 @@ void TestBranching(const std::size_t n, const std::size_t n_threads = 0,
if (CheckBranch(branch)) {
const std::scoped_lock<std::mutex> lock_(mutex_);
collected_.push_back(branch.state);
return true;
} else {
return false;
}
};

// Get branches
auto brancher_ = DPLL(Branches{Branch{n, 0, 0}}, collect_, n_threads);
auto brancher_ = DPLL<true, stop_on_first>(Branches{Branch{n, 0, 0}}, collect_, n_threads);

// Start brancher
auto it_ = std::chrono::high_resolution_clock::now();
Expand All @@ -98,18 +102,20 @@ void TestBranching(const std::size_t n, const std::size_t n_threads = 0,
.count()
<< std::endl;

// Sort collected numbers
std::sort(std::begin(collected_), std::end(collected_));
if constexpr (!stop_on_first){
// Sort collected numbers
std::sort(std::begin(collected_), std::end(collected_));

// Get head
auto head_ = std::cbegin(collected_);
// Get head
auto head_ = std::cbegin(collected_);

// Check results
for (std::size_t i_ = 0, end_ = std::size_t{1} << n; i_ < end_; ++i_)
if (CheckBranch(Branch{n, i_, 0})) assert(*head_++ == i_);
// Check results
for (std::size_t i_ = 0, end_ = std::size_t{1} << n; i_ < end_; ++i_)
if (CheckBranch(Branch{n, i_, 0})) assert(*head_++ == i_);

// All numbers should have been checked at this point
assert(head_ == std::cend(collected_));
// All numbers should have been checked at this point
assert(head_ == std::cend(collected_));
}
}

#ifdef USE_MPI
Expand Down