From 35c6c621343232a9bbf50dd57f4ca23bb8de732a Mon Sep 17 00:00:00 2001
From: Felix Weiglhofer <weiglhofer@fias.uni-frankfurt.de>
Date: Tue, 14 Mar 2023 10:49:34 +0000
Subject: [PATCH] StsDigi: Make address-functions callable on GPU.

---
 core/data/CbmDefs.h             |  7 +--
 core/data/sts/CbmStsAddress.cxx | 82 ------------------------------
 core/data/sts/CbmStsAddress.h   | 88 +++++++++++++++++++++++++++------
 core/data/sts/CbmStsDigi.cxx    | 18 -------
 core/data/sts/CbmStsDigi.h      | 32 +++++++-----
 5 files changed, 97 insertions(+), 130 deletions(-)

diff --git a/core/data/CbmDefs.h b/core/data/CbmDefs.h
index b2a5c9405b..d55c7d9c45 100644
--- a/core/data/CbmDefs.h
+++ b/core/data/CbmDefs.h
@@ -12,10 +12,11 @@
 #ifndef CBMDEFS_H
 #define CBMDEFS_H 1
 
-#include <type_traits>  // for underlying_type
-
 #include <iostream>  // for ostream
 #include <string>
+#include <type_traits>  // for underlying_type
+
+#include <xpu/defines.h>  // for XPU_D
 
 // Convert an element of enum class to its underlying intergral type
 // since with C++11 the return type can't be deduced automatically it has
@@ -25,7 +26,7 @@
 // E.g. ToIntegralType(ECbmModuleId::KSts) should be evaluated at compile
 // time and should not affect the run time performance at all
 template<typename T>
-constexpr auto ToIntegralType(T enumerator) -> typename std::underlying_type<T>::type
+XPU_D constexpr auto ToIntegralType(T enumerator) -> typename std::underlying_type<T>::type
 {
   return static_cast<typename std::underlying_type<T>::type>(enumerator);
 }
diff --git a/core/data/sts/CbmStsAddress.cxx b/core/data/sts/CbmStsAddress.cxx
index e09090a54b..7af19304f5 100644
--- a/core/data/sts/CbmStsAddress.cxx
+++ b/core/data/sts/CbmStsAddress.cxx
@@ -16,62 +16,6 @@
 #include <cassert>  // for assert
 #include <sstream>  // for operator<<, basic_ostream, stringstream
 
-namespace CbmStsAddress::Detail
-{
-
-  // clang-format off
-  // -----    Definition of address bit field   ------------------------------
-  inline constexpr uint16_t kBits[kCurrentVersion + 1][kStsNofLevels] = {
-
-    // Version 0 (until 23 August 2017)
-    {
-      4,  // system
-      4,  // unit / station
-      4,  // ladder
-      1,  // half-ladder
-      3,  // module
-      2,  // sensor
-      1   // side
-    },
-
-    // Version 1 (current, since 23 August 2017)
-    {
-      4,  // system
-      6,  // unit
-      5,  // ladder
-      1,  // half-ladder
-      5,  // module
-      4,  // sensor
-      1   // side
-    }
-
-  };
-  // -------------------------------------------------------------------------
-
-
-  // -----    Bit shifts -----------------------------------------------------
-  inline constexpr int32_t kShift[kCurrentVersion + 1][kStsNofLevels] = {
-    {0, kShift[0][0] + kBits[0][0], kShift[0][1] + kBits[0][1], kShift[0][2] + kBits[0][2], kShift[0][3] + kBits[0][3],
-     kShift[0][4] + kBits[0][4], kShift[0][5] + kBits[0][5]},
-
-    {0, kShift[1][0] + kBits[1][0], kShift[1][1] + kBits[1][1], kShift[1][2] + kBits[1][2], kShift[1][3] + kBits[1][3],
-     kShift[1][4] + kBits[1][4], kShift[1][5] + kBits[1][5]}};
-  // -------------------------------------------------------------------------
-
-
-  // -----    Bit masks  -----------------------------------------------------
-  inline constexpr int32_t kMask[kCurrentVersion + 1][kStsNofLevels] = {
-    {(1 << kBits[0][0]) - 1, (1 << kBits[0][1]) - 1, (1 << kBits[0][2]) - 1, (1 << kBits[0][3]) - 1,
-     (1 << kBits[0][4]) - 1, (1 << kBits[0][5]) - 1, (1 << kBits[0][6]) - 1},
-
-    {(1 << kBits[1][0]) - 1, (1 << kBits[1][1]) - 1, (1 << kBits[1][2]) - 1, (1 << kBits[1][3]) - 1,
-     (1 << kBits[1][4]) - 1, (1 << kBits[1][5]) - 1, (1 << kBits[1][6]) - 1}};
-  // -------------------------------------------------------------------------
-  // clang-format on
-
-}  // Namespace CbmStsAddress::Detail
-
-
 // -----   Construct address from element Ids   ------------------------------
 int32_t CbmStsAddress::GetAddress(uint32_t unit, uint32_t ladder, uint32_t halfladder, uint32_t module, uint32_t sensor,
                                   uint32_t side, uint32_t version)
@@ -200,32 +144,6 @@ int32_t CbmStsAddress::SetElementId(int32_t address, int32_t level, uint32_t new
 }
 // -------------------------------------------------------------------------
 
-// -----   Pack Digi Address    --------------------------------------------
-int32_t CbmStsAddress::PackDigiAddress(int32_t address)
-{
-  using namespace Detail;
-  constexpr int32_t kDMask = kMask[1][kStsUnit] << kShift[1][kStsUnit] | kMask[1][kStsLadder] << kShift[1][kStsLadder]
-                             | kMask[1][kStsHalfLadder] << kShift[1][kStsHalfLadder]
-                             | kMask[1][kStsModule] << kShift[1][kStsModule];
-
-  int32_t ret = (address & kDMask) >> kShift[1][kStsUnit];
-
-  // Check that no bits were set, that are stripped by this function.
-  assert(address == UnpackDigiAddress(ret));
-
-  return ret;
-}
-// -------------------------------------------------------------------------
-
-// -----   Unpack Digi Address    -------------------------------------------
-int32_t CbmStsAddress::UnpackDigiAddress(int32_t digiAddress)
-{
-  using namespace Detail;
-  return digiAddress << kShift[1][kStsUnit] | ToIntegralType(ECbmModuleId::kSts) << kShift[1][kStsSystem]
-         | 1u << kVersionShift;
-}
-// -------------------------------------------------------------------------
-
 // -----   String output   -------------------------------------------------
 std::string CbmStsAddress::ToString(int32_t address)
 {
diff --git a/core/data/sts/CbmStsAddress.h b/core/data/sts/CbmStsAddress.h
index 4bf1bf097a..6cc0b87804 100644
--- a/core/data/sts/CbmStsAddress.h
+++ b/core/data/sts/CbmStsAddress.h
@@ -12,8 +12,12 @@
 
 #include "CbmDefs.h"  // for ECbmModuleId
 
+#include <cassert>  // for assert
+#include <cstdint>  // for uint32_t
 #include <sstream>  // for string
 
+#include <xpu/defines.h>  // for XPU_D
+
 /** Enumerator for the hierarchy levels of the STS setup **/
 enum EStsElementLevel
 {
@@ -56,6 +60,57 @@ namespace CbmStsAddress
   inline constexpr int32_t kVersionShift = 28;  // First bit for version number
   inline constexpr int32_t kVersionMask  = (1 << kVersionSize) - 1;
 
+  namespace Detail
+  {
+    // clang-format off
+    // -----    Definition of address bit field   ------------------------------
+    inline constexpr uint16_t kBits[kCurrentVersion + 1][kStsNofLevels] = {
+
+      // Version 0 (until 23 August 2017)
+      {
+        4,  // system
+        4,  // unit / station
+        4,  // ladder
+        1,  // half-ladder
+        3,  // module
+        2,  // sensor
+        1   // side
+      },
+
+      // Version 1 (current, since 23 August 2017)
+      {
+        4,  // system
+        6,  // unit
+        5,  // ladder
+        1,  // half-ladder
+        5,  // module
+        4,  // sensor
+        1   // side
+      }
+
+    };
+    // -------------------------------------------------------------------------
+
+    // -----    Bit shifts -----------------------------------------------------
+    inline constexpr int32_t kShift[kCurrentVersion + 1][kStsNofLevels] = {
+      {0, kShift[0][0] + kBits[0][0], kShift[0][1] + kBits[0][1], kShift[0][2] + kBits[0][2], kShift[0][3] + kBits[0][3],
+      kShift[0][4] + kBits[0][4], kShift[0][5] + kBits[0][5]},
+
+      {0, kShift[1][0] + kBits[1][0], kShift[1][1] + kBits[1][1], kShift[1][2] + kBits[1][2], kShift[1][3] + kBits[1][3],
+      kShift[1][4] + kBits[1][4], kShift[1][5] + kBits[1][5]}};
+    // -------------------------------------------------------------------------
+
+    // -----    Bit masks  -----------------------------------------------------
+    inline constexpr int32_t kMask[kCurrentVersion + 1][kStsNofLevels] = {
+      {(1 << kBits[0][0]) - 1, (1 << kBits[0][1]) - 1, (1 << kBits[0][2]) - 1, (1 << kBits[0][3]) - 1,
+      (1 << kBits[0][4]) - 1, (1 << kBits[0][5]) - 1, (1 << kBits[0][6]) - 1},
+
+      {(1 << kBits[1][0]) - 1, (1 << kBits[1][1]) - 1, (1 << kBits[1][2]) - 1, (1 << kBits[1][3]) - 1,
+      (1 << kBits[1][4]) - 1, (1 << kBits[1][5]) - 1, (1 << kBits[1][6]) - 1}};
+    // -------------------------------------------------------------------------
+    // clang-format on
+  }  // Namespace Detail
+
 
   /** @brief Construct address
    ** @param unit         Unit index
@@ -123,34 +178,35 @@ namespace CbmStsAddress
    **/
   int32_t SetElementId(int32_t address, int32_t level, uint32_t newId);
 
-
-  /** @brief Strip address to contain only unit, (half)ladder and module.
-   ** @param address Full address
-   ** @return 17 bit address that can be stored in a Digi
-   **/
-  int32_t PackDigiAddress(int32_t address);
-
-
   /** @brief Add version and system to compressed address that's stored in a digi
    ** @param digiAddress Compressed address from digi
    ** @return Full address
    **/
-  int32_t UnpackDigiAddress(int32_t digiAddress);
-
+  XPU_D inline int32_t UnpackDigiAddress(int32_t digiAddress)
+  {
+    using namespace Detail;
+    return digiAddress << kShift[1][kStsUnit] | ToIntegralType(ECbmModuleId::kSts) << kShift[1][kStsSystem]
+           | 1u << kVersionShift;
+  }
 
   /** @brief Strip address to contain only unit, (half)ladder and module.
    ** @param address Full address
    ** @return 17 bit address that can be stored in a Digi
    **/
-  int32_t PackDigiAddress(int32_t address);
+  XPU_D inline int32_t PackDigiAddress(int32_t address)
+  {
+    using namespace Detail;
+    constexpr int32_t kDMask = kMask[1][kStsUnit] << kShift[1][kStsUnit] | kMask[1][kStsLadder] << kShift[1][kStsLadder]
+                               | kMask[1][kStsHalfLadder] << kShift[1][kStsHalfLadder]
+                               | kMask[1][kStsModule] << kShift[1][kStsModule];
 
+    int32_t ret = (address & kDMask) >> kShift[1][kStsUnit];
 
-  /** @brief Add version and system to compressed address that's stored in a digi
-   ** @param digiAddress Compressed address from digi
-   ** @return Full address
-   **/
-  int32_t UnpackDigiAddress(int32_t digiAddress);
+    // Check that no bits were set, that are stripped by this function.
+    assert(address == UnpackDigiAddress(ret));
 
+    return ret;
+  }
 
   /** @brief String output
    ** @param address Unique element address
diff --git a/core/data/sts/CbmStsDigi.cxx b/core/data/sts/CbmStsDigi.cxx
index 15618dc4c0..87fee5b58f 100644
--- a/core/data/sts/CbmStsDigi.cxx
+++ b/core/data/sts/CbmStsDigi.cxx
@@ -29,24 +29,6 @@ string CbmStsDigi::ToString() const
   return ss.str();
 }
 
-void CbmStsDigi::PackAddressAndTime(int32_t newAddress, uint32_t newTime)
-{
-  int32_t packedAddr = CbmStsAddress::PackDigiAddress(newAddress);
-
-  uint32_t highestBitAddr = packedAddr >> kNumLowerAddrBits;
-  uint32_t lowerAddr      = packedAddr & ((1 << kNumLowerAddrBits) - 1);
-
-  fAddress = lowerAddr;
-  fTime    = (highestBitAddr << kNumTimestampBits) | (kTimestampMask & newTime);
-}
-
-int32_t CbmStsDigi::UnpackAddress() const
-{
-  int32_t highestBitAddr = fTime >> kNumTimestampBits;
-  int32_t packedAddress  = (highestBitAddr << kNumLowerAddrBits) | int32_t(fAddress);
-  return CbmStsAddress::UnpackDigiAddress(packedAddress);
-}
-
 #ifndef NO_ROOT
 ClassImp(CbmStsDigi)
 #endif
diff --git a/core/data/sts/CbmStsDigi.h b/core/data/sts/CbmStsDigi.h
index 513234f7ff..a184059e07 100644
--- a/core/data/sts/CbmStsDigi.h
+++ b/core/data/sts/CbmStsDigi.h
@@ -44,14 +44,13 @@ public:
   CbmStsDigi() = default;
 
 
-#if XPU_IS_CPU
   /** Standard constructor
    ** @param  address  Unique element address
    ** @param  channel  Channel number
    ** @param  time     Measurement time [ns]
    ** @param  charge   Charge [ADC units]
    **/
-  CbmStsDigi(int32_t address, int32_t channel, double time, uint16_t charge)
+  XPU_D CbmStsDigi(int32_t address, int32_t channel, double time, uint16_t charge)
   {
     // StsDigi is not able to store negative timestamps.
     assert(time >= 0);
@@ -60,7 +59,6 @@ public:
     PackAddressAndTime(address, time);
     PackChannelAndCharge(channel, charge);
   }
-#endif
 
   /** Destructor **/
   ~CbmStsDigi() = default;
@@ -68,7 +66,7 @@ public:
   /** Unique detector element address  (see CbmStsAddress)
    ** @value Unique address of readout channel
    **/
-  int32_t GetAddress() const { return UnpackAddress(); }
+  XPU_D int32_t GetAddress() const { return UnpackAddress(); }
 
 
   /** @brief Get the desired name of the branch for this obj in the cbm output tree  (static)
@@ -138,9 +136,7 @@ public:
 
   XPU_D void SetCharge(uint16_t charge) { PackChannelAndCharge(UnpackChannel(), charge); }
 
-#if XPU_IS_CPU
-  void SetAddress(int32_t address) { PackAddressAndTime(address, UnpackTime()); }
-#endif
+  XPU_D void SetAddress(int32_t address) { PackAddressAndTime(address, UnpackTime()); }
 
 
   /** Set new channel and charge.
@@ -153,7 +149,7 @@ public:
    **
    ** Slightly more efficient than calling both individual setters.
    **/
-  void SetAddressAndTime(int32_t address, uint32_t time) { PackAddressAndTime(address, time); }
+  XPU_D void SetAddressAndTime(int32_t address, uint32_t time) { PackAddressAndTime(address, time); }
 
 
   /** String output **/
@@ -197,9 +193,23 @@ private:
   XPU_D uint16_t UnpackChannel() const { return fChannelAndCharge >> kNumAdcBits; }
   XPU_D uint16_t UnpackCharge() const { return fChannelAndCharge & kAdcMask; }
 
-  // Packing / Unpacking address not available on GPU for now...
-  void PackAddressAndTime(int32_t address, uint32_t time);
-  int32_t UnpackAddress() const;
+  XPU_D void PackAddressAndTime(int32_t newAddress, uint32_t newTime)
+  {
+    int32_t packedAddr = CbmStsAddress::PackDigiAddress(newAddress);
+
+    uint32_t highestBitAddr = packedAddr >> kNumLowerAddrBits;
+    uint32_t lowerAddr      = packedAddr & ((1 << kNumLowerAddrBits) - 1);
+
+    fAddress = lowerAddr;
+    fTime    = (highestBitAddr << kNumTimestampBits) | (kTimestampMask & newTime);
+  }
+
+  XPU_D int32_t UnpackAddress() const
+  {
+    int32_t highestBitAddr = fTime >> kNumTimestampBits;
+    int32_t packedAddress  = (highestBitAddr << kNumLowerAddrBits) | int32_t(fAddress);
+    return CbmStsAddress::UnpackDigiAddress(packedAddress);
+  }
 
 #ifndef NO_ROOT
   ClassDefNV(CbmStsDigi, 8);
-- 
GitLab