/* Copyright (C) 2018-2020 GSI Helmholtzzentrum fuer Schwerionenforschung, Darmstadt
   SPDX-License-Identifier: GPL-3.0-only
   Authors: Etienne Bechtel, Florian Uhlig [committer], Etienne Bechtel */

#include "CbmTrdModuleRecR.h"

#include "CbmDigiManager.h"
#include "CbmTrdAddress.h"
#include "CbmTrdCluster.h"
#include "CbmTrdClusterFinder.h"
#include "CbmTrdDigi.h"
#include "CbmTrdHit.h"
#include "CbmTrdParModDigi.h"
#include "CbmTrdParSetDigi.h"
#include "TGeoMatrix.h"

#include <Logger.h>

#include <TCanvas.h>
#include <TClonesArray.h>
#include <TH2F.h>
#include <TImage.h>
#include <TVector3.h>

#include <iostream>

constexpr Double_t CbmTrdModuleRecR::kxVar_Value[2][5];
constexpr Double_t CbmTrdModuleRecR::kyVar_Value[2][5];

//_______________________________________________________________________________
CbmTrdModuleRecR::CbmTrdModuleRecR() : CbmTrdModuleRec(), fDigiCounter(0), fDigiMap(), fClusterMap()
{
  SetNameTitle("TrdModuleRecR", "Reconstructor for rectangular pad TRD module");
}

//_______________________________________________________________________________
CbmTrdModuleRecR::CbmTrdModuleRecR(Int_t mod, Int_t ly, Int_t rot)
  : CbmTrdModuleRec(mod, ly, rot)
  , fDigiCounter(0)
  , fDigiMap()
  , fClusterMap()
{
  SetNameTitle(Form("TrdModuleRecR%02d", mod), "Reconstructor for rectangular pad TRD module");
}

//_______________________________________________________________________________
CbmTrdModuleRecR::~CbmTrdModuleRecR() {}

//_______________________________________________________________________________
Bool_t CbmTrdModuleRecR::AddDigi(const CbmTrdDigi* digi, Int_t id)
{

  // fill the digimap
  fDigiMap.push_back(std::make_tuple(id, false, digi));
  fDigiCounter++;
  return kTRUE;
}

//_______________________________________________________________________________
void CbmTrdModuleRecR::Clear(Option_t* opt)
{
  if (strcmp(opt, "cls") == 0) {
    fDigiMap.erase(fDigiMap.begin(), fDigiMap.end());
    fClusterMap.erase(fClusterMap.begin(), fClusterMap.end());
    fDigiCounter = 0;
  }
  CbmTrdModuleRec::Clear(opt);
}

//_______________________________________________________________________________
std::vector<CbmTrdCluster> CbmTrdModuleRecR::BuildClusters(bool)
{
  std::vector<CbmTrdCluster> clustersOut;

  const double interval = CbmTrdDigi::Clk(CbmTrdDigi::eCbmTrdAsicType::kSPADIC);
  auto start            = fDigiMap.begin();  // marker to skip already processed entries from the map

  // search for an unprocessed main triggered digi and then start a subloop to
  // directly construct the cluster  (search for main-trigger then add the neighbors)
  for (auto mainit = fDigiMap.begin(); mainit != fDigiMap.end(); mainit++) {

    // skip invalid digis
    const CbmTrdDigi* digi = (const CbmTrdDigi*) std::get<2>(*mainit);
    if (!digi) continue;

    ///////////// To do: Perhaps separte self and neighbor digis?
    ////////////// To do: Rethink the start position logic

    // get digi time and type
    const double time                   = digi->GetTime();
    const CbmTrdDigi::eTriggerType type = static_cast<CbmTrdDigi::eTriggerType>(digi->GetTriggerType());
    const Bool_t marked                 = std::get<1>(*mainit);
    if (type != CbmTrdDigi::eTriggerType::kSelf || marked) continue;

    // variety of neccessary address information; uses the "combiId" for the
    // comparison of digi positions
    const Int_t digiId  = std::get<0>(*mainit);
    const Int_t channel = digi->GetAddressChannel();
    const Int_t ncols   = fDigiPar->GetNofColumns();

    // some logic information which is used to process and find the clusters
    Int_t lowcol  = channel;
    Int_t highcol = channel;
    Int_t lowrow  = channel;
    Int_t highrow = channel;

    // information buffer to handle neighbor rows and cluster over two rows; the
    // identification of adjacent rows is done by comparing their center of
    // gravity
    struct Buffer {
      size_t count   = 0;
      double data[3] = {0, 0, 0};
      const double& operator[](size_t ind) const { return data[ind]; }
      double& operator[](size_t ind) { return (count++, data[ind]); }
      double GetCoG() { return (data[2] / data[0]) - (data[1] / data[0]); }
    };

    Buffer buffertop, bufferbot, bufferrow;
    bufferrow[0] = digi->GetCharge();

    // //some logical flags to reject unnecessary steps
    Bool_t sealtopcol = false;  // the "seal" bools register when the logical end
                                // of the cluster was found
    Bool_t sealbotcol = false;
    Bool_t sealtoprow = false;
    Bool_t sealbotrow = false;
    Bool_t rowchange  = false;  // flags that there is a possible two row cluster
    Bool_t addtop     = false;  // adds the buffered information of the second row
    Bool_t addbot     = false;

    // //vector which contains the actual cluster
    std::vector<std::pair<Int_t, const CbmTrdDigi*>> cluster;
    cluster.push_back(std::make_pair(digiId, digi));
    std::get<1>(*mainit) = true;

    // Bool_t mergerow=CbmTrdClusterFinder::HasRowMerger();
    Bool_t mergerow = true;

    // update start position
    start = std::lower_bound(start, fDigiMap.end(), time - interval,
                             [](const auto& obj, double val) { return std::get<2>(obj)->GetTime() < val; });

    // loop to find the other pads corresponding to the main trigger
    // is exited either if the implemented trigger logic is fullfilled
    // or if there are no more adjacend pads due to edges,etc.
    while (true) {

      // counter which is used to easily break clusters which are at the edge and
      // therefore do not fullfill the classical look
      const size_t oldSize = cluster.size();

      // find the FN digis of main trigger or adjacent main triggers
      for (auto FNit = start; FNit != fDigiMap.end(); FNit++) {

        // Skip already processed digis
        bool& filled = std::get<1>(*FNit);
        if (filled) continue;

        // some information to serparate the time space and to skip processed digis
        const CbmTrdDigi* d  = (const CbmTrdDigi*) std::get<2>(*FNit);
        const double newtime = d->GetTime();
        if (newtime > time + interval) break;

        // position information of the possible neighbor digis
        const double charge                        = d->GetCharge();
        const int digiid                           = std::get<0>(*FNit);
        const int ch                               = d->GetAddressChannel();
        const CbmTrdDigi::eTriggerType triggertype = static_cast<CbmTrdDigi::eTriggerType>(d->GetTriggerType());

        auto TryAdd = [&](int val) {
          if (ch == val && !filled) {
            return (cluster.emplace_back(std::make_pair(digiid, d)), filled = true);
          }
          return false;
        };

        // logical implementation of the trigger logic in the same row as the
        // main trigger
        if (triggertype == CbmTrdDigi::eTriggerType::kSelf) {
          if (TryAdd(lowcol - 1)) {
            lowcol = ch;
          }
          if (TryAdd(highcol + 1)) {
            highcol = ch;
          }
        }
        if (triggertype == CbmTrdDigi::eTriggerType::kNeighbor) {
          if (!sealtopcol && TryAdd(highcol + 1)) {
            sealtopcol = true;
          }
          if (!sealbotcol && TryAdd(lowcol - 1)) {
            sealbotcol = true;
          }
        }

        const int col = ch % ncols;
        if (col == ncols) sealtopcol = true;
        if (col == 0) sealbotcol = true;

        if (mergerow) {
          // multiple row processing
          // first buffering
          if (ch == channel - ncols && !rowchange && triggertype == CbmTrdDigi::eTriggerType::kSelf) {
            rowchange    = true;
            bufferbot[0] = charge;
          }
          if (ch == (channel - ncols) - 1 && rowchange) {
            bufferbot[1] = charge;
          }
          if (ch == (channel - ncols) + 1 && rowchange) {
            bufferbot[2] = charge;
          }
          if (ch == channel + ncols && !rowchange && triggertype == CbmTrdDigi::eTriggerType::kSelf) {
            rowchange    = true;
            buffertop[0] = charge;
          }
          if (ch == (channel + ncols) - 1 && rowchange) {
            buffertop[1] = charge;
          }
          if (ch == (channel + ncols) + 1 && rowchange) {
            buffertop[2] = charge;
          }

          if (ch == channel - 1) {
            bufferrow[1] = charge;
          }
          if (ch == channel + 1) {
            bufferrow[2] = charge;
          }

          /// To do: charge equal zero produces problems. use NaN or other invalid value instead
          //   int num = 3 - std::count(&bufferrow.data[0], &bufferrow.data[3], 0);
          //   std::cout << bufferrow.count << " " << num << " " << charge << std::endl;
          //   assert(bufferrow.count == num);

          // then the calculation of the center of gravity with the
          // identification of common CoGs
          if (buffertop.count == 3 && bufferrow.count == 3 && !addtop
              && TMath::Abs((buffertop.GetCoG() - bufferrow.GetCoG())) < 0.25 * bufferrow.GetCoG()) {
            addtop = true;
          }
          if (bufferbot.count == 3 && bufferrow.count == 3 && !addbot
              && TMath::Abs((bufferbot.GetCoG() - bufferrow.GetCoG())) < 0.25 * bufferrow.GetCoG()) {
            addbot = true;
          }

          // adding of the neighboring row
          if (addbot && TryAdd(channel - ncols)) {
            lowrow  = ch;
            highrow = ch;
          }
          if (addtop && TryAdd(channel + ncols)) {
            lowrow  = ch;
            highrow = ch;
          }
          if (triggertype == CbmTrdDigi::eTriggerType::kSelf) {
            if (rowchange && lowrow != channel && TryAdd(lowrow - 1)) {
              lowrow = ch;
            }
            if (rowchange && highrow != channel && TryAdd(highrow + 1)) {
              highrow = ch;
            }
          }

          if (triggertype == CbmTrdDigi::eTriggerType::kNeighbor) {
            if (rowchange && highrow != channel && !sealtoprow && TryAdd(highrow + 1)) {
              sealtoprow = true;
            }
            if (rowchange && lowrow != channel && !sealbotrow && TryAdd(lowrow - 1)) {
              sealbotrow = true;
            }
          }
        }  //! if (mergerow)
      }    //! for (auto FNit = start; FNit != fDigiMap.end(); FNit++)

      // some finish criteria
      if (cluster.size() - oldSize == 0) break;
      if (sealbotcol && sealtopcol && !rowchange) break;
      if (sealbotcol && sealtopcol && sealtoprow && sealbotrow) break;
    }  //!  while (true)

    addClusters(cluster, &clustersOut);
  }  //! for (auto mainit = fDigiMap.begin(); mainit != fDigiMap.end(); mainit++)

  return clustersOut;
}

//_____________________________________________________________________
void CbmTrdModuleRecR::addClusters(std::vector<std::pair<Int_t, const CbmTrdDigi*>> cluster,
                                   std::vector<CbmTrdCluster>* clustersOut)
{
  // create vector for indice matching
  std::vector<Int_t> digiIndices;
  digiIndices.reserve(cluster.size());

  // add digi ids to vector
  std::transform(cluster.begin(), cluster.end(), std::back_inserter(digiIndices),
                 [](const auto& pair) { return pair.first; });

  // add the cluster to the Array
  CbmTrdCluster& newcluster = clustersOut->emplace_back();
  newcluster.SetAddress(fModAddress);
  newcluster.SetDigis(digiIndices);
  newcluster.SetNCols(digiIndices.size());
}

//_______________________________________________________________________________
Bool_t CbmTrdModuleRecR::MakeHits() { return kTRUE; }

//_______________________________________________________________________________
CbmTrdHit* CbmTrdModuleRecR::MakeHit(Int_t clusterId, const CbmTrdCluster* cluster,
                                     std::vector<const CbmTrdDigi*>* digis)
{

  TVector3 hit_posV;
  TVector3 local_pad_posV;
  TVector3 local_pad_dposV;
  for (Int_t iDim = 0; iDim < 3; iDim++) {
    hit_posV[iDim]        = 0.0;
    local_pad_posV[iDim]  = 0.0;
    local_pad_dposV[iDim] = 0.0;
  }

  Double_t xVar        = 0;
  Double_t yVar        = 0;
  Double_t totalCharge = 0;
  //  Double_t totalChargeTR = 0;
  //  Double_t momentum = 0.;
  //  Int_t moduleAddress = 0;
  Double_t time    = 0.;
  Int_t errorclass = 0.;
  Bool_t EB        = false;
  Bool_t EBP       = false;
  for (std::vector<const CbmTrdDigi*>::iterator id = digis->begin(); id != digis->end(); id++) {
    const CbmTrdDigi* digi = (*id);
    if (!digi) {
      continue;
      std::cout << " no digi " << std::endl;
    }

    Double_t digiCharge = digi->GetCharge();
    errorclass          = digi->GetErrorClass();
    EB                  = digi->IsFlagged(0);
    EBP                 = digi->IsFlagged(1);

    //    if (digiCharge <= 0)     {std::cout<<" charge 0 " <<
    //    std::endl;continue;}
    if (digiCharge <= 0.05) {
      continue;
    }

    time += digi->GetTime();
    //    time += digi->GetTimeDAQ();

    totalCharge += digi->GetCharge();

    fDigiPar->GetPadPosition(digi->GetAddressChannel(), true, local_pad_posV, local_pad_dposV);

    Double_t xMin = local_pad_posV[0] - local_pad_dposV[0];
    Double_t xMax = local_pad_posV[0] + local_pad_dposV[0];
    xVar += (xMax * xMax + xMax * xMin + xMin * xMin) * digiCharge;

    Double_t yMin = local_pad_posV[1] - local_pad_dposV[1];
    Double_t yMax = local_pad_posV[1] + local_pad_dposV[1];
    yVar += (yMax * yMax + yMax * yMin + yMin * yMin) * digiCharge;

    for (Int_t iDim = 0; iDim < 3; iDim++) {
      hit_posV[iDim] += local_pad_posV[iDim] * digiCharge;
    }
  }
  time /= digis->size();

  if (totalCharge <= 0) return NULL;

  Double_t hit_pos[3];
  for (Int_t iDim = 0; iDim < 3; iDim++) {
    hit_posV[iDim] /= totalCharge;
    hit_pos[iDim] = hit_posV[iDim];
  }

  if (EB) {
    xVar = kxVar_Value[0][errorclass];
    yVar = kyVar_Value[0][errorclass];
  }
  else {
    if (EBP) time -= 46;  //due to the event time of 0 in the EB mode and the ULong in the the digi time
    //TODO: move to parameter file
    xVar = kxVar_Value[1][errorclass];
    yVar = kyVar_Value[1][errorclass];
  }

  TVector3 cluster_pad_dposV(xVar, yVar, 0);

  // --- If a TGeoNode is attached, transform into global coordinate system
  Double_t global[3];
  LocalToMaster(hit_pos, global);

  if (!EB) {  // preliminary correction for angle dependence in the position
              // reconsutrction
    global[0] = global[0] + (0.00214788 + global[0] * 0.000195394);
    global[1] = global[1] + (0.00370566 + global[1] * 0.000213235);
  }

  fDigiPar->TransformHitError(cluster_pad_dposV);

  // TODO: get momentum for more exact spacial error
  if ((fDigiPar->GetOrientation() == 1) || (fDigiPar->GetOrientation() == 3)) {
    cluster_pad_dposV[0] = sqrt(fDigiPar->GetPadSizeY(1));
  }
  else {
    cluster_pad_dposV[1] = sqrt(fDigiPar->GetPadSizeY(1));
  }

  // Set charge of incomplete clusters (missing NTs) to -1 (not deleting them because they are still relevant for tracking)
  if (!IsClusterComplete(cluster)) totalCharge = -1.0;

  Int_t nofHits = fHits->GetEntriesFast();

  //  return new ((*fHits)[nofHits]) CbmTrdHit(fModAddress, global,
  //  cluster_pad_dposV, 0, clusterId,0, 0,
  //  totalCharge/1e6,time,Double_t(CbmTrdDigi::Clk(CbmTrdDigi::eCbmTrdAsicType::kSPADIC)));
  return new ((*fHits)[nofHits])
    CbmTrdHit(fModAddress, global, cluster_pad_dposV, 0, clusterId, totalCharge / 1e6, time,
              Double_t(8.5));  // TODO: move to parameter file
}

Double_t CbmTrdModuleRecR::GetSpaceResolution(Double_t val)
{

  std::pair<Double_t, Double_t> res[12] = {
    std::make_pair(0.5, 0.4),  std::make_pair(1, 0.35),   std::make_pair(2, 0.3),    std::make_pair(2.5, 0.3),
    std::make_pair(3.5, 0.28), std::make_pair(4.5, 0.26), std::make_pair(5.5, 0.26), std::make_pair(6.5, 0.26),
    std::make_pair(7.5, 0.26), std::make_pair(8.5, 0.26), std::make_pair(8.5, 0.26), std::make_pair(9.5, 0.26)};

  Double_t selval = 0.;

  for (Int_t n = 0; n < 12; n++) {
    if (val < res[0].first) selval = res[0].second;
    if (n == 11) {
      selval = res[11].second;
      break;
    }
    if (val >= res[n].first && val <= res[n + 1].first) {
      Double_t dx    = res[n + 1].first - res[n].first;
      Double_t dy    = res[n + 1].second - res[n].second;
      Double_t slope = dy / dx;
      selval         = (val - res[n].first) * slope + res[n].second;
      break;
    }
  }

  return selval;
}

bool CbmTrdModuleRecR::IsClusterComplete(const CbmTrdCluster* cluster)
{
  int colMin = fDigiPar->GetNofColumns();
  int rowMin = fDigiPar->GetNofRows();

  for (int i = 0; i < cluster->GetNofDigis(); ++i) {
    const CbmTrdDigi* digi = CbmDigiManager::Instance()->Get<CbmTrdDigi>(cluster->GetDigi(i));
    int digiCol            = fDigiPar->GetPadColumn(digi->GetAddressChannel());
    int digiRow            = fDigiPar->GetPadRow(digi->GetAddressChannel());

    if (digiCol < colMin) colMin = digiCol;
    if (digiRow < rowMin) rowMin = digiRow;
  }

  const UShort_t nCols = cluster->GetNCols();
  const UShort_t nRows = cluster->GetNRows();

  CbmTrdDigi* digiMap[nRows][nCols];                        //create array on stack for optimal performance
  memset(digiMap, 0, sizeof(CbmTrdDigi*) * nCols * nRows);  //init with nullpointers

  for (int i = 0; i < cluster->GetNofDigis(); ++i) {
    const CbmTrdDigi* digi = CbmDigiManager::Instance()->Get<CbmTrdDigi>(cluster->GetDigi(i));
    int digiCol            = fDigiPar->GetPadColumn(digi->GetAddressChannel());
    int digiRow            = fDigiPar->GetPadRow(digi->GetAddressChannel());

    if (digiMap[digiRow - rowMin][digiCol - colMin])
      return false;  // To be investigated why this sometimes happens (Redmin Issue 2914)

    digiMap[digiRow - rowMin][digiCol - colMin] = const_cast<CbmTrdDigi*>(digi);
  }

  // check if each row of the cluster starts and ends with a kNeighbor digi
  for (int iRow = 0; iRow < nRows; ++iRow) {
    int colStart = 0;
    while (digiMap[iRow][colStart] == nullptr)
      ++colStart;
    if (digiMap[iRow][colStart]->GetTriggerType() != static_cast<Int_t>(CbmTrdDigi::eTriggerType::kNeighbor))
      return false;

    int colStop = nCols - 1;
    while (digiMap[iRow][colStop] == nullptr)
      --colStop;
    if (digiMap[iRow][colStop]->GetTriggerType() != static_cast<Int_t>(CbmTrdDigi::eTriggerType::kNeighbor))
      return false;
  }

  return true;
}

ClassImp(CbmTrdModuleRecR)