diff --git a/libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.hpp b/libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.hpp index e3f42ba..26304c2 100755 --- a/libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.hpp +++ b/libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.hpp @@ -28,6 +28,8 @@ namespace libfqfft { FieldT arithmetic_generator; void do_precomputation(); + static bool valid_for_size(const size_t m); + arithmetic_sequence_domain(const size_t m); void FFT(std::vector &a); diff --git a/libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.tcc b/libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.tcc index 3b59ae2..74ce02c 100755 --- a/libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.tcc +++ b/libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.tcc @@ -23,6 +23,20 @@ namespace libfqfft { +template +bool arithmetic_sequence_domain::valid_for_size(const size_t m) +{ + if (m <= 1) { + return false; + } + + if (FieldT::arithmetic_generator() == FieldT::zero()) { + return false; + } + + return true; +} + template arithmetic_sequence_domain::arithmetic_sequence_domain(const size_t m) : evaluation_domain(m) { @@ -42,7 +56,7 @@ void arithmetic_sequence_domain::FFT(std::vector &a) /* Monomial to Newton */ monomial_to_newton_basis(a, this->subproduct_tree, this->m); - + /* Newton to Evaluation */ std::vector S(this->m); /* i! * arithmetic_generator */ S[0] = FieldT::one(); @@ -70,7 +84,7 @@ template void arithmetic_sequence_domain::iFFT(std::vector &a) { if (a.size() != this->m) throw DomainSizeException("arithmetic: expected a.size() == this->m"); - + if (!this->precomputation_sentinel) do_precomputation(); /* Interpolation to Newton */ @@ -152,7 +166,7 @@ std::vector arithmetic_sequence_domain::evaluate_all_lagrange_po std::vector w(this->m); w[0] = g_vanish.inverse() * (this->arithmetic_generator^(this->m-1)); - + l[0] = l_vanish * l[0].inverse() * w[0]; for (size_t i = 1; i < this->m; i++) { diff --git a/libfqfft/evaluation_domain/domains/basic_radix2_domain.hpp b/libfqfft/evaluation_domain/domains/basic_radix2_domain.hpp index 7b77b29..2ce3dc1 100755 --- a/libfqfft/evaluation_domain/domains/basic_radix2_domain.hpp +++ b/libfqfft/evaluation_domain/domains/basic_radix2_domain.hpp @@ -26,6 +26,8 @@ class basic_radix2_domain : public evaluation_domain { FieldT omega; + static bool valid_for_size(const size_t m); + basic_radix2_domain(const size_t m); void FFT(std::vector &a); diff --git a/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc b/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc index 2486f7a..cf4a69c 100755 --- a/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc +++ b/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc @@ -22,6 +22,20 @@ namespace libfqfft { +template +bool basic_radix2_domain::valid_for_size(const size_t m) +{ + if (m <= 1) { + return false; + } + + if (!libff::has_root_of_unity(m)) { + return false; + } + + return true; +} + template basic_radix2_domain::basic_radix2_domain(const size_t m) : evaluation_domain(m) { diff --git a/libfqfft/evaluation_domain/domains/extended_radix2_domain.hpp b/libfqfft/evaluation_domain/domains/extended_radix2_domain.hpp index 7637925..53e9c11 100755 --- a/libfqfft/evaluation_domain/domains/extended_radix2_domain.hpp +++ b/libfqfft/evaluation_domain/domains/extended_radix2_domain.hpp @@ -27,6 +27,8 @@ class extended_radix2_domain : public evaluation_domain { FieldT omega; FieldT shift; + static bool valid_for_size(const size_t m); + extended_radix2_domain(const size_t m); void FFT(std::vector &a); diff --git a/libfqfft/evaluation_domain/domains/extended_radix2_domain.tcc b/libfqfft/evaluation_domain/domains/extended_radix2_domain.tcc index b6a31fa..de4a686 100755 --- a/libfqfft/evaluation_domain/domains/extended_radix2_domain.tcc +++ b/libfqfft/evaluation_domain/domains/extended_radix2_domain.tcc @@ -17,6 +17,32 @@ namespace libfqfft { +template +bool extended_radix2_domain::valid_for_size(const size_t m) +{ + if (m <= 1) { + return false; + } + + // Will `get_root_of_unity` throw? + if (!std::is_same::value) + { + const size_t logm = libff::log2(m); + + if (logm != (FieldT::s + 1)) { + return false; + } + } + + size_t small_m = m / 2; + + if (!libff::has_root_of_unity(small_m)) { + return false; + } + + return true; +} + template extended_radix2_domain::extended_radix2_domain(const size_t m) : evaluation_domain(m) { diff --git a/libfqfft/evaluation_domain/domains/geometric_sequence_domain.hpp b/libfqfft/evaluation_domain/domains/geometric_sequence_domain.hpp index 22bd783..70a2da2 100755 --- a/libfqfft/evaluation_domain/domains/geometric_sequence_domain.hpp +++ b/libfqfft/evaluation_domain/domains/geometric_sequence_domain.hpp @@ -27,6 +27,8 @@ namespace libfqfft { std::vector geometric_triangular_sequence; void do_precomputation(); + static bool valid_for_size(const size_t m); + geometric_sequence_domain(const size_t m); void FFT(std::vector &a); diff --git a/libfqfft/evaluation_domain/domains/geometric_sequence_domain.tcc b/libfqfft/evaluation_domain/domains/geometric_sequence_domain.tcc index 1535810..530f23f 100755 --- a/libfqfft/evaluation_domain/domains/geometric_sequence_domain.tcc +++ b/libfqfft/evaluation_domain/domains/geometric_sequence_domain.tcc @@ -23,19 +23,33 @@ namespace libfqfft { +template +bool geometric_sequence_domain::valid_for_size(const size_t m) +{ + if (m <= 1) { + return false; + } + + if (FieldT::geometric_generator() == FieldT::zero()) { + return false; + } + + return true; +} + template geometric_sequence_domain::geometric_sequence_domain(const size_t m) : evaluation_domain(m) { if (m <= 1) throw InvalidSizeException("geometric(): expected m > 1"); if (FieldT::geometric_generator() == FieldT::zero()) throw InvalidSizeException("geometric(): expected FieldT::geometric_generator() != FieldT::zero()"); - + precomputation_sentinel = 0; } template void geometric_sequence_domain::FFT(std::vector &a) -{ +{ if (a.size() != this->m) throw DomainSizeException("geometric: expected a.size() == this->m"); if (!this->precomputation_sentinel) do_precomputation(); @@ -71,7 +85,7 @@ template void geometric_sequence_domain::iFFT(std::vector &a) { if (a.size() != this->m) throw DomainSizeException("geometric: expected a.size() == this->m"); - + if (!this->precomputation_sentinel) do_precomputation(); /* Interpolation to Newton */ diff --git a/libfqfft/evaluation_domain/domains/step_radix2_domain.hpp b/libfqfft/evaluation_domain/domains/step_radix2_domain.hpp index 33ba7f4..72fc090 100755 --- a/libfqfft/evaluation_domain/domains/step_radix2_domain.hpp +++ b/libfqfft/evaluation_domain/domains/step_radix2_domain.hpp @@ -29,6 +29,8 @@ class step_radix2_domain : public evaluation_domain { FieldT big_omega; FieldT small_omega; + static bool valid_for_size(const size_t m); + step_radix2_domain(const size_t m); void FFT(std::vector &a); diff --git a/libfqfft/evaluation_domain/domains/step_radix2_domain.tcc b/libfqfft/evaluation_domain/domains/step_radix2_domain.tcc index e9a984e..b5aaac3 100755 --- a/libfqfft/evaluation_domain/domains/step_radix2_domain.tcc +++ b/libfqfft/evaluation_domain/domains/step_radix2_domain.tcc @@ -17,6 +17,33 @@ namespace libfqfft { +template +bool step_radix2_domain::valid_for_size(const size_t m) +{ + if (m <= 1) { + return false; + } + + const size_t big_m = 1ul<<(libff::log2(m)-1); + const size_t small_m = m - big_m; + + if (small_m != 1ul<(1ul<(1ul< step_radix2_domain::step_radix2_domain(const size_t m) : evaluation_domain(m) { @@ -30,7 +57,7 @@ step_radix2_domain::step_radix2_domain(const size_t m) : evaluation_doma try { omega = libff::get_root_of_unity(1ul<(small_m); } diff --git a/libfqfft/evaluation_domain/evaluation_domain.hpp b/libfqfft/evaluation_domain/evaluation_domain.hpp index 43c7498..2745330 100755 --- a/libfqfft/evaluation_domain/evaluation_domain.hpp +++ b/libfqfft/evaluation_domain/evaluation_domain.hpp @@ -27,6 +27,7 @@ #define EVALUATION_DOMAIN_HPP_ #include +#include namespace libfqfft { diff --git a/libfqfft/evaluation_domain/get_evaluation_domain.tcc b/libfqfft/evaluation_domain/get_evaluation_domain.tcc index 299537c..9ab4100 100755 --- a/libfqfft/evaluation_domain/get_evaluation_domain.tcc +++ b/libfqfft/evaluation_domain/get_evaluation_domain.tcc @@ -38,15 +38,33 @@ std::shared_ptr > get_evaluation_domain(const size_t m const size_t small = min_size - big; const size_t rounded_small = (1ul<(min_size)); } - catch(...) { try { result.reset(new extended_radix2_domain(min_size)); } - catch(...) { try { result.reset(new step_radix2_domain(min_size)); } - catch(...) { try { result.reset(new basic_radix2_domain(big + rounded_small)); } - catch(...) { try { result.reset(new extended_radix2_domain(big + rounded_small)); } - catch(...) { try { result.reset(new step_radix2_domain(big + rounded_small)); } - catch(...) { try { result.reset(new geometric_sequence_domain(min_size)); } - catch(...) { try { result.reset(new arithmetic_sequence_domain(min_size)); } - catch(...) { throw DomainSizeException("get_evaluation_domain: no matching domain"); }}}}}}}} + if (basic_radix2_domain::valid_for_size(min_size)) { + result.reset(new basic_radix2_domain(min_size)); + } + else if (extended_radix2_domain::valid_for_size(min_size)) { + result.reset(new extended_radix2_domain(min_size)); + } + else if (step_radix2_domain::valid_for_size(min_size)) { + result.reset(new step_radix2_domain(min_size)); + } + else if (basic_radix2_domain::valid_for_size(big + rounded_small)) { + result.reset(new basic_radix2_domain(big + rounded_small)); + } + else if (extended_radix2_domain::valid_for_size(big + rounded_small)) { + result.reset(new extended_radix2_domain(big + rounded_small)); + } + else if (step_radix2_domain::valid_for_size(big + rounded_small)) { + result.reset(new step_radix2_domain(big + rounded_small)); + } + else if (geometric_sequence_domain::valid_for_size(min_size)) { + result.reset(new geometric_sequence_domain(min_size)); + } + else if (arithmetic_sequence_domain::valid_for_size(min_size)) { + result.reset(new arithmetic_sequence_domain(min_size)); + } + else { + throw DomainSizeException("get_evaluation_domain: no matching domain"); + } return result; }