diff --git a/Makefile b/Makefile index 50c9487..e94ed1f 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ CXX := g++ -CXXFLAGS := -std=c++17 -O3 -fPIC +CXXFLAGS := -std=c++17 -O3 -fPIC -fopenmp # Python / pybind11 include flags PYBIND11_INCLUDES := $(shell python3 -m pybind11 --includes) @@ -40,6 +40,9 @@ else LDFLAGS := -Wl,-rpath,$$ORIGIN/../extern/faiss/build/install/lib endif +# Add OpenMP linking +LDFLAGS += -fopenmp + .PHONY: all clean prepare all: prepare $(TARGET) diff --git a/src/IVFFlatIndex.cpp b/src/IVFFlatIndex.cpp index 4b4abf0..53e3386 100644 --- a/src/IVFFlatIndex.cpp +++ b/src/IVFFlatIndex.cpp @@ -1,5 +1,6 @@ #include "IVFFlatIndex.h" #include "SimdUtils.h" +#include #include #include #include @@ -47,35 +48,63 @@ SearchResult IVFFlatIndex::search(const Vector& query, size_t k) const { std::vector cdist(nlist_); std::vector heap; - // calculate distance from centroids + // Calculate distance from query to all centroids (parallelized) + #pragma omp parallel for schedule(static) for (size_t c = 0; c < nlist_; ++c) { float d = l2_naive(query.data(), centroids_[c].data(), dimension_); cdist[c] = {d, c}; } - std::partial_sort(cdist.begin(), cdist.begin() + nprobe_, cdist.end(), - [](auto& a, auto& b) { - return a.first < b.first; + + // Select top-nprobe nearest centroids + std::partial_sort(cdist.begin(), cdist.begin() + nprobe_, cdist.end(), + [](auto& a, auto& b) { + return a.first < b.first; } ); - // probe nprobe lists + // Probe nprobe nearest lists in parallel + // Each thread maintains a local heap, then merges into global heap heap.reserve(k); const auto& data = datastore_->getAll(); + #pragma omp parallel for schedule(dynamic) for (size_t pi = 0; pi < nprobe_; ++pi) { size_t c = cdist[pi].second; + // Thread-local heap for this cluster + std::vector local; + local.reserve(k); + + // Search within this cluster's inverted list for (size_t id : lists_[c]) { float dist = l2_naive(query.data(), data[id].data(), dimension_); - if (heap.size() < k) { - heap.emplace_back(dist, id); - if (heap.size() == k) - std::make_heap(heap.begin(), heap.end()); - } else if (dist < heap.front().first) { - std::pop_heap(heap.begin(), heap.end()); - heap.back() = {dist, id}; - std::push_heap(heap.begin(), heap.end()); + if (local.size() < k) { + local.emplace_back(dist, id); + if (local.size() == k) { + std::make_heap(local.begin(), local.end()); + } + } else if (dist < local.front().first) { + std::pop_heap(local.begin(), local.end()); + local.back() = {dist, id}; + std::push_heap(local.begin(), local.end()); + } + } + + // Merge local results into global heap (thread-safe) + #pragma omp critical + { + for (auto& p : local) { + if (heap.size() < k) { + heap.push_back(p); + if (heap.size() == k) { + std::make_heap(heap.begin(), heap.end()); + } + } else if (p.first < heap.front().first) { + std::pop_heap(heap.begin(), heap.end()); + heap.back() = p; + std::push_heap(heap.begin(), heap.end()); + } } } } @@ -105,6 +134,8 @@ std::vector IVFFlatIndex::search_batch(const Dataset& queries, siz const size_t nq = queries.size(); std::vector results(nq); + // Parallel batch search with dynamic scheduling + #pragma omp parallel for schedule(dynamic) for (size_t i = 0; i < nq; ++i) { results[i] = search(queries[i], k); }