diff --git a/src/RdaReader/CMakeLists.txt b/src/RdaReader/CMakeLists.txt index 7d8743a..ec23afe 100644 --- a/src/RdaReader/CMakeLists.txt +++ b/src/RdaReader/CMakeLists.txt @@ -2,6 +2,8 @@ project(RdaReader) set(CMAKE_CXX_STANDARD 20) +find_package(Threads REQUIRED) + add_library(RdaReader src/RdaReader src/SharedIndex.cpp src/BucketedZstdData.cpp) target_include_directories(RdaReader PUBLIC include) -target_link_libraries(RdaReader zstd) \ No newline at end of file +target_link_libraries(RdaReader PRIVATE zstd Threads::Threads) \ No newline at end of file diff --git a/src/RdaReader/src/RdaReader.cpp b/src/RdaReader/src/RdaReader.cpp index d342329..997d148 100644 --- a/src/RdaReader/src/RdaReader.cpp +++ b/src/RdaReader/src/RdaReader.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include void RdaReader::log(const std::string &string) { if(logger) logger(string); @@ -37,19 +39,40 @@ size_t RdaReader::readRda(std::istream &rda, uint64_t id, std::ostream &output) return readRda(rda, id, output, dummyMutex); } +std::optional getVictimIfAvailable(std::vector &victims, std::mutex &victimMutex) { + std::lock_guard lock(victimMutex); + if(victims.size() > 0) { + std::istream *tmp = victims.back(); + victims.pop_back(); + return tmp; + } + return {}; +} + size_t RdaReader::readDataset(const std::vector &datasetName, const std::vector &sharedIndex, const std::vector &rdas, std::ostream &output, std::mutex &outputMutex) { log("Reading shared index... "); SharedIndex sharedIndexReader(sharedIndex); + std::vector victims = rdas; + std::mutex victimMutex; + if(std::optional id = sharedIndexReader.getID(datasetName)) { log("Found ID: " + std::to_string(id.value()) + '\n'); std::atomic_size_t totalEntries(0); - std::for_each( - std::execution::par, - rdas.begin(), - rdas.end(), - [this, &totalEntries, &id, &output, &outputMutex](std::istream * const &rda) {totalEntries += readRda(*rda, id.value(), output, outputMutex);} - ); + std::vector threads(std::thread::hardware_concurrency()); + for(int i = 0; i < std::thread::hardware_concurrency(); ++i) { + threads[i] = std::thread( + [&victims, &victimMutex, this, &totalEntries, &id, &output, &outputMutex](){ + for(std::optional victim; victim = getVictimIfAvailable(victims, victimMutex); ) { + totalEntries += readRda(*victim.value(), id.value(), output, outputMutex); + } + } + ); + } + for(std::thread &thread : threads) { + thread.join(); + } + return totalEntries; } log("No entries found\n");