diff --git a/src/spices.cpp b/src/spices.cpp index 19b333d..b406e1b 100644 --- a/src/spices.cpp +++ b/src/spices.cpp @@ -4,6 +4,10 @@ #include #include #include +#include + +#include "SharedIndex.hpp" +#include "BucketedZstdData.hpp" #ifdef _WIN32 #include @@ -27,8 +31,6 @@ std::mutex cerrMutex; fprintf(stderr, __VA_ARGS__); \ } -#include "SharedIndex.hpp" -#include "BucketedZstdData.hpp" std::optional> readSharedIndex(const std::filesystem::path &filePath) { std::ifstream sharedIndexFile(filePath, std::ios::binary | std::ios::ate); @@ -42,6 +44,23 @@ std::optional> readSharedIndex(const std::filesystem::path &fi return sharedIndexData; } +void processRDA(const std::filesystem::directory_entry &file, std::atomic_size_t &totalEntries, uint64_t id, std::mutex &outputMutex) { + if(file.path().extension() == ".rda") { + std::ifstream fileStream(file.path(), std::ios::binary); + CERRLOG("Reading %s\n", std::string(file.path().filename()).c_str()); + BucketedZstdData bucket(fileStream); + + if(std::optional>> data = bucket.getEntriesByID(id)) { + totalEntries += data.value().size(); + const std::lock_guard lock(outputMutex); + CERRLOG("Writing %s\n", std::string(file.path().filename()).c_str()); + for(const auto &entry : data.value()) { + std::cout.write(entry.data(), entry.size()) << '\n'; + } + } + } +} + int main(int argc, char **argv) { if(argc != 3 && argc != 4) { std::cerr << "usage: subreddit rootdirectory [force write to terminal(true | false)]" << std::endl; @@ -58,13 +77,13 @@ int main(int argc, char **argv) { std::cerr << "Loading shared index..." << std::flush; std::vector sharedIndexData; - if(auto data = readSharedIndex(sharedIndexPath); data.has_value()) { - sharedIndexData.swap(data.value()); + if(auto data = readSharedIndex(sharedIndexPath)) { + sharedIndexData = data.value(); } else { std::cerr << "cannot find '" << sharedIndexPath << "'" << std::endl; - return 1; + return 2; } - std::cerr << "Loaded shared index" << std::endl; + std::cerr << "Loaded shared index" << std::endl; std::string datesetString = argv[1]; std::vector datasetName(datesetString.begin(), datesetString.end()); @@ -72,34 +91,21 @@ int main(int argc, char **argv) { SharedIndex sharedIndex(sharedIndexData); std::cerr << "Fetching ID from shared index... " << std::flush; - size_t totalEntries = 0; - if(auto id = sharedIndex.getID(datasetName)) { + std::atomic_size_t totalEntries(0); + if(std::optional id = sharedIndex.getID(datasetName)) { std::cerr << "Found ID: " << id.value() << std::endl; - std::mutex outputMutex; std::for_each( std::execution::par, std::filesystem::begin(std::filesystem::directory_iterator(rootDirectory)), std::filesystem::end(std::filesystem::directory_iterator()), - [&totalEntries, &id, &outputMutex](const auto& file) - { - if(file.path().extension() == ".rda") { - std::ifstream fileStream(file.path(), std::ios::binary); - CERRLOG("Reading %s\n", std::string(file.path().filename()).c_str()); - BucketedZstdData bucket(fileStream); - - if(auto data = bucket.getEntriesByID(id.value()); data.has_value()) { - totalEntries += data.value().size(); - const std::lock_guard lock(outputMutex); - CERRLOG("Writing %s\n", std::string(file.path().filename()).c_str()); - for(const auto &entry : data.value()) { - std::cout.write(entry.data(), entry.size()) << '\n'; - } - } - } - }); + [&totalEntries, &id, &outputMutex](const auto& file) {processRDA(file, totalEntries, id.value(), outputMutex);} + ); std::cerr << "Found a total of " << totalEntries << " entries" << std::endl; + } else { + std::cerr << "Cannot find '" << argv[1] << "' in the shared index" << std::endl; + return 2; } } \ No newline at end of file