-
Notifications
You must be signed in to change notification settings - Fork 1
Add OpenMP parallelization to IVFFlatIndex #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| CXX := g++ | ||
| CXXFLAGS := -std=c++17 -O3 -fPIC | ||
| CXXFLAGS := -std=c++17 -O3 -fPIC -fopenmp | ||
|
|
||
| # Python / pybind11 include flags | ||
| PYBIND11_INCLUDES := $(shell python3 -m pybind11 --includes) | ||
|
|
@@ -40,6 +40,9 @@ else | |
| LDFLAGS := -Wl,-rpath,$$ORIGIN/../extern/faiss/build/install/lib | ||
| endif | ||
|
|
||
| # Add OpenMP linking | ||
| LDFLAGS += -fopenmp | ||
|
Comment on lines
+43
to
+44
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
|
|
||
| .PHONY: all clean prepare | ||
|
|
||
| all: prepare $(TARGET) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| #include "IVFFlatIndex.h" | ||
| #include "SimdUtils.h" | ||
| #include <omp.h> | ||
| #include <limits> | ||
| #include <random> | ||
| #include <algorithm> | ||
|
|
@@ -47,35 +48,63 @@ SearchResult IVFFlatIndex::search(const Vector& query, size_t k) const { | |
| std::vector<Pair> cdist(nlist_); | ||
| std::vector<Pair> heap; | ||
|
|
||
| // calculate distance from centroids | ||
| // Calculate distance from query to all centroids (parallelized) | ||
| #pragma omp parallel for schedule(static) | ||
| for (size_t c = 0; c < nlist_; ++c) { | ||
| float d = l2_naive(query.data(), centroids_[c].data(), dimension_); | ||
| cdist[c] = {d, c}; | ||
| } | ||
|
Comment on lines
+52
to
56
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 質心 query 分給多個 threads |
||
| std::partial_sort(cdist.begin(), cdist.begin() + nprobe_, cdist.end(), | ||
| [](auto& a, auto& b) { | ||
| return a.first < b.first; | ||
|
|
||
| // Select top-nprobe nearest centroids | ||
| std::partial_sort(cdist.begin(), cdist.begin() + nprobe_, cdist.end(), | ||
| [](auto& a, auto& b) { | ||
| return a.first < b.first; | ||
| } | ||
| ); | ||
|
|
||
| // probe nprobe lists | ||
| // Probe nprobe nearest lists in parallel | ||
| // Each thread maintains a local heap, then merges into global heap | ||
| heap.reserve(k); | ||
| const auto& data = datastore_->getAll(); | ||
|
|
||
| #pragma omp parallel for schedule(dynamic) | ||
| for (size_t pi = 0; pi < nprobe_; ++pi) { | ||
| size_t c = cdist[pi].second; | ||
|
|
||
| // Thread-local heap for this cluster | ||
| std::vector<Pair> local; | ||
| local.reserve(k); | ||
|
|
||
| // Search within this cluster's inverted list | ||
| for (size_t id : lists_[c]) { | ||
| float dist = l2_naive(query.data(), data[id].data(), dimension_); | ||
|
|
||
| if (heap.size() < k) { | ||
| heap.emplace_back(dist, id); | ||
| if (heap.size() == k) | ||
| std::make_heap(heap.begin(), heap.end()); | ||
| } else if (dist < heap.front().first) { | ||
| std::pop_heap(heap.begin(), heap.end()); | ||
| heap.back() = {dist, id}; | ||
| std::push_heap(heap.begin(), heap.end()); | ||
| if (local.size() < k) { | ||
| local.emplace_back(dist, id); | ||
| if (local.size() == k) { | ||
| std::make_heap(local.begin(), local.end()); | ||
| } | ||
| } else if (dist < local.front().first) { | ||
| std::pop_heap(local.begin(), local.end()); | ||
| local.back() = {dist, id}; | ||
| std::push_heap(local.begin(), local.end()); | ||
| } | ||
| } | ||
|
Comment on lines
-71
to
+92
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Search within this cluster's inverted list , 將原本 heap, 改成讓 openMP 每個 thread 各自有自己的 heap (local) |
||
|
|
||
| // Merge local results into global heap (thread-safe) | ||
| #pragma omp critical | ||
| { | ||
| for (auto& p : local) { | ||
| if (heap.size() < k) { | ||
| heap.push_back(p); | ||
| if (heap.size() == k) { | ||
| std::make_heap(heap.begin(), heap.end()); | ||
| } | ||
| } else if (p.first < heap.front().first) { | ||
| std::pop_heap(heap.begin(), heap.end()); | ||
| heap.back() = p; | ||
| std::push_heap(heap.begin(), heap.end()); | ||
| } | ||
|
Comment on lines
+94
to
+107
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 將 local 資料合併 |
||
| } | ||
| } | ||
| } | ||
|
|
@@ -105,6 +134,8 @@ std::vector<SearchResult> IVFFlatIndex::search_batch(const Dataset& queries, siz | |
| const size_t nq = queries.size(); | ||
| std::vector<SearchResult> results(nq); | ||
|
|
||
| // Parallel batch search with dynamic scheduling | ||
| #pragma omp parallel for schedule(dynamic) | ||
| for (size_t i = 0; i < nq; ++i) { | ||
| results[i] = search(queries[i], k); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
-fopenmpfor compilation