Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
CXX := g++
CXXFLAGS := -std=c++17 -O3 -fPIC
CXXFLAGS := -std=c++17 -O3 -fPIC -fopenmp
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added -fopenmp for compilation


# Python / pybind11 include flags
PYBIND11_INCLUDES := $(shell python3 -m pybind11 --includes)
Expand Down Expand Up @@ -40,6 +40,9 @@ else
LDFLAGS := -Wl,-rpath,$$ORIGIN/../extern/faiss/build/install/lib
endif

# Add OpenMP linking
LDFLAGS += -fopenmp
Comment on lines +43 to +44
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added -fopenmp for linking


.PHONY: all clean prepare

all: prepare $(TARGET)
Expand Down
57 changes: 44 additions & 13 deletions src/IVFFlatIndex.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "IVFFlatIndex.h"
#include "SimdUtils.h"
#include <omp.h>
#include <limits>
#include <random>
#include <algorithm>
Expand Down Expand Up @@ -47,35 +48,63 @@ SearchResult IVFFlatIndex::search(const Vector& query, size_t k) const {
std::vector<Pair> cdist(nlist_);
std::vector<Pair> heap;

// calculate distance from centroids
// Calculate distance from query to all centroids (parallelized)
#pragma omp parallel for schedule(static)
for (size_t c = 0; c < nlist_; ++c) {
float d = l2_naive(query.data(), centroids_[c].data(), dimension_);
cdist[c] = {d, c};
}
Comment on lines +52 to 56
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

質心 query 分給多個 threads

std::partial_sort(cdist.begin(), cdist.begin() + nprobe_, cdist.end(),
[](auto& a, auto& b) {
return a.first < b.first;

// Select top-nprobe nearest centroids
std::partial_sort(cdist.begin(), cdist.begin() + nprobe_, cdist.end(),
[](auto& a, auto& b) {
return a.first < b.first;
}
);

// probe nprobe lists
// Probe nprobe nearest lists in parallel
// Each thread maintains a local heap, then merges into global heap
heap.reserve(k);
const auto& data = datastore_->getAll();

#pragma omp parallel for schedule(dynamic)
for (size_t pi = 0; pi < nprobe_; ++pi) {
size_t c = cdist[pi].second;

// Thread-local heap for this cluster
std::vector<Pair> local;
local.reserve(k);

// Search within this cluster's inverted list
for (size_t id : lists_[c]) {
float dist = l2_naive(query.data(), data[id].data(), dimension_);

if (heap.size() < k) {
heap.emplace_back(dist, id);
if (heap.size() == k)
std::make_heap(heap.begin(), heap.end());
} else if (dist < heap.front().first) {
std::pop_heap(heap.begin(), heap.end());
heap.back() = {dist, id};
std::push_heap(heap.begin(), heap.end());
if (local.size() < k) {
local.emplace_back(dist, id);
if (local.size() == k) {
std::make_heap(local.begin(), local.end());
}
} else if (dist < local.front().first) {
std::pop_heap(local.begin(), local.end());
local.back() = {dist, id};
std::push_heap(local.begin(), local.end());
}
}
Comment on lines -71 to +92
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Search within this cluster's inverted list , 將原本 heap, 改成讓 openMP 每個 thread 各自有自己的 heap (local)


// Merge local results into global heap (thread-safe)
#pragma omp critical
{
for (auto& p : local) {
if (heap.size() < k) {
heap.push_back(p);
if (heap.size() == k) {
std::make_heap(heap.begin(), heap.end());
}
} else if (p.first < heap.front().first) {
std::pop_heap(heap.begin(), heap.end());
heap.back() = p;
std::push_heap(heap.begin(), heap.end());
}
Comment on lines +94 to +107
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

將 local 資料合併

}
}
}
Expand Down Expand Up @@ -105,6 +134,8 @@ std::vector<SearchResult> IVFFlatIndex::search_batch(const Dataset& queries, siz
const size_t nq = queries.size();
std::vector<SearchResult> results(nq);

// Parallel batch search with dynamic scheduling
#pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < nq; ++i) {
results[i] = search(queries[i], k);
}
Expand Down