adds proper multi-threading
This commit is contained in:
parent
5fc5ba6032
commit
fe2af2978b
@ -2,6 +2,8 @@ project(RdaReader)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
add_library(RdaReader src/RdaReader src/SharedIndex.cpp src/BucketedZstdData.cpp)
|
||||
target_include_directories(RdaReader PUBLIC include)
|
||||
target_link_libraries(RdaReader zstd)
|
||||
target_link_libraries(RdaReader PRIVATE zstd Threads::Threads)
|
@ -6,6 +6,8 @@
|
||||
#include <execution>
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <optional>
|
||||
#include <thread>
|
||||
|
||||
void RdaReader::log(const std::string &string) {
|
||||
if(logger) logger(string);
|
||||
@ -37,19 +39,40 @@ size_t RdaReader::readRda(std::istream &rda, uint64_t id, std::ostream &output)
|
||||
return readRda(rda, id, output, dummyMutex);
|
||||
}
|
||||
|
||||
std::optional<std::istream *> getVictimIfAvailable(std::vector<std::istream *> &victims, std::mutex &victimMutex) {
|
||||
std::lock_guard lock(victimMutex);
|
||||
if(victims.size() > 0) {
|
||||
std::istream *tmp = victims.back();
|
||||
victims.pop_back();
|
||||
return tmp;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
size_t RdaReader::readDataset(const std::vector<char> &datasetName, const std::vector<char> &sharedIndex, const std::vector<std::istream *> &rdas, std::ostream &output, std::mutex &outputMutex) {
|
||||
log("Reading shared index... ");
|
||||
SharedIndex sharedIndexReader(sharedIndex);
|
||||
|
||||
std::vector<std::istream *> victims = rdas;
|
||||
std::mutex victimMutex;
|
||||
|
||||
if(std::optional<std::uint64_t> id = sharedIndexReader.getID(datasetName)) {
|
||||
log("Found ID: " + std::to_string(id.value()) + '\n');
|
||||
std::atomic_size_t totalEntries(0);
|
||||
std::for_each(
|
||||
std::execution::par,
|
||||
rdas.begin(),
|
||||
rdas.end(),
|
||||
[this, &totalEntries, &id, &output, &outputMutex](std::istream * const &rda) {totalEntries += readRda(*rda, id.value(), output, outputMutex);}
|
||||
std::vector<std::thread> threads(std::thread::hardware_concurrency());
|
||||
for(int i = 0; i < std::thread::hardware_concurrency(); ++i) {
|
||||
threads[i] = std::thread(
|
||||
[&victims, &victimMutex, this, &totalEntries, &id, &output, &outputMutex](){
|
||||
for(std::optional<std::istream *> victim; victim = getVictimIfAvailable(victims, victimMutex); ) {
|
||||
totalEntries += readRda(*victim.value(), id.value(), output, outputMutex);
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
for(std::thread &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
return totalEntries;
|
||||
}
|
||||
log("No entries found\n");
|
||||
|
Loading…
Reference in New Issue
Block a user