/* Copyright (C) 2022 GSI Helmholtzzentrum fuer Schwerionenforschung, Darmstadt
   SPDX-License-Identifier: GPL-3.0-only
   Authors: Volker Friese [committer] */


#include "Unpack.h"

#include <chrono>

#include <xpu/host.h>

#include "compat/Algorithm.h"
#include "log.hpp"

using namespace std;

namespace cbm::algo
{
  // -----   Execution   -------------------------------------------------------
  Unpack::resultType Unpack::operator()(const fles::Timeslice* timeslice)
  {
    xpu::scoped_timer t0("Unpack");

    // --- Output data
    resultType result          = {};
    CbmDigiTimeslice& digiTs   = result.first;
    UnpackMonitorData& monitor = result.second;

    ParallelInit(*timeslice);

    if (DetectorEnabled(fles::SubsystemIdentifier::STS)) {
      xpu::scoped_timer t1("STS");
      ParallelMsLoop(fParallelStsSetup, digiTs.fData.fSts.fDigis, monitor.fSts, *timeslice, fAlgoSts, 0x20);
    }

    // ---  Component loop
    for (uint64_t comp = 0; comp < timeslice->num_components(); comp++) {

      // System ID of current component
      const auto systemId = static_cast<fles::SubsystemIdentifier>(timeslice->descriptor(comp, 0).sys_id);

      if (!DetectorEnabled(systemId)) continue;

      xpu::scoped_timer t1(fles::to_string(systemId));

      // Equipment ID of current component
      const uint16_t equipmentId = timeslice->descriptor(comp, 0).eq_id;

      // The current algorithms work for the format versions hard-coded as parameters to MsLoop() below.
      // Other versions are not yet supported.
      // In the future, different data formats will be supported by instantiating different
      // algorithms depending on the version.

      // if (systemId == fles::SubsystemIdentifier::STS) {
      //   MsLoop(timeslice, fAlgoSts, comp, equipmentId, &digiTs.fData.fSts.fDigis, monitor, &monitor.fSts, 0x20);
      // }
      if (systemId == fles::SubsystemIdentifier::MUCH) {
        MsLoop(timeslice, fAlgoMuch, comp, equipmentId, &digiTs.fData.fMuch.fDigis, monitor, &monitor.fMuch, 0x20);
      }
      if (systemId == fles::SubsystemIdentifier::RPC) {
        MsLoop(timeslice, fAlgoTof, comp, equipmentId, &digiTs.fData.fTof.fDigis, monitor, &monitor.fTof, 0x00);
      }
      if (systemId == fles::SubsystemIdentifier::T0) {
        MsLoop(timeslice, fAlgoBmon, comp, equipmentId, &digiTs.fData.fT0.fDigis, monitor, &monitor.fBmon, 0x00);
      }
      if (systemId == fles::SubsystemIdentifier::TRD) {
        MsLoop(timeslice, fAlgoTrd, comp, equipmentId, &digiTs.fData.fTrd.fDigis, monitor, &monitor.fTrd, 0x01);
      }
      if (systemId == fles::SubsystemIdentifier::TRD2D) {
        MsLoop(timeslice, fAlgoTrd2d, comp, equipmentId, &digiTs.fData.fTrd2d.fDigis, monitor, &monitor.fTrd2d, 0x02);
      }
      if (systemId == fles::SubsystemIdentifier::RICH) {
        MsLoop(timeslice, fAlgoRich, comp, equipmentId, &digiTs.fData.fRich.fDigis, monitor, &monitor.fRich, 0x03);
      }
    }  //# component

    // --- Sorting of output digis. Is required by both digi trigger and event builder.

    xpu::scoped_timer t2("Sort");
    Sort(digiTs.fData.fSts.fDigis.begin(), digiTs.fData.fSts.fDigis.end(),
         [](CbmStsDigi digi1, CbmStsDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
    Sort(digiTs.fData.fMuch.fDigis.begin(), digiTs.fData.fMuch.fDigis.end(),
         [](CbmMuchDigi digi1, CbmMuchDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
    Sort(digiTs.fData.fTof.fDigis.begin(), digiTs.fData.fTof.fDigis.end(),
         [](CbmTofDigi digi1, CbmTofDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
    Sort(digiTs.fData.fT0.fDigis.begin(), digiTs.fData.fT0.fDigis.end(),
         [](CbmTofDigi digi1, CbmTofDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
    Sort(digiTs.fData.fTrd.fDigis.begin(), digiTs.fData.fTrd.fDigis.end(),
         [](CbmTrdDigi digi1, CbmTrdDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
    Sort(digiTs.fData.fTrd2d.fDigis.begin(), digiTs.fData.fTrd2d.fDigis.end(),
         [](CbmTrdDigi digi1, CbmTrdDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });
    Sort(digiTs.fData.fRich.fDigis.begin(), digiTs.fData.fRich.fDigis.end(),
         [](CbmRichDigi digi1, CbmRichDigi digi2) { return digi1.GetTime() < digi2.GetTime(); });

    return result;
  }
  // ----------------------------------------------------------------------------

  // ----------------- Microslice loop ------------------------------------------
  template<class Digi, class UnpackAlgo, class MonitorData>
  void Unpack::MsLoop(const fles::Timeslice* timeslice, std::map<uint16_t, UnpackAlgo>& algoMap, const uint64_t comp,
                      const uint16_t eqId, std::vector<Digi>* digis, UnpackMonitorData& monitor,
                      std::vector<MonitorData>* monitorMs, uint8_t sys_ver)
  {
    // --- Component log
    size_t numBytesInComp = 0;
    size_t numDigisInComp = 0;

    // For profiling
    const auto starttime = std::chrono::high_resolution_clock::now();

    // Get Unpacker
    const auto algoIt = algoMap.find(eqId);
    if (algoIt == algoMap.end()) {
      monitor.fNumErrInvalidEqId++;
      return;
    }
    UnpackAlgo& algo = algoIt->second;

    if (timeslice->descriptor(comp, 0).sys_ver != sys_ver) {
      monitor.fNumErrInvalidSysVer++;
      return;
    }

    const uint64_t numMsInComp = timeslice->num_microslices(comp);

    for (uint64_t mslice = 0; mslice < numMsInComp; mslice++) {
      const auto msDescriptor = timeslice->descriptor(comp, mslice);
      const auto msContent    = timeslice->content(comp, mslice);
      auto result             = algo(msContent, msDescriptor, timeslice->start_time());
      L_(debug) << "Unpack::MsLoop(): Component " << comp << ", microslice " << mslice << ", digis "
                << result.first.size() << ", " << result.second.print();
      numBytesInComp += msDescriptor.size;
      numDigisInComp += result.first.size();
      digis->insert(digis->end(), result.first.begin(), result.first.end());
      monitorMs->push_back(result.second);
    }
    // Get elapsed time
    const auto endtime  = std::chrono::high_resolution_clock::now();
    const auto duration = std::chrono::duration_cast<std::chrono::microseconds>(endtime - starttime);

    L_(debug) << "Unpack(): Component " << comp << ", subsystem "
              << fles::to_string(static_cast<fles::SubsystemIdentifier>(timeslice->descriptor(comp, 0).sys_id))
              << ", microslices " << numMsInComp << " input size " << numBytesInComp << " bytes,"
              << " digis " << numDigisInComp << ", CPU time " << duration.count() / 1000. << " ms";

    monitor.fNumMs += numMsInComp;
    monitor.fNumBytes += numBytesInComp;
    monitor.fNumDigis += numDigisInComp;
    monitor.fNumCompUsed++;
  }
  // ----------------------------------------------------------------------------


  // -----   Initialisation   ---------------------------------------------------
  void Unpack::Init(std::vector<fles::SubsystemIdentifier> subIds)
  {
    fSubIds = subIds;

    // --- Common parameters for all components for STS
    uint32_t numChansPerAsicSts   = 128;  // R/O channels per ASIC for STS
    uint32_t numAsicsPerModuleSts = 16;   // Number of ASICs per module for STS

    // Create one algorithm per component for STS and configure it with parameters
    auto equipIdsSts = fStsConfig.GetEquipmentIds();
    for (auto& equip : equipIdsSts) {
      std::unique_ptr<UnpackStsPar> par(new UnpackStsPar());
      par->fNumChansPerAsic   = numChansPerAsicSts;
      par->fNumAsicsPerModule = numAsicsPerModuleSts;
      const size_t numElinks  = fStsConfig.GetNumElinks(equip);
      for (size_t elink = 0; elink < numElinks; elink++) {
        UnpackStsElinkPar elinkPar;
        auto mapEntry        = fStsConfig.Map(equip, elink);
        elinkPar.fAddress    = mapEntry.first;   // Module address for this elink
        elinkPar.fAsicNr     = mapEntry.second;  // ASIC number within module
        elinkPar.fTimeOffset = fSystemTimeOffset[fles::SubsystemIdentifier::STS];
        elinkPar.fAdcOffset  = 1.;
        elinkPar.fAdcGain    = 1.;
        if (fApplyWalkCorrection) elinkPar.fWalk = fStsConfig.WalkMap(elinkPar.fAddress, elinkPar.fAsicNr);
        // TODO: Add parameters for time and ADC calibration
        par->fElinkParams.push_back(elinkPar);
      }
      fAlgoSts[equip].SetParams(std::move(par));
      L_(debug) << "--- Configured equipment " << equip << " with " << numElinks << " elinks";
    }  //# equipments

    // Create one algorithm per component for MUCH and configure it with parameters
    auto equipIdsMuch = fMuchConfig.GetEquipmentIds();
    for (auto& equip : equipIdsMuch) {
      std::unique_ptr<UnpackMuchPar> par(new UnpackMuchPar());
      const size_t numElinks = fMuchConfig.GetNumElinks(equip);
      for (size_t elink = 0; elink < numElinks; elink++) {
        UnpackMuchElinkPar elinkPar;
        elinkPar.fAddress    = fMuchConfig.Map(equip, elink);  // Vector of MUCH addresses for this elink
        elinkPar.fTimeOffset = fSystemTimeOffset[fles::SubsystemIdentifier::MUCH];
        par->fElinkParams.push_back(elinkPar);
      }
      fAlgoMuch[equip].SetParams(std::move(par));
      L_(debug) << "--- Configured equipment " << equip << " with " << numElinks << " elinks";
    }

    // Create one algorithm per component for TOF and configure it with parameters
    auto equipIdsTof = fTofConfig.GetEquipmentIds();
    for (auto& equip : equipIdsTof) {
      std::unique_ptr<UnpackTofPar> par(new UnpackTofPar());
      const size_t numElinks = fTofConfig.GetNumElinks(equip);
      for (size_t elink = 0; elink < numElinks; elink++) {
        UnpackTofElinkPar elinkPar;
        elinkPar.fChannelUId = fTofConfig.Map(equip, elink);  // Vector of TOF addresses for this elink
        elinkPar.fTimeOffset = fSystemTimeOffset[fles::SubsystemIdentifier::RPC];
        par->fElinkParams.push_back(elinkPar);
      }
      fAlgoTof[equip].SetParams(std::move(par));
      L_(debug) << "--- Configured equipment " << equip << " with " << numElinks << " elinks";
    }

    // Create one algorithm per component for T0 and configure it with parameters
    auto equipIdsBmon = fBmonConfig.GetEquipmentIds();
    for (auto& equip : equipIdsBmon) {
      std::unique_ptr<UnpackBmonPar> par(new UnpackBmonPar());
      const size_t numElinks = fBmonConfig.GetNumElinks(equip);
      for (size_t elink = 0; elink < numElinks; elink++) {
        UnpackBmonElinkPar elinkPar;
        elinkPar.fChannelUId = fBmonConfig.Map(equip, elink);  // Vector of T0 addresses for this elink
        elinkPar.fTimeOffset = fSystemTimeOffset[fles::SubsystemIdentifier::T0];
        par->fElinkParams.push_back(elinkPar);
      }
      fAlgoBmon[equip].SetParams(std::move(par));
      L_(debug) << "--- Configured equipment " << equip << " with " << numElinks << " elinks";
    }

    // Create one algorithm per component and configure it with parameters
    auto equipIdsRich = fRichConfig.GetEquipmentIds();
    for (auto& equip : equipIdsRich) {
      std::unique_ptr<UnpackRichPar> par(new UnpackRichPar());
      std::map<uint32_t, std::vector<double>> compMap = fRichConfig.Map(equip);
      for (auto const& val : compMap) {
        uint32_t address                       = val.first;
        par->fElinkParams[address].fToTshift   = val.second;
        par->fElinkParams[address].fTimeOffset = fSystemTimeOffset[fles::SubsystemIdentifier::RICH];
      }
      fAlgoRich[equip].SetParams(std::move(par));
      L_(info) << "--- Configured equipment " << equip << " with " << fRichConfig.GetNumElinks(equip) << " elinks";
    }

    // Create one algorithm per component for TRD and configure it with parameters
    auto equipIdsTrd = fTrdConfig.GetEquipmentIds();
    for (auto& equip : equipIdsTrd) {

      std::unique_ptr<UnpackTrdPar> par(new UnpackTrdPar());
      const size_t numCrobs = fTrdConfig.GetNumCrobs(equip);

      for (size_t crob = 0; crob < numCrobs; crob++) {
        UnpackTrdCrobPar crobPar;
        const size_t numElinks = fTrdConfig.GetNumElinks(equip, crob);

        for (size_t elink = 0; elink < numElinks; elink++) {
          UnpackTrdElinkPar elinkPar;
          auto addresses        = fTrdConfig.Map(equip, crob, elink);
          elinkPar.fAddress     = addresses.first;   // Asic address for this elink
          elinkPar.fChanAddress = addresses.second;  // Channel addresses for this elink
          elinkPar.fTimeOffset  = fSystemTimeOffset[fles::SubsystemIdentifier::TRD];
          crobPar.fElinkParams.push_back(elinkPar);
        }
        par->fCrobParams.push_back(crobPar);
      }
      fAlgoTrd[equip].SetParams(std::move(par));
      L_(debug) << "--- Configured equipment " << equip << " with " << numCrobs << " crobs";
    }

    // Create one algorithm per component for TRD2D and configure it with parameters
    auto equipIdsTrd2d = fTrd2dConfig.GetEquipmentIds();
    for (auto& equip : equipIdsTrd2d) {

      std::unique_ptr<UnpackTrd2dPar> par(new UnpackTrd2dPar());
      const size_t numAsics = fTrd2dConfig.GetNumAsics(equip);

      for (size_t asic = 0; asic < numAsics; asic++) {
        UnpackTrd2dAsicPar asicPar;
        const size_t numChans = fTrd2dConfig.GetNumChans(equip, asic);

        for (size_t chan = 0; chan < numChans; chan++) {
          UnpackTrd2dChannelPar chanPar;
          auto pars           = fTrd2dConfig.ChanMap(equip, asic, chan);
          chanPar.fPadAddress = std::get<0>(pars);  // Pad address for channel
          chanPar.fMask       = std::get<1>(pars);  // Flag channel mask
          chanPar.fDaqOffset  = std::get<2>(pars);  // Time calibration parameter
          asicPar.fChanParams.push_back(chanPar);
        }
        auto comppars          = fTrd2dConfig.CompMap(equip);
        par->fSystemTimeOffset = fSystemTimeOffset[fles::SubsystemIdentifier::TRD2D];
        par->fModId            = comppars.first;
        par->fCrobId           = comppars.second;
        par->fAsicParams.push_back(asicPar);
      }
      fAlgoTrd2d[equip].SetParams(std::move(par));
      L_(debug) << "--- Configured equipment " << equip << " with " << numAsics << " asics";
    }

    L_(info) << "--- Configured " << fAlgoSts.size()
             << " unpacker algorithms for STS. (Walk correction = " << fApplyWalkCorrection << ")";
    L_(debug) << "Readout map:" << fStsConfig.PrintReadoutMap();
    L_(info) << "--- Configured " << fAlgoMuch.size() << " unpacker algorithms for MUCH.";
    L_(info) << "--- Configured " << fAlgoRich.size() << " unpacker algorithms for RICH.";
    L_(debug) << "Readout map:" << fRichConfig.PrintReadoutMap();
    L_(info) << "--- Configured " << fAlgoTof.size() << " unpacker algorithms for TOF.";
    L_(info) << "--- Configured " << fAlgoTrd.size() << " unpacker algorithms for TRD.";
    L_(info) << "--- Configured " << fAlgoTrd2d.size() << " unpacker algorithms for TRD2D.";
    L_(info) << "--- Configured " << fAlgoBmon.size() << " unpacker algorithms for T0.";
    L_(info) << "==================================================";
  }
  // ----------------------------------------------------------------------------

  // ----------------------------------------------------------------------------
  void Unpack::ParallelInit(const fles::Timeslice& timeslice)
  {
    xpu::scoped_timer t("ParallelInit");

    fParallelStsSetup = {};

    size_t numMs       = 0;
    size_t maxNumDigis = 0;
    for (uint64_t comp = 0; comp < timeslice.num_components(); comp++) {
      auto systemId = static_cast<fles::SubsystemIdentifier>(timeslice.descriptor(comp, 0).sys_id);
      if (systemId == fles::SubsystemIdentifier::STS) {
        uint64_t numMsInComp = timeslice.num_microslices(comp);
        numMs += numMsInComp;
        u16 componentId = timeslice.descriptor(comp, 0).eq_id;
        for (uint64_t mslice = 0; mslice < numMsInComp; mslice++) {
          uint64_t msByteSize     = timeslice.descriptor(comp, mslice).size;
          uint64_t numDigisInComp = msByteSize / sizeof(CbmStsDigi);
          if (numDigisInComp > maxNumDigis) maxNumDigis = numDigisInComp;
          fParallelStsSetup.msEquipmentIds.push_back(componentId);
          fParallelStsSetup.msDescriptors.push_back(timeslice.descriptor(comp, mslice));
          fParallelStsSetup.msContent.push_back(timeslice.content(comp, mslice));
        }
      }
    }
    fParallelStsSetup.msDigis.resize(numMs);
    fParallelStsSetup.msMonitorData.resize(numMs);
  }
  // ----------------------------------------------------------------------------

  // ----------------------------------------------------------------------------
  template<class Digi, class UnpackAlgo, class Monitor>
  void Unpack::ParallelMsLoop(ParallelSetup<Digi, Monitor>& setup, std::vector<Digi>& digisOut,
                              std::vector<Monitor>& monitorOut, const fles::Timeslice& ts,
                              const std::map<u16, UnpackAlgo>& algos, u8 sys_ver)
  {
    const auto& msContent = setup.msContent;
    const auto& msDesc    = setup.msDescriptors;
    const auto& msEqIds   = setup.msEquipmentIds;
    auto& monitor         = setup.msMonitorData;
    auto& msDigis         = setup.msDigis;
    size_t numMs          = msDigis.size();

    xpu::push_timer("Unpack");
#pragma omp parallel for schedule(dynamic)
    for (size_t i = 0; i < numMs; i++) {
      auto result = fAlgoSts.at(msEqIds[i])(msContent[i], msDesc[i], ts.start_time());
      msDigis[i]  = std::move(result.first);
      monitor[i]  = std::move(result.second);
    }
    xpu::pop_timer();

    size_t nDigisTotal = 0;
    for (const auto& digis : msDigis) {
      nDigisTotal += digis.size();
    }

    xpu::push_timer("Resize");
    digisOut.resize(nDigisTotal);
    xpu::pop_timer();

    xpu::push_timer("Merge");
#pragma omp parallel for schedule(dynamic)
    for (unsigned int i = 0; i < numMs; i++) {
      unsigned int offset = 0;
      for (unsigned int x = 0; x < i; x++)
        offset += msDigis[x].size();
      std::copy(msDigis[i].begin(), msDigis[i].end(), digisOut.begin() + offset);
    }
    xpu::pop_timer();

    monitorOut = std::move(monitor);

    // Todo: Combine monitor Data
  }

} /* namespace cbm::algo */