diff --git a/algo/detectors/trd/Hitfind.cxx b/algo/detectors/trd/Hitfind.cxx index f1e3efc0e870d8fc3c89aa12d569caa0566f86ab..36670ae31f267c276c053064da5db0af091a1b81 100644 --- a/algo/detectors/trd/Hitfind.cxx +++ b/algo/detectors/trd/Hitfind.cxx @@ -13,6 +13,12 @@ using namespace std; using fles::Subsystem; +// By default all steps (clusterizing, hit building, hit merging) are parallelized +// by row index. Hit merging is thereby only done between pairs of neighboring rows, +// which can miss some large digi clusters. Enable flag below to instead parallelize +// the last step (hit merging) by module index. +#define MERGE_BY_MODULE + namespace cbm::algo::trd { // ----- Constructor ------------------------------------------------------ @@ -102,9 +108,19 @@ namespace cbm::algo::trd } // ---------------------------------------------------------------------------- - - // ----- Execution ------------------------------------------------------- + // ----- Execution -------------------------------------------------------- Hitfind::resultType Hitfind::operator()(gsl::span<CbmTrdDigi> digiIn) + { +#ifdef MERGE_BY_MODULE + return RunModuleParallelMerge(digiIn); +#else + return RunRowParallel(digiIn); +#endif + } + // ---------------------------------------------------------------------------- + + // ----- Execution fully row parallel -------------------------------------- + Hitfind::resultType Hitfind::RunRowParallel(gsl::span<CbmTrdDigi> digiIn) { constexpr bool DebugCheckInput = true; @@ -298,4 +314,182 @@ namespace cbm::algo::trd } // ---------------------------------------------------------------------------- + + // ----- Execution merging module parallel --------------------------------- + Hitfind::resultType Hitfind::RunModuleParallelMerge(gsl::span<CbmTrdDigi> digiIn) + { + constexpr bool DebugCheckInput = true; + + // --- Output data + resultType result = {}; + auto& hitsOut = std::get<0>(result); + auto& monitor = std::get<1>(result); + + // Intermediate digi storage variables (digi, index) per module and row + std::unordered_map<int, std::vector<std::vector<std::pair<CbmTrdDigi, int32_t>>>> digiBuffer; //[modAddress][row] + + // Intermediate hits per module and row + std::unordered_map<int, std::vector<std::vector<hitDataType>>> hitBuffer; //[modAddress][row] + + // Initialize storage buffers + for (size_t mod = 0; mod < fModList.size(); mod++) { + const int address = std::get<0>(fModList[mod]); + const size_t numRows = std::get<2>(fModList[mod]); + digiBuffer[address].resize(numRows); + hitBuffer[address].resize(numRows); + } + + // Loop over the digis array and store the digis in separate vectors for + // each module and row + xpu::push_timer("DigiModuleSort"); + for (size_t idigi = 0; idigi < digiIn.size(); idigi++) { + const CbmTrdDigi* digi = &digiIn[idigi]; + const int address = digi->GetAddressModule(); + if constexpr (DebugCheckInput) { + auto modInfo = + std::find_if(fModList.begin(), fModList.end(), [&](auto m) { return std::get<0>(m) == address; }); + if (modInfo == fModList.end()) { + L_(error) << "TRD: Unknown module ID"; + continue; + } + bool digiIs2D = digi->IsFASP(); + if (std::get<1>(*modInfo) != digiIs2D) { + L_(error) << "TRD: Module + Digi type mismatch: " << std::get<0>(*modInfo) << ": " << std::get<1>(*modInfo) + << " " << digiIs2D; + continue; + } + } + const size_t modId = fModId[address]; + const size_t numCols = std::get<3>(fModList[modId]); + const int row = digi->GetAddressChannel() / numCols; + digiBuffer[address][row].emplace_back(*digi, idigi); + } + monitor.sortTime = xpu::pop_timer(); + + xpu::push_timer("BuildClusters"); + xpu::t_add_bytes(digiIn.size_bytes()); + + // Cluster building + CBM_PARALLEL_FOR(schedule(dynamic)) + for (size_t row = 0; row < fRowList.size(); row++) { + + const int address = std::get<0>(fRowList[row]); + const bool is2D = std::get<1>(fRowList[row]); + const size_t rowInMod = std::get<2>(fRowList[row]); + const auto& digiInput = digiBuffer[address][rowInMod]; + if (is2D) { + auto clusters = (*fClusterBuild2d[address])(digiInput, 0.); // Number is TS start time (T0) + hitBuffer[address][rowInMod] = (*fHitFind2d[address])(&clusters); + } + else { + auto clusters = (*fClusterBuild[address])(digiInput); + hitBuffer[address][rowInMod] = (*fHitFind[address])(&clusters); + } + } + monitor.timeClusterize = xpu::pop_timer(); + + // Hit finding + PODVector<Hit> hitsFlat; // hit storage + PODVector<size_t> modSizes; // nHits per modules + PODVector<uint> modAddresses; // address of modules + + // Prefix array for parallelization + std::vector<size_t> hitsPrefix; + std::vector<size_t> sizePrefix; + std::vector<size_t> addrPrefix; + + xpu::push_timer("FindHits"); + CBM_PARALLEL() + { + const int ithread = openmp::GetThreadNum(); + const int nthreads = openmp::GetNumThreads(); + + CBM_OMP(single) + { + hitsPrefix.resize(nthreads + 1); + sizePrefix.resize(nthreads + 1); + addrPrefix.resize(nthreads + 1); + } + + std::vector<Hit> local_hits; + std::vector<size_t> local_sizes; + std::vector<uint> local_addresses; + + CBM_OMP(for schedule(dynamic) nowait) + for (size_t mod = 0; mod < fModList.size(); mod++) { + const int address = std::get<0>(fModList[mod]); + const bool is2D = std::get<1>(fModList[mod]); + + // Lambda expression for vector concatenation + auto concatVec = [](auto& acc, const auto& innerVec) { + acc.insert(acc.end(), innerVec.begin(), innerVec.end()); + return std::move(acc); + }; + + + // Flatten the input vector of vectors and merge hits + auto& hitbuffer = hitBuffer[address]; + auto hitData = std::accumulate(hitbuffer.begin(), hitbuffer.end(), std::vector<hitDataType>(), concatVec); + + std::vector<hitDataType> mod_hitdata; + std::vector<hitDataType> dummy; + if (is2D) { + mod_hitdata = (*fHitMerge2d[address])(hitData, dummy).first; + } + else { + mod_hitdata = (*fHitMerge[address])(hitData, dummy).first; + } + std::vector<Hit> mod_hits; + std::transform(mod_hitdata.begin(), mod_hitdata.end(), std::back_inserter(mod_hits), + [](const auto& p) { return p.first; }); + + // store partition size + local_sizes.push_back(mod_hits.size()); + + // store hw address of partition + local_addresses.push_back(address); + + // Append clusters to output + local_hits.insert(local_hits.end(), std::make_move_iterator(mod_hits.begin()), + std::make_move_iterator(mod_hits.end())); + } + + hitsPrefix[ithread + 1] = local_hits.size(); + sizePrefix[ithread + 1] = local_sizes.size(); + addrPrefix[ithread + 1] = local_addresses.size(); + CBM_OMP(barrier) + CBM_OMP(single) + { + for (int i = 1; i < (nthreads + 1); i++) { + hitsPrefix[i] += hitsPrefix[i - 1]; + sizePrefix[i] += sizePrefix[i - 1]; + addrPrefix[i] += addrPrefix[i - 1]; + } + hitsFlat.resize(hitsPrefix[nthreads]); + modSizes.resize(sizePrefix[nthreads]); + modAddresses.resize(addrPrefix[nthreads]); + } + std::move(local_hits.begin(), local_hits.end(), hitsFlat.begin() + hitsPrefix[ithread]); + std::move(local_sizes.begin(), local_sizes.end(), modSizes.begin() + sizePrefix[ithread]); + std::move(local_addresses.begin(), local_addresses.end(), modAddresses.begin() + addrPrefix[ithread]); + } + // Monitoring + monitor.timeHitfind = xpu::pop_timer(); + monitor.numDigis = digiIn.size(); + monitor.numHits = hitsFlat.size(); + + // Create ouput vector + hitsOut = PartitionedVector(std::move(hitsFlat), modSizes, modAddresses); + + // Ensure hits are time sorted + CBM_PARALLEL_FOR(schedule(dynamic)) + for (size_t i = 0; i < hitsOut.NPartitions(); i++) { + auto part = hitsOut[i]; + std::sort(part.begin(), part.end(), [](const auto& h0, const auto& h1) { return h0.Time() < h1.Time(); }); + } + + return result; + } + // ---------------------------------------------------------------------------- + } // namespace cbm::algo::trd diff --git a/algo/detectors/trd/Hitfind.h b/algo/detectors/trd/Hitfind.h index ee2566317fc1f1fed5839b403170cb438e62cf63..acf2138d0aab6276295bf690873b62e273d24149 100644 --- a/algo/detectors/trd/Hitfind.h +++ b/algo/detectors/trd/Hitfind.h @@ -68,6 +68,12 @@ namespace cbm::algo::trd **/ resultType operator()(gsl::span<CbmTrdDigi> digiIn); + /** @brief Run all steps row-parallel **/ + resultType RunRowParallel(gsl::span<CbmTrdDigi> digiIn); + + /** @brief Run merge step module-parallel all others row-parallel **/ + resultType RunModuleParallelMerge(gsl::span<CbmTrdDigi> digiIn); + /** @brief Constructor **/ explicit Hitfind(trd::HitfindSetup, trd::Hitfind2DSetup);