adds proper multi-threading
This commit is contained in:
		| @@ -2,6 +2,8 @@ project(RdaReader) | |||||||
|  |  | ||||||
| set(CMAKE_CXX_STANDARD 20) | set(CMAKE_CXX_STANDARD 20) | ||||||
|  |  | ||||||
|  | find_package(Threads REQUIRED) | ||||||
|  |  | ||||||
| add_library(RdaReader src/RdaReader src/SharedIndex.cpp src/BucketedZstdData.cpp) | add_library(RdaReader src/RdaReader src/SharedIndex.cpp src/BucketedZstdData.cpp) | ||||||
| target_include_directories(RdaReader PUBLIC include) | target_include_directories(RdaReader PUBLIC include) | ||||||
| target_link_libraries(RdaReader zstd) | target_link_libraries(RdaReader PRIVATE zstd Threads::Threads) | ||||||
| @@ -6,6 +6,8 @@ | |||||||
| #include <execution> | #include <execution> | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
| #include <atomic> | #include <atomic> | ||||||
|  | #include <optional> | ||||||
|  | #include <thread> | ||||||
|  |  | ||||||
| void RdaReader::log(const std::string &string) { | void RdaReader::log(const std::string &string) { | ||||||
| 	if(logger) logger(string); | 	if(logger) logger(string); | ||||||
| @@ -37,19 +39,40 @@ size_t RdaReader::readRda(std::istream &rda, uint64_t id, std::ostream &output) | |||||||
| 	return readRda(rda, id, output, dummyMutex); | 	return readRda(rda, id, output, dummyMutex); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | std::optional<std::istream *> getVictimIfAvailable(std::vector<std::istream *> &victims, std::mutex &victimMutex) { | ||||||
|  | 	std::lock_guard lock(victimMutex); | ||||||
|  | 	if(victims.size() > 0) { | ||||||
|  | 		std::istream *tmp = victims.back(); | ||||||
|  | 		victims.pop_back(); | ||||||
|  | 		return tmp; | ||||||
|  | 	} | ||||||
|  | 	return {}; | ||||||
|  | } | ||||||
|  |  | ||||||
| size_t RdaReader::readDataset(const std::vector<char> &datasetName, const std::vector<char> &sharedIndex, const std::vector<std::istream *> &rdas, std::ostream &output, std::mutex &outputMutex) { | size_t RdaReader::readDataset(const std::vector<char> &datasetName, const std::vector<char> &sharedIndex, const std::vector<std::istream *> &rdas, std::ostream &output, std::mutex &outputMutex) { | ||||||
| 	log("Reading shared index... "); | 	log("Reading shared index... "); | ||||||
| 	SharedIndex sharedIndexReader(sharedIndex); | 	SharedIndex sharedIndexReader(sharedIndex); | ||||||
|  |  | ||||||
|  | 	std::vector<std::istream *> victims = rdas; | ||||||
|  | 	std::mutex victimMutex; | ||||||
|  |  | ||||||
| 	if(std::optional<std::uint64_t> id = sharedIndexReader.getID(datasetName)) { | 	if(std::optional<std::uint64_t> id = sharedIndexReader.getID(datasetName)) { | ||||||
| 		log("Found ID: " + std::to_string(id.value()) + '\n'); | 		log("Found ID: " + std::to_string(id.value()) + '\n'); | ||||||
| 		std::atomic_size_t totalEntries(0); | 		std::atomic_size_t totalEntries(0); | ||||||
| 		std::for_each( | 		std::vector<std::thread> threads(std::thread::hardware_concurrency()); | ||||||
| 			std::execution::par, | 		for(int i = 0; i < std::thread::hardware_concurrency(); ++i) { | ||||||
| 			rdas.begin(), | 			threads[i] = std::thread( | ||||||
| 			rdas.end(), | 				[&victims, &victimMutex, this, &totalEntries, &id, &output, &outputMutex](){ | ||||||
| 			[this, &totalEntries, &id, &output, &outputMutex](std::istream * const &rda) {totalEntries += readRda(*rda, id.value(), output, outputMutex);} | 					for(std::optional<std::istream *> victim; victim = getVictimIfAvailable(victims, victimMutex); ) { | ||||||
|  | 						totalEntries += readRda(*victim.value(), id.value(), output, outputMutex); | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
| 			); | 			); | ||||||
|  | 		} | ||||||
|  | 		for(std::thread &thread : threads) { | ||||||
|  | 			thread.join(); | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		return totalEntries; | 		return totalEntries; | ||||||
| 	} | 	} | ||||||
| 	log("No entries found\n"); | 	log("No entries found\n"); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user