/* Copyright (C) 2023 GSI Helmholtzzentrum fuer Schwerionenforschung, Darmstadt
   SPDX-License-Identifier: GPL-3.0-only
   Authors: Sergei Zharko [committer] */

/// \file   TrackingChain.cxx
/// \date   14.09.2023
/// \brief  A chain class to execute CA tracking algorithm in online reconstruction (implementation)
/// \author S.Zharko <s.zharko@gsi.de>

#include "TrackingChain.h"

#include "tof/Config.h"

#include <xpu/host.h>

#include "CaConstants.h"
#include "CaHit.h"
#include "CaInitManager.h"
#include "CaParameters.h"

using namespace cbm::algo;

using cbm::algo::TrackingChain;
using cbm::algo::ca::EDetectorID;
using cbm::algo::ca::Framework;
using cbm::algo::ca::HitTypes_t;
using cbm::algo::ca::InitManager;
using cbm::algo::ca::Parameters;
using cbm::algo::ca::Track;
using cbm::algo::ca::constants::clrs::CL;   // clear text
using cbm::algo::ca::constants::clrs::GNb;  // grin bald text

// ---------------------------------------------------------------------------------------------------------------------
//
void TrackingChain::Init()
{
  // ------ Read parameters from binary
  std::string paramFileBase = "mcbm_beam_2022_05_23_nickel.ca.par";  // TODO: Get the setup name from Opts()
  auto paramFile            = Opts().ParamsDir();
  paramFile /= paramFileBase;
  L_(info) << "Tracking Chain: reading CA parameters file " << GNb << paramFile.string() << CL << '\n';
  auto manager = InitManager {};
  manager.ReadParametersObject(paramFile.string());
  auto parameters = manager.TakeParameters();
  L_(info) << "Tracking Chain: parameters object: \n" << parameters.ToString(1) << '\n';

  // ------ Initialize CA framework
  fCaMonitor.Reset();
  fCaFramework.Init(ca::Framework::TrackingMode::kMcbm);
  fCaFramework.ReceiveParameters(std::move(parameters));
}

// ---------------------------------------------------------------------------------------------------------------------
//
TrackingChain::Return_t TrackingChain::Run(Input_t recoResults)
{
  //xpu::scoped_timer t_("CA");  // TODO: pass timings to monitoring for throughput?
  fCaMonitorData.Reset();
  fCaMonitorData.StartTimer(ca::ETimer::Tracking);

  // ----- Init input data ---------------------------------------------------------------------------------------------
  fCaMonitorData.StartTimer(ca::ETimer::PrepareInput);
  this->PrepareInput(recoResults);
  fCaMonitorData.StopTimer(ca::ETimer::PrepareInput);

  // ----- Run reconstruction ------------------------------------------------------------------------------------------
  fCaFramework.SetMonitorData(fCaMonitorData);
  fCaFramework.fTrackFinder.FindTracks();
  fCaMonitorData = fCaFramework.GetMonitorData();
  L_(info) << "Timeslice contains " << fCaMonitorData.GetCounterValue(ca::ECounter::RecoTrack) << " tracks";

  // ----- Init output data --------------------------------------------------------------------------------------------
  // FIXME: SZh 22.10.2023: Provide a class for the tracking output data (tracks, hit indices and monitor)
  fCaMonitor.AddMonitorData(fCaMonitorData);
  fCaMonitorData.StopTimer(ca::ETimer::Tracking);
  return std::make_pair(std::move(fCaFramework.fRecoTracks), fCaMonitorData);
}

// ---------------------------------------------------------------------------------------------------------------------
//
void TrackingChain::Finalize() { L_(info) << fCaMonitor.ToString(); }

// ---------------------------------------------------------------------------------------------------------------------
//
void TrackingChain::PrepareInput(Input_t recoResults)
{
  //L_(info) << "TOF TEST: " << tof::Config::GetTofTrackingStation(0x00008036);  <- access to the TOF tracking station
  fNofHitKeys  = 0;
  int nHitsTot = recoResults.stsHits.NElements() + recoResults.tofHits.NElements();
  L_(info) << "Tracking chain: input has " << nHitsTot << " hits";
  fCaDataManager.ResetInputData(nHitsTot);
  ReadHits<EDetectorID::Sts>(recoResults.stsHits);
  ReadHits<EDetectorID::Tof>(recoResults.tofHits);
  fCaDataManager.SetNhitKeys(fNofHitKeys);
  L_(info) << "Tracking chain:" << fCaDataManager.GetNofHits() << " will be passed to the ca::Framework";
  fCaFramework.ReceiveInputData(fCaDataManager.TakeInputData());
}

// ---------------------------------------------------------------------------------------------------------------------
//
template<EDetectorID DetID>
void TrackingChain::ReadHits(PartitionedSpan<const ca::HitTypes_t::at<DetID>> hits)
{
  using Hit_t           = ca::HitTypes_t::at<DetID>;
  constexpr bool IsMvd  = (DetID == EDetectorID::Mvd);
  constexpr bool IsSts  = (DetID == EDetectorID::Sts);
  constexpr bool IsMuch = (DetID == EDetectorID::Much);
  constexpr bool IsTrd  = (DetID == EDetectorID::Trd);
  constexpr bool IsTof  = (DetID == EDetectorID::Tof);

  xpu::t_add_bytes(hits.NElements() * sizeof(Hit_t));  // Assumes call from Run, for existence of timer!

  ca::HitKeyIndex_t firstHitKey = fNofHitKeys;
  int64_t dataStreamDet         = static_cast<int64_t>(DetID) << 60;  // detector part of the data stream
  for (size_t iPartition = 0; iPartition < hits.NPartitions(); ++iPartition) {
    const auto& [vHits, extHitAddress] = hits.Partition(iPartition);
    // ---- Define data stream and station index
    int64_t dataStream = dataStreamDet | extHitAddress;
    int iStLocal       = -1;
    // FIXME: This definition of the station index works only for STS, and there is no any guaranty, that it will
    //        work for other mCBM setups.
    if constexpr (IsSts) { iStLocal = (extHitAddress >> 4) & 0xF; }
    if constexpr (IsTof) { iStLocal = tof::Config::GetTofTrackingStation(extHitAddress); }

    int iStActive  = (iStLocal != -1) ? fCaFramework.GetParameters().GetStationIndexActive(iStLocal, DetID) : -1;
    size_t iOffset = hits.Offsets()[iPartition];
    if (iStActive < 0) { continue; }

    for (size_t iPartHit = 0; iPartHit < vHits.size(); ++iPartHit) {
      const auto& hit = vHits[iPartHit];
      int iHitExt     = iOffset + iPartHit;
      // ---- Fill ca::Hit
      ca::Hit caHit;
      if constexpr (IsSts) {
        caHit.SetFrontKey(firstHitKey + hit.fFrontClusterId);
        caHit.SetBackKey(firstHitKey + hit.fBackClusterId);
      }
      else {
        caHit.SetFrontKey(firstHitKey + iHitExt);
        caHit.SetBackKey(caHit.FrontKey());
      }
      caHit.SetX(hit.X());
      caHit.SetY(hit.Y());
      caHit.SetZ(hit.Z());
      caHit.SetT(hit.Time());
      caHit.SetDx2(hit.Dx() * hit.Dx());
      caHit.SetDy2(hit.Dy() * hit.Dy());
      if constexpr (IsSts) caHit.SetDxy(hit.fDxy);
      caHit.SetDt2(hit.TimeError() * hit.TimeError());
      /// FIXME: Define ranges from the hit, when will be available
      caHit.SetRangeX(3.5 * hit.Dx());
      caHit.SetRangeY(3.5 * hit.Dy());
      caHit.SetRangeT(3.5 * hit.TimeError());
      caHit.SetStation(iStActive);
      caHit.SetId(fCaDataManager.GetNofHits());
      if (caHit.Check()) {
        fCaDataManager.PushBackHit(caHit, dataStream);
        if (fNofHitKeys <= caHit.FrontKey()) { fNofHitKeys = caHit.FrontKey() + 1; }
        if (fNofHitKeys <= caHit.BackKey()) { fNofHitKeys = caHit.BackKey() + 1; }
      }
      else {
        if constexpr (IsMvd) { fCaMonitorData.IncrementCounter(ca::ECounter::UndefinedMvdHit); }
        if constexpr (IsSts) { fCaMonitorData.IncrementCounter(ca::ECounter::UndefinedStsHit); }
        if constexpr (IsMuch) { fCaMonitorData.IncrementCounter(ca::ECounter::UndefinedMuchHit); }
        if constexpr (IsTrd) { fCaMonitorData.IncrementCounter(ca::ECounter::UndefinedTrdHit); }
        if constexpr (IsTof) { fCaMonitorData.IncrementCounter(ca::ECounter::UndefinedTofHit); }
      }
      // ---- Update number of hit keys
    }  // iPartHit
  }    // iPartition
}

// template void TrackingChain::ReadHits<EDetectorID::Sts>(const PartitionedPODVector<HitTypes_t::at<EDetectorID::Sts>>&);