From 6e65296663ac520276116d6b1c80c0d16f91c29c Mon Sep 17 00:00:00 2001 From: "se.gorbunov" <se.gorbunov@gsi.de> Date: Mon, 13 Nov 2023 21:49:58 +0000 Subject: [PATCH] KF: utility to refit global tracks with Kalman Filter Smoother --- algo/ca/core/data/CaMeasurementTime.h | 6 +- algo/ca/core/tracking/CaTrackFit.cxx | 6 +- algo/ca/core/tracking/CaTrackFit.h | 9 +- reco/KF/CbmKFTrackFitter.cxx | 637 ++++++++++++++++++++++++++ reco/KF/CbmKFTrackFitter.h | 160 +++++++ reco/KF/KF.cmake | 1 + reco/KF/KFLinkDef.h | 1 + 7 files changed, 813 insertions(+), 7 deletions(-) create mode 100644 reco/KF/CbmKFTrackFitter.cxx create mode 100644 reco/KF/CbmKFTrackFitter.h diff --git a/algo/ca/core/data/CaMeasurementTime.h b/algo/ca/core/data/CaMeasurementTime.h index f74597788e..a8a24532eb 100644 --- a/algo/ca/core/data/CaMeasurementTime.h +++ b/algo/ca/core/data/CaMeasurementTime.h @@ -85,14 +85,14 @@ namespace cbm::algo::ca ///------------------------------ /// Data members - DataT fT {constants::Undef<DataT>}; ///< time coordinate of the measurement - DataT fDt2 {constants::Undef<DataT>}; ///< rms^2 of the time coordinate measurement + DataT fT {0}; ///< time coordinate of the measurement + DataT fDt2 {1.}; ///< rms^2 of the time coordinate measurement /// number of degrees of freedom (used for chi2 calculation) /// if ndf == 1, the measurement is used in fit and in the chi2 calculation /// if ndf == 0, the measurement is used neither in fit nor in the chi2 calculation - DataT fNdfT = constants::Undef<DataT>; ///< ndf for the time coordinate measurement + DataT fNdfT {0}; ///< ndf for the time coordinate measurement } _fvecalignment; diff --git a/algo/ca/core/tracking/CaTrackFit.cxx b/algo/ca/core/tracking/CaTrackFit.cxx index f2385979c2..508b7e0fff 100644 --- a/algo/ca/core/tracking/CaTrackFit.cxx +++ b/algo/ca/core/tracking/CaTrackFit.cxx @@ -879,7 +879,7 @@ namespace cbm::algo::ca } - void TrackFit::MultipleScattering(fvec radThick) + void TrackFit::MultipleScattering(fvec radThick, fvec qp0) { cnst ONE = 1.; @@ -891,13 +891,13 @@ namespace cbm::algo::ca fvec h = txtx + tyty; fvec t = sqrt(txtx1 + tyty); fvec h2 = h * h; - fvec qp0t = fQp0 * t; + fvec qp0t = qp0 * t; cnst c1 = 0.0136f, c2 = c1 * 0.038f, c3 = c2 * 0.5f, c4 = -c3 / 2.0f, c5 = c3 / 3.0f, c6 = -c3 / 4.0f; fvec s0 = (c1 + c2 * log(radThick) + c3 * h + h2 * (c4 + c5 * h + c6 * h2)) * qp0t; //fvec a = ( (ONE+mass2*qp0*qp0t)*radThick*s0*s0 ); - fvec a = ((t + fMass2 * fQp0 * qp0t) * radThick * s0 * s0); + fvec a = ((t + fMass2 * qp0 * qp0t) * radThick * s0 * s0); fTr.C22()(fMask) += txtx1 * a; fTr.C32()(fMask) += tx * ty * a; diff --git a/algo/ca/core/tracking/CaTrackFit.h b/algo/ca/core/tracking/CaTrackFit.h index 1bb492e339..aa01221011 100644 --- a/algo/ca/core/tracking/CaTrackFit.h +++ b/algo/ca/core/tracking/CaTrackFit.h @@ -11,6 +11,7 @@ #pragma once // include this header only once per compilation unit #include "CaField.h" +#include "CaMeasurementTime.h" #include "CaMeasurementU.h" #include "CaMeasurementXy.h" #include "CaSimd.h" @@ -100,6 +101,9 @@ namespace cbm::algo::ca /// filter the track with the time measurement void FilterTime(fvec t, fvec dt2, fvec timeInfo); + /// filter the track with the time measurement + void FilterTime(MeasurementTime<fvec> mt) { FilterTime(mt.T(), mt.Dt2(), mt.NdfT()); } + /// filter the track with the hit void FilterHit(const ca::Station& s, const ca::Hit& h); @@ -142,8 +146,11 @@ namespace cbm::algo::ca fvec upstreamDirection); + /// apply multiple scattering correction to the track with the given Qp0 + void MultipleScattering(fvec radThick, fvec qp0); + /// apply multiple scattering correction to the track - void MultipleScattering(fvec radThick); + void MultipleScattering(fvec radThick) { MultipleScattering(radThick, fQp0); } /// apply multiple scattering correction in thick material to the track void MultipleScatteringInThickMaterial(fvec radThick, fvec thickness, bool fDownstream); diff --git a/reco/KF/CbmKFTrackFitter.cxx b/reco/KF/CbmKFTrackFitter.cxx new file mode 100644 index 0000000000..ba8347cc1f --- /dev/null +++ b/reco/KF/CbmKFTrackFitter.cxx @@ -0,0 +1,637 @@ +/* Copyright (C) 2023 GSI Helmholtzzentrum fuer Schwerionenforschung, Darmstadt + SPDX-License-Identifier: GPL-3.0-only + Authors: Sergey Gorbunov [committer] */ + +#include "CbmKFTrackFitter.h" + +#include "CbmGlobalTrack.h" +#include "CbmL1.h" +#include "CbmL1Util.h" +#include "CbmMuchPixelHit.h" +#include "CbmMuchTrack.h" +#include "CbmMuchTrackingInterface.h" +#include "CbmMvdHit.h" +#include "CbmMvdTrackingInterface.h" +#include "CbmStsAddress.h" +#include "CbmStsHit.h" +#include "CbmStsSetup.h" +#include "CbmStsTrack.h" +#include "CbmStsTrackingInterface.h" +#include "CbmTofHit.h" +#include "CbmTofTrack.h" +#include "CbmTofTrackingInterface.h" +#include "CbmTrdHit.h" +#include "CbmTrdTrack.h" +#include "CbmTrdTrackingInterface.h" + +#include "FairRootManager.h" + +#include "TClonesArray.h" +#include "TDatabasePDG.h" + +#include "CaConstants.h" +#include "CaFramework.h" +#include "CaSimd.h" +#include "CaStation.h" +#include "CaTrackFit.h" +#include "CaTrackParam.h" +#include "KFParticleDatabase.h" + +using std::vector; +using namespace std; +using ca::fmask; +using ca::fvec; + +namespace +{ + using namespace cbm::algo; +} + +void CbmKFTrackFitter::Track::MakeConsistent() +{ + // sort the nodes in z + std::sort(fNodes.begin(), fNodes.end(), [](const FitNode& a, const FitNode& b) { return a.fZ < b.fZ; }); + + // set the first and last hit nodes + fFirstHitNode = fNodes.size() - 1; + fLastHitNode = 0; + for (int i = 0; i < (int) fNodes.size(); i++) { + if (fNodes[i].fMxy.NdfX()[0] + fNodes[i].fMxy.NdfY()[0] > 0) { + fFirstHitNode = std::min(fFirstHitNode, i); + fLastHitNode = std::max(fLastHitNode, i); + } + } +} + + +CbmKFTrackFitter::CbmKFTrackFitter() {} + +CbmKFTrackFitter::~CbmKFTrackFitter() {} + +void CbmKFTrackFitter::Init() +{ + if (fIsInitialized) return; + + if (!CbmL1::Instance() || !CbmL1::Instance()->fpAlgo) { + LOG(fatal) << "CbmKFTrackFitter: no CbmL1 task initialized "; + } + + FairRootManager* ioman = FairRootManager::Instance(); + + if (!ioman) { LOG(fatal) << "CbmKFTrackFitter: no FairRootManager"; } + + // Get hits + + fInputMvdHits = dynamic_cast<TClonesArray*>(ioman->GetObject("MvdHit")); + fInputStsHits = dynamic_cast<TClonesArray*>(ioman->GetObject("StsHit")); + fInputTrdHits = dynamic_cast<TClonesArray*>(ioman->GetObject("TrdHit")); + fInputMuchHits = dynamic_cast<TClonesArray*>(ioman->GetObject("MuchHit")); + fInputTofHits = dynamic_cast<TClonesArray*>(ioman->GetObject("TofHit")); + + // Get global tracks + fInputGlobalTracks = dynamic_cast<TClonesArray*>(ioman->GetObject("GlobalTrack")); + + // Get detector tracks + fInputStsTracks = dynamic_cast<TClonesArray*>(ioman->GetObject("StsTrack")); + fInputMuchTracks = dynamic_cast<TClonesArray*>(ioman->GetObject("MuchTrack")); + fInputTrdTracks = dynamic_cast<TClonesArray*>(ioman->GetObject("TrdTrack")); + fInputTofTracks = dynamic_cast<TClonesArray*>(ioman->GetObject("TofTrack")); + + fIsInitialized = true; +} + +void CbmKFTrackFitter::SetParticleHypothesis(int pdg) +{ + TParticlePDG* particlePDG = TDatabasePDG::Instance()->GetParticle(pdg); + if (!particlePDG) { + LOG(fatal) << "CbmKFTrackFitter: particle PDG " << pdg << " is not in the data base, please set the mass manually"; + return; + } + fMass = particlePDG->Mass(); + fIsElectron = (abs(pdg) == 11); +} + +void CbmKFTrackFitter::SetMassHypothesis(double mass) +{ + assert(mass >= 0.); + fMass = mass; +} + + +void CbmKFTrackFitter::SetElectronFlag(bool isElectron) { fIsElectron = isElectron; } + + +bool CbmKFTrackFitter::CreateGlobalTrack(CbmKFTrackFitter::Track& kfTrack, const CbmGlobalTrack& globalTrack) +{ + Init(); + if (!fIsInitialized) return false; + + std::vector<CbmMvdHit> mvdHits; + std::vector<CbmStsHit> stsHits; + std::vector<CbmMuchPixelHit> muchHits; + std::vector<CbmTrdHit> trdHits; + std::vector<CbmTofHit> tofHits; + + kfTrack = {}; + + // Read MVD & STS hits + + if (globalTrack.GetStsTrackIndex() >= 0) { + + int stsTrackIndex = globalTrack.GetStsTrackIndex(); + + if (!fInputStsTracks) { + LOG(error) << "CbmKFTrackFitter: Sts track array not found!"; + return false; + } + if (stsTrackIndex >= fInputStsTracks->GetEntriesFast()) { + LOG(error) << "CbmKFTrackFitter: Sts track index " << stsTrackIndex << " is out of range!"; + return false; + } + auto* stsTrack = dynamic_cast<const CbmStsTrack*>(fInputStsTracks->At(stsTrackIndex)); + if (!stsTrack) { + LOG(error) << "CbmKFTrackFitter: Sts track is null!"; + return false; + } + + // Read MVD hits + + int nMvdHits = stsTrack->GetNofMvdHits(); + if (nMvdHits > 0) { + if (!fInputMvdHits) { + LOG(error) << "CbmKFTrackFitter: Mvd hit array not found!"; + return false; + } + mvdHits.reserve(nMvdHits); + for (int ih = 0; ih < nMvdHits; ih++) { + int hitIndex = stsTrack->GetMvdHitIndex(ih); + if (hitIndex >= fInputMvdHits->GetEntriesFast()) { + LOG(error) << "CbmKFTrackFitter: Mvd hit index " << hitIndex << " is out of range!"; + return false; + } + mvdHits.push_back(*dynamic_cast<const CbmMvdHit*>(fInputMvdHits->At(hitIndex))); + } + } + + // Read STS hits + + int nStsHits = stsTrack->GetNofStsHits(); + if (nStsHits > 0) { + if (!fInputStsHits) { + LOG(error) << "CbmKFTrackFitter: Sts hit array not found!"; + return false; + } + stsHits.reserve(nStsHits); + for (int ih = 0; ih < nStsHits; ih++) { + int hitIndex = stsTrack->GetStsHitIndex(ih); + if (hitIndex >= fInputStsHits->GetEntriesFast()) { + LOG(error) << "CbmKFTrackFitter: Sts hit index " << hitIndex << " is out of range!"; + return false; + } + stsHits.push_back(*dynamic_cast<const CbmStsHit*>(fInputStsHits->At(hitIndex))); + } + } + } // MVD & STS hits + + + // Read Much hits + + if (globalTrack.GetMuchTrackIndex() >= 0) { + int muchTrackIndex = globalTrack.GetMuchTrackIndex(); + if (!fInputMuchTracks) { + LOG(error) << "CbmKFTrackFitter: Much track array not found!"; + return false; + } + if (muchTrackIndex >= fInputMuchTracks->GetEntriesFast()) { + LOG(error) << "CbmKFTrackFitter: Much track index " << muchTrackIndex << " is out of range!"; + return false; + } + auto* track = dynamic_cast<const CbmMuchTrack*>(fInputMuchTracks->At(muchTrackIndex)); + if (!track) { + LOG(error) << "CbmKFTrackFitter: Much track is null!"; + return false; + } + int nHits = track->GetNofHits(); + if (nHits > 0) { + if (!fInputMuchHits) { + LOG(error) << "CbmKFTrackFitter: Much hit array not found!"; + return false; + } + muchHits.reserve(nHits); + for (int ih = 0; ih < nHits; ih++) { + int hitIndex = track->GetHitIndex(ih); + if (hitIndex >= fInputMuchHits->GetEntriesFast()) { + LOG(error) << "CbmKFTrackFitter: Much hit index " << hitIndex << " is out of range!"; + return false; + } + muchHits.push_back(*dynamic_cast<const CbmMuchPixelHit*>(fInputMuchHits->At(hitIndex))); + } + } + } + + // Read TRD hits + + if (globalTrack.GetTrdTrackIndex() >= 0) { + int trdTrackIndex = globalTrack.GetTrdTrackIndex(); + if (!fInputTrdTracks) { + LOG(error) << "CbmKFTrackFitter: Trd track array not found!"; + return false; + } + if (trdTrackIndex >= fInputTrdTracks->GetEntriesFast()) { + LOG(error) << "CbmKFTrackFitter: Trd track index " << trdTrackIndex << " is out of range!"; + return false; + } + auto* track = dynamic_cast<const CbmTrdTrack*>(fInputTrdTracks->At(trdTrackIndex)); + if (!track) { + LOG(error) << "CbmKFTrackFitter: Trd track is null!"; + return false; + } + int nHits = track->GetNofHits(); + if (nHits > 0) { + if (!fInputTrdHits) { + LOG(error) << "CbmKFTrackFitter: Trd hit array not found!"; + return false; + } + trdHits.reserve(nHits); + for (int ih = 0; ih < nHits; ih++) { + int hitIndex = track->GetHitIndex(ih); + if (hitIndex >= fInputTrdHits->GetEntriesFast()) { + LOG(error) << "CbmKFTrackFitter: Trd hit index " << hitIndex << " is out of range!"; + return false; + } + trdHits.push_back(*dynamic_cast<const CbmTrdHit*>(fInputTrdHits->At(hitIndex))); + } + } + } + + + // Read TOF hits + + if (globalTrack.GetTofTrackIndex() >= 0) { + int tofTrackIndex = globalTrack.GetTofTrackIndex(); + if (!fInputTofTracks) { + LOG(error) << "CbmKFTrackFitter: Trd track array not found!"; + return false; + } + if (tofTrackIndex >= fInputTofTracks->GetEntriesFast()) { + LOG(error) << "CbmKFTrackFitter: Trd track index " << tofTrackIndex << " is out of range!"; + return false; + } + auto* track = dynamic_cast<const CbmTofTrack*>(fInputTofTracks->At(tofTrackIndex)); + if (!track) { + LOG(error) << "CbmKFTrackFitter: Tof track is null!"; + return false; + } + + int nHits = track->GetNofHits(); + if (nHits > 0) { + if (!fInputTofHits) { + LOG(error) << "CbmKFTrackFitter: Tof hit array not found!"; + return false; + } + tofHits.reserve(nHits); + for (int ih = 0; ih < nHits; ih++) { + int hitIndex = track->GetHitIndex(ih); + if (hitIndex >= fInputTofHits->GetEntriesFast()) { + LOG(error) << "CbmKFTrackFitter: Tof hit index " << hitIndex << " is out of range!"; + return false; + } + tofHits.push_back(*dynamic_cast<const CbmTofHit*>(fInputTofHits->At(hitIndex))); + } + } + } + + return CreateTrack(kfTrack, *globalTrack.GetParamFirst(), mvdHits, stsHits, muchHits, trdHits, tofHits); +} + + +bool CbmKFTrackFitter::CreateMvdStsTrack(CbmKFTrackFitter::Track& kfTrack, const CbmStsTrack& stsTrack) +{ + Init(); + if (!fIsInitialized) return false; + + std::vector<CbmMvdHit> mvdHits; + std::vector<CbmStsHit> stsHits; + std::vector<CbmMuchPixelHit> muchHits; + std::vector<CbmTrdHit> trdHits; + std::vector<CbmTofHit> tofHits; + + kfTrack = {}; + + // Read MVD hits + + int nMvdHits = stsTrack.GetNofMvdHits(); + if (nMvdHits > 0) { + if (!fInputMvdHits) { + LOG(error) << "CbmKFTrackFitter: Mvd hit array not found!"; + return false; + } + mvdHits.reserve(nMvdHits); + for (int ih = 0; ih < nMvdHits; ih++) { + int hitIndex = stsTrack.GetMvdHitIndex(ih); + mvdHits.push_back(*dynamic_cast<const CbmMvdHit*>(fInputMvdHits->At(hitIndex))); + } + } + + // Read STS hits + + int nStsHits = stsTrack.GetNofStsHits(); + if (nStsHits > 0) { + if (!fInputStsHits) { + LOG(error) << "CbmKFTrackFitter: Sts hit array not found!"; + return false; + } + stsHits.reserve(nStsHits); + for (int ih = 0; ih < nStsHits; ih++) { + int hitIndex = stsTrack.GetStsHitIndex(ih); + stsHits.push_back(*dynamic_cast<const CbmStsHit*>(fInputStsHits->At(hitIndex))); + } + } + + return CreateTrack(kfTrack, *stsTrack.GetParamFirst(), mvdHits, stsHits, muchHits, trdHits, tofHits); +} + + +bool CbmKFTrackFitter::CreateTrack(CbmKFTrackFitter::Track& kfTrack, const FairTrackParam& trackFirst, + const std::vector<CbmMvdHit>& mvdHits, const std::vector<CbmStsHit>& stsHits, + const std::vector<CbmMuchPixelHit>& muchHits, const std::vector<CbmTrdHit>& trdHits, + const std::vector<CbmTofHit>& tofHits + +) +{ + kfTrack = {}; + Init(); + if (!fIsInitialized) return false; + + std::vector<const CbmPixelHit*> hits; + + std::vector<int> hitStations; + + const ca::Parameters& caPar = CbmL1::Instance()->fpAlgo->GetParameters(); + + for (auto& h : mvdHits) { + hits.push_back(dynamic_cast<const CbmPixelHit*>(&h)); + int stIdx = CbmMvdTrackingInterface::Instance()->GetTrackingStationIndex(&h); + hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kMvd)); + } + + for (auto& h : stsHits) { + hits.push_back(dynamic_cast<const CbmPixelHit*>(&h)); + int stIdx = CbmStsTrackingInterface::Instance()->GetTrackingStationIndex(&h); + hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kSts)); + } + + for (auto& h : muchHits) { + hits.push_back(dynamic_cast<const CbmPixelHit*>(&h)); + int stIdx = CbmMuchTrackingInterface::Instance()->GetTrackingStationIndex(&h); + hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kMuch)); + } + + for (auto& h : trdHits) { + hits.push_back(dynamic_cast<const CbmPixelHit*>(&h)); + int stIdx = CbmTrdTrackingInterface::Instance()->GetTrackingStationIndex(&h); + hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kTrd)); + } + + for (auto& h : tofHits) { + hits.push_back(dynamic_cast<const CbmPixelHit*>(&h)); + int stIdx = CbmTofTrackingInterface::Instance()->GetTrackingStationIndex(&h); + hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kTof)); + } + + CbmKFTrackFitter::Track t; + + int nStations = caPar.GetNstationsActive(); + + t.fNodes.resize(nStations); + for (int i = 0; i < nStations; i++) { + t.fNodes[i].fMaterialLayer = i; + t.fNodes[i].fZ = caPar.GetStation(i).GetZScal(); + t.fNodes[i].fRadThick = 0.; + t.fNodes[i].fIsRadThickSet = false; + t.fNodes[i].fIsFitted = false; + } + + t.fFirstHitNode = nStations - 1; + t.fLastHitNode = 0; + + for (unsigned int i = 0; i < hits.size(); i++) { + + assert(hits[i]); + const CbmPixelHit& h = *hits[i]; + + int ista = hitStations[i]; + + if (ista < 0) continue; + assert(ista < nStations); + + CbmKFTrackFitter::FitNode& n = t.fNodes[ista]; + + n.fZ = h.GetZ(); + + n.fMxy.SetX(h.GetX()); + n.fMxy.SetY(h.GetY()); + n.fMxy.SetDx2(h.GetDx() * h.GetDx()); + n.fMxy.SetDy2(h.GetDy() * h.GetDy()); + n.fMxy.SetDxy(h.GetDxy()); + n.fMxy.SetNdfX(1); + n.fMxy.SetNdfY(1); + + n.fMt.SetT(h.GetTime()); + n.fMt.SetDt2(h.GetTimeError() * h.GetTimeError()); + n.fMt.SetNdfT(1); + + n.fRadThick = 0.; + + t.fFirstHitNode = std::min(t.fFirstHitNode, ista); + t.fLastHitNode = std::max(t.fLastHitNode, ista); + } + + ca::TrackParamD tmp = cbm::L1Util::ConvertTrackParam(trackFirst); + + t.fNodes[t.fFirstHitNode].fTrack.Set(tmp); + t.fNodes[t.fFirstHitNode].fIsFitted = 1; + + kfTrack = t; + return true; +} + + +void CbmKFTrackFitter::FilterFirstMeasurement(const FitNode& n) +{ + // a special routine to filter the first measurement. + // the measurement errors are simply copied to the track covariance matrix + + const auto& mxy = n.fMxy; + const auto& mt = n.fMt; + + auto& tr = fFit.Tr(); + tr.ResetErrors(mxy.Dx2(), mxy.Dy2(), 1., 1., 1., 1.e4, 1.e2); + tr.SetC10(mxy.Dxy()); + tr.SetX(mxy.X()); + tr.SetY(mxy.Y()); + tr.SetNdf(-5 + mxy.NdfX() + mxy.NdfY()); + if (mt.NdfT()[0] > 0) { + tr.SetTime(mt.T()); + tr.SetC55(mt.Dt2()); + tr.SetNdfTime(-2 + 1); + } + else { + tr.SetNdfTime(-2); + } + tr.SetVi(0.); +} + + +void CbmKFTrackFitter::AddMaterialEffects(const CbmKFTrackFitter::Track& t, CbmKFTrackFitter::FitNode& n, + bool upstreamDirection) +{ + // add material effects + if (n.fMaterialLayer < 0) { return; } + + // calculate the radiation thickness from the current track + if (!n.fIsRadThickSet && !n.fIsFitted) { + n.fRadThick = CbmL1::Instance()->fpAlgo->GetParameters().GetMaterialThicknessScal( + n.fMaterialLayer, fFit.Tr().GetX()[0], fFit.Tr().GetY()[0]); + } + + fvec msQp0 = t.fMsQp0; + if (!t.fIsMsQp0Set) { + if (n.fIsFitted) { msQp0 = n.fTrack.GetQp(); } + else { + msQp0 = fFit.Tr().GetQp(); + } + } + fFit.MultipleScattering(n.fRadThick, msQp0); + fFit.EnergyLossCorrection(n.fRadThick, upstreamDirection ? fvec::One() : fvec::Zero()); +} + +void CbmKFTrackFitter::FitTrack(CbmKFTrackFitter::Track& t) +{ + // fit the track + + // ensure that the fitter is initialized + Init(); + + t.MakeConsistent(); + + fFit.SetMask(fmask::One()); + fFit.SetParticleMass(fMass); + + ca::FieldRegion field _fvecalignment; + field.SetUseOriginalField(); + + int nNodes = t.fNodes.size(); + + // fit downstream. The approximation is taken from the first hit node + { + FitNode& n = t.fNodes[t.fFirstHitNode]; + fFit.SetTrack(n.fTrack); + FilterFirstMeasurement(n); + n.fTrack = fFit.Tr(); + n.fIsFitted = false; + } + + for (int iNode = t.fFirstHitNode + 1; iNode < nNodes; iNode++) { + FitNode& n = t.fNodes[iNode]; + fFit.Extrapolate(n.fZ, field); + if (n.fIsFitted) { fFit.SetQp0(n.fTrack.GetQp()); } + AddMaterialEffects(t, n, false); + n.fTrack = fFit.Tr(); + n.fIsFitted = false; + fFit.FilterXY(n.fMxy); + fFit.FilterTime(n.fMt); + if (iNode == t.fLastHitNode) { n.fTrack = fFit.Tr(); } + if (iNode >= t.fLastHitNode) { n.fIsFitted = true; } + } + + + // fit upstream + { + FitNode& n = t.fNodes[t.fLastHitNode]; + fFit.SetTrack(n.fTrack); + FilterFirstMeasurement(n); + n.fIsFitted = true; + } + + for (int iNode = t.fLastHitNode - 1; iNode >= 0; iNode--) { + FitNode& n = t.fNodes[iNode]; + fFit.Extrapolate(n.fZ, field); + fFit.FilterXY(n.fMxy); + fFit.FilterTime(n.fMt); + + // combine partially fitted downstream and upstream tracks + if (iNode > t.fFirstHitNode) { Smooth(n.fTrack, fFit.Tr()); } + else { + n.fTrack = fFit.Tr(); + } + n.fIsFitted = true; + fFit.SetQp0(n.fTrack.GetQp()); + AddMaterialEffects(t, n, true); + if (iNode == t.fFirstHitNode) { n.fTrack = fFit.Tr(); } + } + + // distribute the final chi2, ndf to all nodes + + const auto& tt = t.fNodes[t.fFirstHitNode].fTrack; + for (auto& n : t.fNodes) { + n.fTrack.SetNdf(tt.GetNdf()); + n.fTrack.SetNdfTime(tt.GetNdfTime()); + n.fTrack.SetChiSq(tt.GetChiSq()); + n.fTrack.SetChiSqTime(tt.GetChiSqTime()); + } +} + + +void CbmKFTrackFitter::Smooth(ca::TrackParamV& t1, const ca::TrackParamV& t2) +{ + // combine two tracks + std::tie(t1.X(), t1.Y(), t1.C00(), t1.C10(), t1.C11()) = + Smooth2D(t1.X(), t1.Y(), t1.C00(), t1.C10(), t1.C11(), t2.X(), t2.Y(), t2.C00(), t2.C10(), t2.C11()); + + std::tie(t1.Tx(), t1.Ty(), t1.C22(), t1.C32(), t1.C33()) = + Smooth2D(t1.Tx(), t1.Ty(), t1.C22(), t1.C32(), t1.C33(), t2.Tx(), t2.Ty(), t2.C22(), t2.C32(), t2.C33()); + + std::tie(t1.Qp(), t1.C44()) = Smooth1D(t1.Qp(), t1.C44(), t2.Qp(), t2.C44()); + + std::tie(t1.Time(), t1.Vi(), t1.C55(), t1.C65(), t1.C66()) = + Smooth2D(t1.Time(), t1.Vi(), t1.C55(), t1.C65(), t1.C66(), t2.Time(), t2.Vi(), t2.C55(), t2.C65(), t2.C66()); + + t1.C20() = 0.; + t1.C21() = 0.; + t1.C30() = 0.; + t1.C31() = 0.; + t1.C40() = 0.; + t1.C41() = 0.; + t1.C42() = 0.; + t1.C43() = 0.; + t1.C50() = 0.; + t1.C51() = 0.; + t1.C52() = 0.; + t1.C53() = 0.; + t1.C54() = 0.; + t1.C60() = 0.; + t1.C61() = 0.; + t1.C62() = 0.; + t1.C63() = 0.; + t1.C64() = 0.; +} + +std::tuple<ca::fvec, ca::fvec> CbmKFTrackFitter::Smooth1D(ca::fvec x1, ca::fvec Cxx1, ca::fvec x2, ca::fvec Cxx2) +{ + // combine two 1D values + ca::fvec x = (x1 * Cxx1 + x2 * Cxx2) / (Cxx1 + Cxx2); + ca::fvec Cxx = Cxx1 * Cxx2 / (Cxx1 + Cxx2); + return std::tuple(x, Cxx); +} + +std::tuple<ca::fvec, ca::fvec, ca::fvec, ca::fvec, ca::fvec> +CbmKFTrackFitter::Smooth2D(ca::fvec x1, ca::fvec y1, ca::fvec Cxx1, ca::fvec /*Cxy1*/, ca::fvec Cyy1, ca::fvec x2, + ca::fvec y2, ca::fvec Cxx2, ca::fvec /*Cxy2*/, ca::fvec Cyy2) +{ + // combine two 2D values + // TODO: do it right + auto [x, Cxx] = Smooth1D(x1, Cxx1, x2, Cxx2); + auto [y, Cyy] = Smooth1D(y1, Cyy1, y2, Cyy2); + return std::tuple(x, y, Cxx, ca::fvec::Zero(), Cyy); +} \ No newline at end of file diff --git a/reco/KF/CbmKFTrackFitter.h b/reco/KF/CbmKFTrackFitter.h new file mode 100644 index 0000000000..25d88910bb --- /dev/null +++ b/reco/KF/CbmKFTrackFitter.h @@ -0,0 +1,160 @@ +/* Copyright (C) 2023 GSI Helmholtzzentrum fuer Schwerionenforschung, Darmstadt + SPDX-License-Identifier: GPL-3.0-only + Authors: Sergey Gorbunov [committer] */ + +#pragma once // include this header only once per compilation unit + + +#include "CbmDefs.h" + +#include <vector> + +#include "CaConstants.h" +#include "CaMeasurementTime.h" +#include "CaMeasurementXy.h" +#include "CaSimd.h" +#include "CaTrackFit.h" +#include "CaTrackParam.h" + +class CbmMvdHit; +class CbmStsHit; +class CbmStsTrack; +class CbmGlobalTrack; +class CbmKFVertex; +class CbmMuchPixelHit; +class CbmTrdHit; +class CbmTofHit; + +class FairTrackParam; +class TClonesArray; + +namespace cbm::algo::ca +{ + class FieldRegion; +} // namespace cbm::algo::ca + +namespace +{ + using namespace cbm::algo; +} + +/// A fitter for the Cbm tracks +/// +class CbmKFTrackFitter { +public: + /// A node on the trajectory where the track parameters are: + /// a) measured and / or + /// b) scattered and / or + /// c) need to be estimated + /// The nodes must be ordered by increasing Z + /// + struct FitNode { + + double fZ {0.}; ///< Z coordinate + + ca::TrackParamV fTrack {}; ///< fitted track + + /// == Material information (if present) + // TODO: change to the material layer index when the material layer is implemented + int fMaterialLayer {-1}; ///< index of the material layer. Currently equal to the active tracking station index + + double fRadThick {0.}; ///< material radiation thickness at fZ + + /// == Hit information ( if present ) + + ca::MeasurementXy<ca::fvec> fMxy {}; ///< XY-measurement at fZ + + ca::MeasurementTime<ca::fvec> fMt {}; ///< time measurement at fZ + + ECbmModuleId fSystemId {ECbmModuleId::kNotExist}; ///< detector system ID of the hit + int fDetectorId {-1}; ///< detector ID of the hit + int fHitId {-1}; ///< hit ID + + /// == Flags etc + + bool fIsFitted {false}; ///< true if the node is fitted, false if the fit failed + bool fIsRadThickSet {false}; ///< true if the radiation thickness is set + int fReference1 {-1}; ///< some reference can be set by the user + int fReference2 {-1}; ///< some reference can be set by the user + }; + + /// A track to be fitted + struct Track { + std::vector<FitNode> fNodes; ///< nodes on the track + int fFirstHitNode {-1}; ///< index of the first node with the XY measurement + int fLastHitNode {-1}; ///< index of the last node with the XY measurement + + /// externally defined inverse momentum for the Multiple Scattering calculation. + /// It is used for the tracks in field-free regions. + /// When the momentum can be fitted, the fitted value is used. + /// the default value is set to 0.1 GeV/c + double fMsQp0 {1. / 0.1}; + bool fIsMsQp0Set {false}; + + void MakeConsistent(); // make the structure fields consistent + }; + + CbmKFTrackFitter(); + ~CbmKFTrackFitter(); + + /// initialize the fitter. It must be called in the Init() of the user task. + /// when called later, it sees track branches always empty for whatever reason + void Init(); + + /// set particle hypothesis (mass and electron flag) via particle PDG + void SetParticleHypothesis(int pid); + + /// set particle mass + void SetMassHypothesis(double mass); + + /// set electron flag (bremmstrallung will be applied) + void SetElectronFlag(bool isElectron); + + bool CreateTrack(Track& kfTrack, const FairTrackParam& trackFirst, const std::vector<CbmMvdHit>& mvdHits, + const std::vector<CbmStsHit>& stsHits, const std::vector<CbmMuchPixelHit>& muchHits, + const std::vector<CbmTrdHit>& trdHits, const std::vector<CbmTofHit>& tofHits); + + bool CreateMvdStsTrack(Track& kfTrack, const CbmStsTrack& stsTrack); + bool CreateGlobalTrack(Track& kfTrack, const CbmGlobalTrack& globalTrack); + + /// fit the track + void FitTrack(CbmKFTrackFitter::Track& t); + + /// fit sts tracks + // void FitStsTracks(vector<CbmStsTrack>& Tracks, const vector<int>& pidHypo); + +private: + void FilterFirstMeasurement(const FitNode& n); + void AddMaterialEffects(const Track& t, FitNode& n, bool upstreamDirection); + // combine two tracks + void Smooth(ca::TrackParamV& t1, const ca::TrackParamV& t2); + + std::tuple<ca::fvec, ca::fvec> Smooth1D(ca::fvec x1, ca::fvec Cxx1, ca::fvec x2, ca::fvec Cxx2); + std::tuple<ca::fvec, ca::fvec, ca::fvec, ca::fvec, ca::fvec> Smooth2D(ca::fvec x1, ca::fvec y1, ca::fvec Cxx1, + ca::fvec Cxy1, ca::fvec Cyy1, ca::fvec x2, + ca::fvec y2, ca::fvec Cxx2, ca::fvec Cxy2, + ca::fvec Cyy2); + +private: + // input data arrays + TClonesArray* fInputMvdHits {nullptr}; + TClonesArray* fInputStsHits {nullptr}; + TClonesArray* fInputMuchHits {nullptr}; + TClonesArray* fInputTrdHits {nullptr}; + TClonesArray* fInputTofHits {nullptr}; + + TClonesArray* fInputGlobalTracks {nullptr}; + TClonesArray* fInputStsTracks {nullptr}; + TClonesArray* fInputMuchTracks {nullptr}; + TClonesArray* fInputTrdTracks {nullptr}; + TClonesArray* fInputTofTracks {nullptr}; + + // + + bool fIsInitialized = {false}; // is the fitter initialized + + ca::TrackFit fFit; // track fit object + + double fMass {ca::constants::phys::PionMass}; // mass hypothesis for the fit + bool fIsElectron {false}; // fit track as an electron (with the bermsstrallung effect) +}; diff --git a/reco/KF/KF.cmake b/reco/KF/KF.cmake index 273f26b066..af1bd8f834 100644 --- a/reco/KF/KF.cmake +++ b/reco/KF/KF.cmake @@ -21,6 +21,7 @@ set(SRCS CbmKFTrackInterface.cxx CbmKFUMeasurement.cxx CbmKFVertexInterface.cxx + CbmKFTrackFitter.cxx #Interface/CbmEcalTrackExtrapolationKF.cxx Interface/CbmKFStsHit.cxx diff --git a/reco/KF/KFLinkDef.h b/reco/KF/KFLinkDef.h index 6f16bb68b6..d34ed89840 100644 --- a/reco/KF/KFLinkDef.h +++ b/reco/KF/KFLinkDef.h @@ -40,6 +40,7 @@ #pragma link C++ class CbmL1TofMerger + ; #pragma link C++ class CbmL1TrdTracklet + ; #pragma link C++ class CbmL1TrdTracklet4 + ; +#pragma link C++ class CbmKFTrackFitter + ; //#pragma link C++ class CbmL1TrdTrackFinderSts+; //#pragma link C++ class CbmL1CATrdTrackFinderSA+; -- GitLab