From f0c74832b6e9c27de395cedb8a0ffcc0c5587474 Mon Sep 17 00:00:00 2001
From: Felix Weiglhofer <weiglhofer@fias.uni-frankfurt.de>
Date: Mon, 19 Jun 2023 16:43:05 +0000
Subject: [PATCH] algo: Enable parallel sorting when libTBB is available.

---
 algo/CMakeLists.txt             | 15 ++++++++++
 algo/base/BuildInfo.h           |  7 +++++
 algo/base/compat/Algorithm.h    | 51 +++++++++++++++++++++++++++++++
 algo/unpack/Unpack.cxx          | 53 +++++++++++----------------------
 reco/tasks/CMakeLists.txt       | 21 -------------
 reco/tasks/CbmTaskUnpack.cxx    |  3 --
 reco/tasks/CbmTaskUnpackXpu.cxx | 15 ++++------
 7 files changed, 95 insertions(+), 70 deletions(-)
 create mode 100644 algo/base/compat/Algorithm.h

diff --git a/algo/CMakeLists.txt b/algo/CMakeLists.txt
index 9d8de75311..eb3ee186eb 100644
--- a/algo/CMakeLists.txt
+++ b/algo/CMakeLists.txt
@@ -87,6 +87,21 @@ target_link_libraries(Algo
 target_compile_definitions(Algo PUBLIC NO_ROOT)
 xpu_attach(Algo ${DEVICE_SRCS})
 
+# Try to enable parallel execution in c++17 if TBB is available
+if (CMAKE_SYSTEM_NAME STREQUAL "Linux")
+  list(APPEND CMAKE_PREFIX_PATH "/opt/intel/oneapi/tbb/latest/")
+  find_package(TBB)
+
+  if (TBB_FOUND)
+    message(STATUS "Found TBB")
+    add_compile_definitions(HAVE_TBB)
+    target_link_libraries(Algo PUBLIC TBB::tbb)
+  else()
+    message(STATUS "TBB not found")
+  endif()
+
+endif()
+
 install(TARGETS Algo DESTINATION lib)
 install(DIRECTORY base/compat TYPE INCLUDE FILES_MATCHING PATTERN "*.h")
 install(DIRECTORY base/config TYPE INCLUDE FILES_MATCHING PATTERN "*.h")
diff --git a/algo/base/BuildInfo.h b/algo/base/BuildInfo.h
index 8b6225c930..ddfe4bb44d 100644
--- a/algo/base/BuildInfo.h
+++ b/algo/base/BuildInfo.h
@@ -13,6 +13,13 @@ namespace cbm::algo::BuildInfo
   extern const std::string BUILD_TYPE;
   extern const bool GPU_DEBUG;
 
+  inline constexpr bool HAVE_TBB =
+#ifdef WITH_TBB
+    true;
+#else
+    false;
+#endif
+
 }  // namespace cbm::algo::BuildInfo
 
 #endif  // CBM_ALGO_BUILD_INFO_H
diff --git a/algo/base/compat/Algorithm.h b/algo/base/compat/Algorithm.h
new file mode 100644
index 0000000000..378cc57582
--- /dev/null
+++ b/algo/base/compat/Algorithm.h
@@ -0,0 +1,51 @@
+/* Copyright (C) 2023 FIAS Frankfurt Institute for Advanced Studies, Frankfurt / Main
+   SPDX-License-Identifier: GPL-3.0-only
+   Authors: Felix Weiglhofer [committer] */
+#ifndef CBM_ALGO_BASE_COMPAT_ALGORITHMS_H
+#define CBM_ALGO_BASE_COMPAT_ALGORITHMS_H
+
+/**
+ * @file Algorithms.h
+ * @brief This file contains compatibility wrappers for parallel stl algorithms.
+ *
+ * The parallel algorithms are only available if the compiler supports C++17. Some older
+ * compilers don't ship with the parallel algorithms, so this wrapper falls back to
+ * sequential algorithms in that case.
+ * Also gcc requires the TBB library to be installed to use the parallel algorithms.
+ * If TBB is not available, we also falls back to sequential algorithms.
+**/
+
+#include <algorithm>
+#if __has_include(<execution>)
+#define WITH_EXECUTION
+#include <execution>
+#endif
+
+namespace cbm::algo
+{
+
+  namespace detail
+  {
+#ifdef WITH_EXECUTION
+    inline constexpr auto ExecPolicy =
+#ifdef HAVE_TBB
+      std::execution::par_unseq;
+#else
+      std::execution::seq;
+#endif  // HAVE_TBB
+#endif  // WITH_EXECUTION
+  }  // namespace detail
+
+  template<typename It, typename Compare>
+  void ParallelSort(It first, It last, Compare comp)
+  {
+#ifdef WITH_EXECUTION
+    std::sort(detail::ExecPolicy, first, last, comp);
+#else
+    std::sort(first, last, comp);
+#endif
+  }
+}  // namespace cbm::algo
+
+
+#endif
diff --git a/algo/unpack/Unpack.cxx b/algo/unpack/Unpack.cxx
index b6c9b40d79..a4000d7c60 100644
--- a/algo/unpack/Unpack.cxx
+++ b/algo/unpack/Unpack.cxx
@@ -8,10 +8,7 @@
 #include <chrono>
 
 #include "AlgoFairloggerCompat.h"
-
-#ifdef WITH_EXECUTION
-#include <execution>
-#endif
+#include "compat/Algorithm.h"
 
 using namespace std;
 
@@ -62,38 +59,22 @@ namespace cbm::algo
       }
     }  //# component
 
-    // --- Sorting of output digis. Is required by both digi trigger and event builder.
-#ifdef WITH_EXECUTION
-    std::sort(std::execution::par_unseq, digiTs.fData.fSts.fDigis.begin(), digiTs.fData.fSts.fDigis.end(),
-              [](CbmStsDigi digi1, CbmStsDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(std::execution::par_unseq, digiTs.fData.fMuch.fDigis.begin(), digiTs.fData.fMuch.fDigis.end(),
-              [](CbmMuchDigi digi1, CbmMuchDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(std::execution::par_unseq, digiTs.fData.fTof.fDigis.begin(), digiTs.fData.fTof.fDigis.end(),
-              [](CbmTofDigi digi1, CbmTofDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(std::execution::par_unseq, digiTs.fData.fT0.fDigis.begin(), digiTs.fData.fT0.fDigis.end(),
-              [](CbmTofDigi digi1, CbmTofDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(std::execution::par_unseq, digiTs.fData.fTrd.fDigis.begin(), digiTs.fData.fTrd.fDigis.end(),
-              [](CbmTrdDigi digi1, CbmTrdDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(std::execution::par_unseq, digiTs.fData.fTrd2d.fDigis.begin(), digiTs.fData.fTrd2d.fDigis.end(),
-              [](CbmTrdDigi digi1, CbmTrdDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(std::execution::par_unseq, digiTs.fData.fRich.fDigis.begin(), digiTs.fData.fRich.fDigis.end(),
-              [](CbmRichDigi digi1, CbmRichDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-#else
-    std::sort(digiTs.fData.fSts.fDigis.begin(), digiTs.fData.fSts.fDigis.end(),
-              [](CbmStsDigi digi1, CbmStsDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(digiTs.fData.fMuch.fDigis.begin(), digiTs.fData.fMuch.fDigis.end(),
-              [](CbmMuchDigi digi1, CbmMuchDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(digiTs.fData.fTof.fDigis.begin(), digiTs.fData.fTof.fDigis.end(),
-              [](CbmTofDigi digi1, CbmTofDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(digiTs.fData.fT0.fDigis.begin(), digiTs.fData.fT0.fDigis.end(),
-              [](CbmTofDigi digi1, CbmTofDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(digiTs.fData.fTrd.fDigis.begin(), digiTs.fData.fTrd.fDigis.end(),
-              [](CbmTrdDigi digi1, CbmTrdDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(digiTs.fData.fTrd2d.fDigis.begin(), digiTs.fData.fTrd2d.fDigis.end(),
-              [](CbmTrdDigi digi1, CbmTrdDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-    std::sort(digiTs.fData.fRich.fDigis.begin(), digiTs.fData.fRich.fDigis.end(),
-              [](CbmRichDigi digi1, CbmRichDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-#endif
+       // --- Sorting of output digis. Is required by both digi trigger and event builder.
+    ParallelSort(digiTs.fData.fSts.fDigis.begin(), digiTs.fData.fSts.fDigis.end(),
+                 [](CbmStsDigi digi1, CbmStsDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
+    ParallelSort(digiTs.fData.fMuch.fDigis.begin(), digiTs.fData.fMuch.fDigis.end(),
+                 [](CbmMuchDigi digi1, CbmMuchDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
+    ParallelSort(digiTs.fData.fTof.fDigis.begin(), digiTs.fData.fTof.fDigis.end(),
+                 [](CbmTofDigi digi1, CbmTofDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
+    ParallelSort(digiTs.fData.fT0.fDigis.begin(), digiTs.fData.fT0.fDigis.end(),
+                 [](CbmTofDigi digi1, CbmTofDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
+    ParallelSort(digiTs.fData.fTrd.fDigis.begin(), digiTs.fData.fTrd.fDigis.end(),
+                 [](CbmTrdDigi digi1, CbmTrdDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
+    ParallelSort(digiTs.fData.fTrd2d.fDigis.begin(), digiTs.fData.fTrd2d.fDigis.end(),
+                 [](CbmTrdDigi digi1, CbmTrdDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
+    ParallelSort(digiTs.fData.fRich.fDigis.begin(), digiTs.fData.fRich.fDigis.end(),
+                 [](CbmRichDigi digi1, CbmRichDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
+
     return result;
   }
   // ----------------------------------------------------------------------------
diff --git a/reco/tasks/CMakeLists.txt b/reco/tasks/CMakeLists.txt
index 135c5a4f98..6fdef364b6 100644
--- a/reco/tasks/CMakeLists.txt
+++ b/reco/tasks/CMakeLists.txt
@@ -48,25 +48,4 @@ set(INTERFACE_DEPENDENCIES
   external::fles_ipc
   )
 
-# Check if the compiler supports std::execution in the respective STL
-# library
-CHECK_CXX_SOURCE_COMPILES("
-#include <numeric>
-#include <vector>
-#include <execution>
-
-int main(int argc, char *argv[])
-{
-    std::vector<double> v(10, 1);
-
-    auto result = std::reduce(std::execution::par, v.begin(), v.end());
-    return 0;
-}" HAS_STD_EXECUTION)
-
-if (HAS_STD_EXECUTION)
-  message("Execution is available in STL")
-  add_definitions(-DWITH_EXECUTION)
-endif()
-
 generate_cbm_library()
-
diff --git a/reco/tasks/CbmTaskUnpack.cxx b/reco/tasks/CbmTaskUnpack.cxx
index 6724a37b3c..c45ff9e312 100644
--- a/reco/tasks/CbmTaskUnpack.cxx
+++ b/reco/tasks/CbmTaskUnpack.cxx
@@ -29,9 +29,6 @@
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
-#ifdef WITH_EXECUTION
-#include <execution>
-#endif
 #include <iomanip>
 #include <memory>
 #include <sstream>
diff --git a/reco/tasks/CbmTaskUnpackXpu.cxx b/reco/tasks/CbmTaskUnpackXpu.cxx
index f095169d69..5fac8e8bb6 100644
--- a/reco/tasks/CbmTaskUnpackXpu.cxx
+++ b/reco/tasks/CbmTaskUnpackXpu.cxx
@@ -22,9 +22,6 @@
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
-#ifdef WITH_EXECUTION
-#include <execution>
-#endif
 #include <iomanip>
 #include <memory>
 #include <sstream>
@@ -32,6 +29,8 @@
 
 #include <xpu/host.h>
 
+#include "compat/Algorithm.h"
+
 using namespace std;
 using cbm::algo::UnpackStsXpuElinkPar;
 using cbm::algo::UnpackStsXpuPar;
@@ -70,13 +69,9 @@ void CbmTaskUnpackXpu::Exec(Option_t*)
                                        resultSts.first.end());
 
   // --- Sorting of output digis. Is required by both digi trigger and event builder.
-#ifdef WITH_EXECUTION
-  std::sort(std::execution::par_unseq, fTimeslice->fData.fSts.fDigis.begin(), fTimeslice->fData.fSts.fDigis.end(),
-            [](CbmStsDigi digi1, CbmStsDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-#else
-  std::sort(fTimeslice->fData.fSts.fDigis.begin(), fTimeslice->fData.fSts.fDigis.end(),
-            [](CbmStsDigi digi1, CbmStsDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
-#endif
+  cbm::algo::ParallelSort(fTimeslice->fData.fSts.fDigis.begin(), fTimeslice->fData.fSts.fDigis.end(),
+                          [](CbmStsDigi digi1, CbmStsDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
+
 
   // --- Timeslice log
   timer.Stop();
-- 
GitLab