From 6e65296663ac520276116d6b1c80c0d16f91c29c Mon Sep 17 00:00:00 2001
From: "se.gorbunov" <se.gorbunov@gsi.de>
Date: Mon, 13 Nov 2023 21:49:58 +0000
Subject: [PATCH] KF: utility to refit global tracks with Kalman Filter
 Smoother

---
 algo/ca/core/data/CaMeasurementTime.h |   6 +-
 algo/ca/core/tracking/CaTrackFit.cxx  |   6 +-
 algo/ca/core/tracking/CaTrackFit.h    |   9 +-
 reco/KF/CbmKFTrackFitter.cxx          | 637 ++++++++++++++++++++++++++
 reco/KF/CbmKFTrackFitter.h            | 160 +++++++
 reco/KF/KF.cmake                      |   1 +
 reco/KF/KFLinkDef.h                   |   1 +
 7 files changed, 813 insertions(+), 7 deletions(-)
 create mode 100644 reco/KF/CbmKFTrackFitter.cxx
 create mode 100644 reco/KF/CbmKFTrackFitter.h

diff --git a/algo/ca/core/data/CaMeasurementTime.h b/algo/ca/core/data/CaMeasurementTime.h
index f74597788e..a8a24532eb 100644
--- a/algo/ca/core/data/CaMeasurementTime.h
+++ b/algo/ca/core/data/CaMeasurementTime.h
@@ -85,14 +85,14 @@ namespace cbm::algo::ca
     ///------------------------------
     /// Data members
 
-    DataT fT {constants::Undef<DataT>};    ///< time coordinate of the measurement
-    DataT fDt2 {constants::Undef<DataT>};  ///< rms^2 of the time coordinate measurement
+    DataT fT {0};     ///< time coordinate of the measurement
+    DataT fDt2 {1.};  ///< rms^2 of the time coordinate measurement
 
     /// number of degrees of freedom (used for chi2 calculation)
     /// if ndf == 1, the measurement is used in fit and in the chi2 calculation
     /// if ndf == 0, the measurement is used neither in fit nor in the chi2 calculation
 
-    DataT fNdfT = constants::Undef<DataT>;  ///< ndf for the time coordinate measurement
+    DataT fNdfT {0};  ///< ndf for the time coordinate measurement
 
   } _fvecalignment;
 
diff --git a/algo/ca/core/tracking/CaTrackFit.cxx b/algo/ca/core/tracking/CaTrackFit.cxx
index f2385979c2..508b7e0fff 100644
--- a/algo/ca/core/tracking/CaTrackFit.cxx
+++ b/algo/ca/core/tracking/CaTrackFit.cxx
@@ -879,7 +879,7 @@ namespace cbm::algo::ca
   }
 
 
-  void TrackFit::MultipleScattering(fvec radThick)
+  void TrackFit::MultipleScattering(fvec radThick, fvec qp0)
   {
     cnst ONE = 1.;
 
@@ -891,13 +891,13 @@ namespace cbm::algo::ca
     fvec h     = txtx + tyty;
     fvec t     = sqrt(txtx1 + tyty);
     fvec h2    = h * h;
-    fvec qp0t  = fQp0 * t;
+    fvec qp0t  = qp0 * t;
 
     cnst c1 = 0.0136f, c2 = c1 * 0.038f, c3 = c2 * 0.5f, c4 = -c3 / 2.0f, c5 = c3 / 3.0f, c6 = -c3 / 4.0f;
 
     fvec s0 = (c1 + c2 * log(radThick) + c3 * h + h2 * (c4 + c5 * h + c6 * h2)) * qp0t;
     //fvec a = ( (ONE+mass2*qp0*qp0t)*radThick*s0*s0 );
-    fvec a = ((t + fMass2 * fQp0 * qp0t) * radThick * s0 * s0);
+    fvec a = ((t + fMass2 * qp0 * qp0t) * radThick * s0 * s0);
 
     fTr.C22()(fMask) += txtx1 * a;
     fTr.C32()(fMask) += tx * ty * a;
diff --git a/algo/ca/core/tracking/CaTrackFit.h b/algo/ca/core/tracking/CaTrackFit.h
index 1bb492e339..aa01221011 100644
--- a/algo/ca/core/tracking/CaTrackFit.h
+++ b/algo/ca/core/tracking/CaTrackFit.h
@@ -11,6 +11,7 @@
 #pragma once  // include this header only once per compilation unit
 
 #include "CaField.h"
+#include "CaMeasurementTime.h"
 #include "CaMeasurementU.h"
 #include "CaMeasurementXy.h"
 #include "CaSimd.h"
@@ -100,6 +101,9 @@ namespace cbm::algo::ca
     /// filter the track with the time measurement
     void FilterTime(fvec t, fvec dt2, fvec timeInfo);
 
+    /// filter the track with the time measurement
+    void FilterTime(MeasurementTime<fvec> mt) { FilterTime(mt.T(), mt.Dt2(), mt.NdfT()); }
+
     /// filter the track with the hit
     void FilterHit(const ca::Station& s, const ca::Hit& h);
 
@@ -142,8 +146,11 @@ namespace cbm::algo::ca
                               fvec upstreamDirection);
 
 
+    /// apply multiple scattering correction to the track with the given Qp0
+    void MultipleScattering(fvec radThick, fvec qp0);
+
     /// apply multiple scattering correction to the track
-    void MultipleScattering(fvec radThick);
+    void MultipleScattering(fvec radThick) { MultipleScattering(radThick, fQp0); }
 
     /// apply multiple scattering correction in thick material to the track
     void MultipleScatteringInThickMaterial(fvec radThick, fvec thickness, bool fDownstream);
diff --git a/reco/KF/CbmKFTrackFitter.cxx b/reco/KF/CbmKFTrackFitter.cxx
new file mode 100644
index 0000000000..ba8347cc1f
--- /dev/null
+++ b/reco/KF/CbmKFTrackFitter.cxx
@@ -0,0 +1,637 @@
+/* Copyright (C) 2023 GSI Helmholtzzentrum fuer Schwerionenforschung, Darmstadt
+   SPDX-License-Identifier: GPL-3.0-only
+   Authors: Sergey Gorbunov [committer] */
+
+#include "CbmKFTrackFitter.h"
+
+#include "CbmGlobalTrack.h"
+#include "CbmL1.h"
+#include "CbmL1Util.h"
+#include "CbmMuchPixelHit.h"
+#include "CbmMuchTrack.h"
+#include "CbmMuchTrackingInterface.h"
+#include "CbmMvdHit.h"
+#include "CbmMvdTrackingInterface.h"
+#include "CbmStsAddress.h"
+#include "CbmStsHit.h"
+#include "CbmStsSetup.h"
+#include "CbmStsTrack.h"
+#include "CbmStsTrackingInterface.h"
+#include "CbmTofHit.h"
+#include "CbmTofTrack.h"
+#include "CbmTofTrackingInterface.h"
+#include "CbmTrdHit.h"
+#include "CbmTrdTrack.h"
+#include "CbmTrdTrackingInterface.h"
+
+#include "FairRootManager.h"
+
+#include "TClonesArray.h"
+#include "TDatabasePDG.h"
+
+#include "CaConstants.h"
+#include "CaFramework.h"
+#include "CaSimd.h"
+#include "CaStation.h"
+#include "CaTrackFit.h"
+#include "CaTrackParam.h"
+#include "KFParticleDatabase.h"
+
+using std::vector;
+using namespace std;
+using ca::fmask;
+using ca::fvec;
+
+namespace
+{
+  using namespace cbm::algo;
+}
+
+void CbmKFTrackFitter::Track::MakeConsistent()
+{
+  // sort the nodes in z
+  std::sort(fNodes.begin(), fNodes.end(), [](const FitNode& a, const FitNode& b) { return a.fZ < b.fZ; });
+
+  // set the first and last hit nodes
+  fFirstHitNode = fNodes.size() - 1;
+  fLastHitNode  = 0;
+  for (int i = 0; i < (int) fNodes.size(); i++) {
+    if (fNodes[i].fMxy.NdfX()[0] + fNodes[i].fMxy.NdfY()[0] > 0) {
+      fFirstHitNode = std::min(fFirstHitNode, i);
+      fLastHitNode  = std::max(fLastHitNode, i);
+    }
+  }
+}
+
+
+CbmKFTrackFitter::CbmKFTrackFitter() {}
+
+CbmKFTrackFitter::~CbmKFTrackFitter() {}
+
+void CbmKFTrackFitter::Init()
+{
+  if (fIsInitialized) return;
+
+  if (!CbmL1::Instance() || !CbmL1::Instance()->fpAlgo) {
+    LOG(fatal) << "CbmKFTrackFitter: no CbmL1 task initialized ";
+  }
+
+  FairRootManager* ioman = FairRootManager::Instance();
+
+  if (!ioman) { LOG(fatal) << "CbmKFTrackFitter: no FairRootManager"; }
+
+  // Get hits
+
+  fInputMvdHits  = dynamic_cast<TClonesArray*>(ioman->GetObject("MvdHit"));
+  fInputStsHits  = dynamic_cast<TClonesArray*>(ioman->GetObject("StsHit"));
+  fInputTrdHits  = dynamic_cast<TClonesArray*>(ioman->GetObject("TrdHit"));
+  fInputMuchHits = dynamic_cast<TClonesArray*>(ioman->GetObject("MuchHit"));
+  fInputTofHits  = dynamic_cast<TClonesArray*>(ioman->GetObject("TofHit"));
+
+  // Get global tracks
+  fInputGlobalTracks = dynamic_cast<TClonesArray*>(ioman->GetObject("GlobalTrack"));
+
+  // Get detector tracks
+  fInputStsTracks  = dynamic_cast<TClonesArray*>(ioman->GetObject("StsTrack"));
+  fInputMuchTracks = dynamic_cast<TClonesArray*>(ioman->GetObject("MuchTrack"));
+  fInputTrdTracks  = dynamic_cast<TClonesArray*>(ioman->GetObject("TrdTrack"));
+  fInputTofTracks  = dynamic_cast<TClonesArray*>(ioman->GetObject("TofTrack"));
+
+  fIsInitialized = true;
+}
+
+void CbmKFTrackFitter::SetParticleHypothesis(int pdg)
+{
+  TParticlePDG* particlePDG = TDatabasePDG::Instance()->GetParticle(pdg);
+  if (!particlePDG) {
+    LOG(fatal) << "CbmKFTrackFitter: particle PDG " << pdg << " is not in the data base, please set the mass manually";
+    return;
+  }
+  fMass       = particlePDG->Mass();
+  fIsElectron = (abs(pdg) == 11);
+}
+
+void CbmKFTrackFitter::SetMassHypothesis(double mass)
+{
+  assert(mass >= 0.);
+  fMass = mass;
+}
+
+
+void CbmKFTrackFitter::SetElectronFlag(bool isElectron) { fIsElectron = isElectron; }
+
+
+bool CbmKFTrackFitter::CreateGlobalTrack(CbmKFTrackFitter::Track& kfTrack, const CbmGlobalTrack& globalTrack)
+{
+  Init();
+  if (!fIsInitialized) return false;
+
+  std::vector<CbmMvdHit> mvdHits;
+  std::vector<CbmStsHit> stsHits;
+  std::vector<CbmMuchPixelHit> muchHits;
+  std::vector<CbmTrdHit> trdHits;
+  std::vector<CbmTofHit> tofHits;
+
+  kfTrack = {};
+
+  // Read MVD & STS hits
+
+  if (globalTrack.GetStsTrackIndex() >= 0) {
+
+    int stsTrackIndex = globalTrack.GetStsTrackIndex();
+
+    if (!fInputStsTracks) {
+      LOG(error) << "CbmKFTrackFitter: Sts track array not found!";
+      return false;
+    }
+    if (stsTrackIndex >= fInputStsTracks->GetEntriesFast()) {
+      LOG(error) << "CbmKFTrackFitter: Sts track index " << stsTrackIndex << " is out of range!";
+      return false;
+    }
+    auto* stsTrack = dynamic_cast<const CbmStsTrack*>(fInputStsTracks->At(stsTrackIndex));
+    if (!stsTrack) {
+      LOG(error) << "CbmKFTrackFitter: Sts track is null!";
+      return false;
+    }
+
+    // Read MVD hits
+
+    int nMvdHits = stsTrack->GetNofMvdHits();
+    if (nMvdHits > 0) {
+      if (!fInputMvdHits) {
+        LOG(error) << "CbmKFTrackFitter: Mvd hit array not found!";
+        return false;
+      }
+      mvdHits.reserve(nMvdHits);
+      for (int ih = 0; ih < nMvdHits; ih++) {
+        int hitIndex = stsTrack->GetMvdHitIndex(ih);
+        if (hitIndex >= fInputMvdHits->GetEntriesFast()) {
+          LOG(error) << "CbmKFTrackFitter: Mvd hit index " << hitIndex << " is out of range!";
+          return false;
+        }
+        mvdHits.push_back(*dynamic_cast<const CbmMvdHit*>(fInputMvdHits->At(hitIndex)));
+      }
+    }
+
+    // Read STS hits
+
+    int nStsHits = stsTrack->GetNofStsHits();
+    if (nStsHits > 0) {
+      if (!fInputStsHits) {
+        LOG(error) << "CbmKFTrackFitter: Sts hit array not found!";
+        return false;
+      }
+      stsHits.reserve(nStsHits);
+      for (int ih = 0; ih < nStsHits; ih++) {
+        int hitIndex = stsTrack->GetStsHitIndex(ih);
+        if (hitIndex >= fInputStsHits->GetEntriesFast()) {
+          LOG(error) << "CbmKFTrackFitter: Sts hit index " << hitIndex << " is out of range!";
+          return false;
+        }
+        stsHits.push_back(*dynamic_cast<const CbmStsHit*>(fInputStsHits->At(hitIndex)));
+      }
+    }
+  }  // MVD & STS hits
+
+
+  // Read Much hits
+
+  if (globalTrack.GetMuchTrackIndex() >= 0) {
+    int muchTrackIndex = globalTrack.GetMuchTrackIndex();
+    if (!fInputMuchTracks) {
+      LOG(error) << "CbmKFTrackFitter: Much track array not found!";
+      return false;
+    }
+    if (muchTrackIndex >= fInputMuchTracks->GetEntriesFast()) {
+      LOG(error) << "CbmKFTrackFitter: Much track index " << muchTrackIndex << " is out of range!";
+      return false;
+    }
+    auto* track = dynamic_cast<const CbmMuchTrack*>(fInputMuchTracks->At(muchTrackIndex));
+    if (!track) {
+      LOG(error) << "CbmKFTrackFitter: Much track is null!";
+      return false;
+    }
+    int nHits = track->GetNofHits();
+    if (nHits > 0) {
+      if (!fInputMuchHits) {
+        LOG(error) << "CbmKFTrackFitter: Much hit array not found!";
+        return false;
+      }
+      muchHits.reserve(nHits);
+      for (int ih = 0; ih < nHits; ih++) {
+        int hitIndex = track->GetHitIndex(ih);
+        if (hitIndex >= fInputMuchHits->GetEntriesFast()) {
+          LOG(error) << "CbmKFTrackFitter: Much hit index " << hitIndex << " is out of range!";
+          return false;
+        }
+        muchHits.push_back(*dynamic_cast<const CbmMuchPixelHit*>(fInputMuchHits->At(hitIndex)));
+      }
+    }
+  }
+
+  // Read TRD hits
+
+  if (globalTrack.GetTrdTrackIndex() >= 0) {
+    int trdTrackIndex = globalTrack.GetTrdTrackIndex();
+    if (!fInputTrdTracks) {
+      LOG(error) << "CbmKFTrackFitter: Trd track array not found!";
+      return false;
+    }
+    if (trdTrackIndex >= fInputTrdTracks->GetEntriesFast()) {
+      LOG(error) << "CbmKFTrackFitter: Trd track index " << trdTrackIndex << " is out of range!";
+      return false;
+    }
+    auto* track = dynamic_cast<const CbmTrdTrack*>(fInputTrdTracks->At(trdTrackIndex));
+    if (!track) {
+      LOG(error) << "CbmKFTrackFitter: Trd track is null!";
+      return false;
+    }
+    int nHits = track->GetNofHits();
+    if (nHits > 0) {
+      if (!fInputTrdHits) {
+        LOG(error) << "CbmKFTrackFitter: Trd hit array not found!";
+        return false;
+      }
+      trdHits.reserve(nHits);
+      for (int ih = 0; ih < nHits; ih++) {
+        int hitIndex = track->GetHitIndex(ih);
+        if (hitIndex >= fInputTrdHits->GetEntriesFast()) {
+          LOG(error) << "CbmKFTrackFitter: Trd hit index " << hitIndex << " is out of range!";
+          return false;
+        }
+        trdHits.push_back(*dynamic_cast<const CbmTrdHit*>(fInputTrdHits->At(hitIndex)));
+      }
+    }
+  }
+
+
+  // Read TOF hits
+
+  if (globalTrack.GetTofTrackIndex() >= 0) {
+    int tofTrackIndex = globalTrack.GetTofTrackIndex();
+    if (!fInputTofTracks) {
+      LOG(error) << "CbmKFTrackFitter: Trd track array not found!";
+      return false;
+    }
+    if (tofTrackIndex >= fInputTofTracks->GetEntriesFast()) {
+      LOG(error) << "CbmKFTrackFitter: Trd track index " << tofTrackIndex << " is out of range!";
+      return false;
+    }
+    auto* track = dynamic_cast<const CbmTofTrack*>(fInputTofTracks->At(tofTrackIndex));
+    if (!track) {
+      LOG(error) << "CbmKFTrackFitter: Tof track is null!";
+      return false;
+    }
+
+    int nHits = track->GetNofHits();
+    if (nHits > 0) {
+      if (!fInputTofHits) {
+        LOG(error) << "CbmKFTrackFitter: Tof hit array not found!";
+        return false;
+      }
+      tofHits.reserve(nHits);
+      for (int ih = 0; ih < nHits; ih++) {
+        int hitIndex = track->GetHitIndex(ih);
+        if (hitIndex >= fInputTofHits->GetEntriesFast()) {
+          LOG(error) << "CbmKFTrackFitter: Tof hit index " << hitIndex << " is out of range!";
+          return false;
+        }
+        tofHits.push_back(*dynamic_cast<const CbmTofHit*>(fInputTofHits->At(hitIndex)));
+      }
+    }
+  }
+
+  return CreateTrack(kfTrack, *globalTrack.GetParamFirst(), mvdHits, stsHits, muchHits, trdHits, tofHits);
+}
+
+
+bool CbmKFTrackFitter::CreateMvdStsTrack(CbmKFTrackFitter::Track& kfTrack, const CbmStsTrack& stsTrack)
+{
+  Init();
+  if (!fIsInitialized) return false;
+
+  std::vector<CbmMvdHit> mvdHits;
+  std::vector<CbmStsHit> stsHits;
+  std::vector<CbmMuchPixelHit> muchHits;
+  std::vector<CbmTrdHit> trdHits;
+  std::vector<CbmTofHit> tofHits;
+
+  kfTrack = {};
+
+  // Read MVD hits
+
+  int nMvdHits = stsTrack.GetNofMvdHits();
+  if (nMvdHits > 0) {
+    if (!fInputMvdHits) {
+      LOG(error) << "CbmKFTrackFitter: Mvd hit array not found!";
+      return false;
+    }
+    mvdHits.reserve(nMvdHits);
+    for (int ih = 0; ih < nMvdHits; ih++) {
+      int hitIndex = stsTrack.GetMvdHitIndex(ih);
+      mvdHits.push_back(*dynamic_cast<const CbmMvdHit*>(fInputMvdHits->At(hitIndex)));
+    }
+  }
+
+  // Read STS hits
+
+  int nStsHits = stsTrack.GetNofStsHits();
+  if (nStsHits > 0) {
+    if (!fInputStsHits) {
+      LOG(error) << "CbmKFTrackFitter: Sts hit array not found!";
+      return false;
+    }
+    stsHits.reserve(nStsHits);
+    for (int ih = 0; ih < nStsHits; ih++) {
+      int hitIndex = stsTrack.GetStsHitIndex(ih);
+      stsHits.push_back(*dynamic_cast<const CbmStsHit*>(fInputStsHits->At(hitIndex)));
+    }
+  }
+
+  return CreateTrack(kfTrack, *stsTrack.GetParamFirst(), mvdHits, stsHits, muchHits, trdHits, tofHits);
+}
+
+
+bool CbmKFTrackFitter::CreateTrack(CbmKFTrackFitter::Track& kfTrack, const FairTrackParam& trackFirst,
+                                   const std::vector<CbmMvdHit>& mvdHits, const std::vector<CbmStsHit>& stsHits,
+                                   const std::vector<CbmMuchPixelHit>& muchHits, const std::vector<CbmTrdHit>& trdHits,
+                                   const std::vector<CbmTofHit>& tofHits
+
+)
+{
+  kfTrack = {};
+  Init();
+  if (!fIsInitialized) return false;
+
+  std::vector<const CbmPixelHit*> hits;
+
+  std::vector<int> hitStations;
+
+  const ca::Parameters& caPar = CbmL1::Instance()->fpAlgo->GetParameters();
+
+  for (auto& h : mvdHits) {
+    hits.push_back(dynamic_cast<const CbmPixelHit*>(&h));
+    int stIdx = CbmMvdTrackingInterface::Instance()->GetTrackingStationIndex(&h);
+    hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kMvd));
+  }
+
+  for (auto& h : stsHits) {
+    hits.push_back(dynamic_cast<const CbmPixelHit*>(&h));
+    int stIdx = CbmStsTrackingInterface::Instance()->GetTrackingStationIndex(&h);
+    hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kSts));
+  }
+
+  for (auto& h : muchHits) {
+    hits.push_back(dynamic_cast<const CbmPixelHit*>(&h));
+    int stIdx = CbmMuchTrackingInterface::Instance()->GetTrackingStationIndex(&h);
+    hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kMuch));
+  }
+
+  for (auto& h : trdHits) {
+    hits.push_back(dynamic_cast<const CbmPixelHit*>(&h));
+    int stIdx = CbmTrdTrackingInterface::Instance()->GetTrackingStationIndex(&h);
+    hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kTrd));
+  }
+
+  for (auto& h : tofHits) {
+    hits.push_back(dynamic_cast<const CbmPixelHit*>(&h));
+    int stIdx = CbmTofTrackingInterface::Instance()->GetTrackingStationIndex(&h);
+    hitStations.push_back(caPar.GetStationIndexActive(stIdx, ca::EDetectorID::kTof));
+  }
+
+  CbmKFTrackFitter::Track t;
+
+  int nStations = caPar.GetNstationsActive();
+
+  t.fNodes.resize(nStations);
+  for (int i = 0; i < nStations; i++) {
+    t.fNodes[i].fMaterialLayer = i;
+    t.fNodes[i].fZ             = caPar.GetStation(i).GetZScal();
+    t.fNodes[i].fRadThick      = 0.;
+    t.fNodes[i].fIsRadThickSet = false;
+    t.fNodes[i].fIsFitted      = false;
+  }
+
+  t.fFirstHitNode = nStations - 1;
+  t.fLastHitNode  = 0;
+
+  for (unsigned int i = 0; i < hits.size(); i++) {
+
+    assert(hits[i]);
+    const CbmPixelHit& h = *hits[i];
+
+    int ista = hitStations[i];
+
+    if (ista < 0) continue;
+    assert(ista < nStations);
+
+    CbmKFTrackFitter::FitNode& n = t.fNodes[ista];
+
+    n.fZ = h.GetZ();
+
+    n.fMxy.SetX(h.GetX());
+    n.fMxy.SetY(h.GetY());
+    n.fMxy.SetDx2(h.GetDx() * h.GetDx());
+    n.fMxy.SetDy2(h.GetDy() * h.GetDy());
+    n.fMxy.SetDxy(h.GetDxy());
+    n.fMxy.SetNdfX(1);
+    n.fMxy.SetNdfY(1);
+
+    n.fMt.SetT(h.GetTime());
+    n.fMt.SetDt2(h.GetTimeError() * h.GetTimeError());
+    n.fMt.SetNdfT(1);
+
+    n.fRadThick = 0.;
+
+    t.fFirstHitNode = std::min(t.fFirstHitNode, ista);
+    t.fLastHitNode  = std::max(t.fLastHitNode, ista);
+  }
+
+  ca::TrackParamD tmp = cbm::L1Util::ConvertTrackParam(trackFirst);
+
+  t.fNodes[t.fFirstHitNode].fTrack.Set(tmp);
+  t.fNodes[t.fFirstHitNode].fIsFitted = 1;
+
+  kfTrack = t;
+  return true;
+}
+
+
+void CbmKFTrackFitter::FilterFirstMeasurement(const FitNode& n)
+{
+  // a special routine to filter the first measurement.
+  // the measurement errors are simply copied to the track covariance matrix
+
+  const auto& mxy = n.fMxy;
+  const auto& mt  = n.fMt;
+
+  auto& tr = fFit.Tr();
+  tr.ResetErrors(mxy.Dx2(), mxy.Dy2(), 1., 1., 1., 1.e4, 1.e2);
+  tr.SetC10(mxy.Dxy());
+  tr.SetX(mxy.X());
+  tr.SetY(mxy.Y());
+  tr.SetNdf(-5 + mxy.NdfX() + mxy.NdfY());
+  if (mt.NdfT()[0] > 0) {
+    tr.SetTime(mt.T());
+    tr.SetC55(mt.Dt2());
+    tr.SetNdfTime(-2 + 1);
+  }
+  else {
+    tr.SetNdfTime(-2);
+  }
+  tr.SetVi(0.);
+}
+
+
+void CbmKFTrackFitter::AddMaterialEffects(const CbmKFTrackFitter::Track& t, CbmKFTrackFitter::FitNode& n,
+                                          bool upstreamDirection)
+{
+  // add material effects
+  if (n.fMaterialLayer < 0) { return; }
+
+  // calculate the radiation thickness from the current track
+  if (!n.fIsRadThickSet && !n.fIsFitted) {
+    n.fRadThick = CbmL1::Instance()->fpAlgo->GetParameters().GetMaterialThicknessScal(
+      n.fMaterialLayer, fFit.Tr().GetX()[0], fFit.Tr().GetY()[0]);
+  }
+
+  fvec msQp0 = t.fMsQp0;
+  if (!t.fIsMsQp0Set) {
+    if (n.fIsFitted) { msQp0 = n.fTrack.GetQp(); }
+    else {
+      msQp0 = fFit.Tr().GetQp();
+    }
+  }
+  fFit.MultipleScattering(n.fRadThick, msQp0);
+  fFit.EnergyLossCorrection(n.fRadThick, upstreamDirection ? fvec::One() : fvec::Zero());
+}
+
+void CbmKFTrackFitter::FitTrack(CbmKFTrackFitter::Track& t)
+{
+  // fit the track
+
+  // ensure that the fitter is initialized
+  Init();
+
+  t.MakeConsistent();
+
+  fFit.SetMask(fmask::One());
+  fFit.SetParticleMass(fMass);
+
+  ca::FieldRegion field _fvecalignment;
+  field.SetUseOriginalField();
+
+  int nNodes = t.fNodes.size();
+
+  // fit downstream. The approximation is taken from the first hit node
+  {
+    FitNode& n = t.fNodes[t.fFirstHitNode];
+    fFit.SetTrack(n.fTrack);
+    FilterFirstMeasurement(n);
+    n.fTrack    = fFit.Tr();
+    n.fIsFitted = false;
+  }
+
+  for (int iNode = t.fFirstHitNode + 1; iNode < nNodes; iNode++) {
+    FitNode& n = t.fNodes[iNode];
+    fFit.Extrapolate(n.fZ, field);
+    if (n.fIsFitted) { fFit.SetQp0(n.fTrack.GetQp()); }
+    AddMaterialEffects(t, n, false);
+    n.fTrack    = fFit.Tr();
+    n.fIsFitted = false;
+    fFit.FilterXY(n.fMxy);
+    fFit.FilterTime(n.fMt);
+    if (iNode == t.fLastHitNode) { n.fTrack = fFit.Tr(); }
+    if (iNode >= t.fLastHitNode) { n.fIsFitted = true; }
+  }
+
+
+  // fit upstream
+  {
+    FitNode& n = t.fNodes[t.fLastHitNode];
+    fFit.SetTrack(n.fTrack);
+    FilterFirstMeasurement(n);
+    n.fIsFitted = true;
+  }
+
+  for (int iNode = t.fLastHitNode - 1; iNode >= 0; iNode--) {
+    FitNode& n = t.fNodes[iNode];
+    fFit.Extrapolate(n.fZ, field);
+    fFit.FilterXY(n.fMxy);
+    fFit.FilterTime(n.fMt);
+
+    // combine partially fitted downstream and upstream tracks
+    if (iNode > t.fFirstHitNode) { Smooth(n.fTrack, fFit.Tr()); }
+    else {
+      n.fTrack = fFit.Tr();
+    }
+    n.fIsFitted = true;
+    fFit.SetQp0(n.fTrack.GetQp());
+    AddMaterialEffects(t, n, true);
+    if (iNode == t.fFirstHitNode) { n.fTrack = fFit.Tr(); }
+  }
+
+  // distribute the final chi2, ndf to all nodes
+
+  const auto& tt = t.fNodes[t.fFirstHitNode].fTrack;
+  for (auto& n : t.fNodes) {
+    n.fTrack.SetNdf(tt.GetNdf());
+    n.fTrack.SetNdfTime(tt.GetNdfTime());
+    n.fTrack.SetChiSq(tt.GetChiSq());
+    n.fTrack.SetChiSqTime(tt.GetChiSqTime());
+  }
+}
+
+
+void CbmKFTrackFitter::Smooth(ca::TrackParamV& t1, const ca::TrackParamV& t2)
+{
+  // combine two tracks
+  std::tie(t1.X(), t1.Y(), t1.C00(), t1.C10(), t1.C11()) =
+    Smooth2D(t1.X(), t1.Y(), t1.C00(), t1.C10(), t1.C11(), t2.X(), t2.Y(), t2.C00(), t2.C10(), t2.C11());
+
+  std::tie(t1.Tx(), t1.Ty(), t1.C22(), t1.C32(), t1.C33()) =
+    Smooth2D(t1.Tx(), t1.Ty(), t1.C22(), t1.C32(), t1.C33(), t2.Tx(), t2.Ty(), t2.C22(), t2.C32(), t2.C33());
+
+  std::tie(t1.Qp(), t1.C44()) = Smooth1D(t1.Qp(), t1.C44(), t2.Qp(), t2.C44());
+
+  std::tie(t1.Time(), t1.Vi(), t1.C55(), t1.C65(), t1.C66()) =
+    Smooth2D(t1.Time(), t1.Vi(), t1.C55(), t1.C65(), t1.C66(), t2.Time(), t2.Vi(), t2.C55(), t2.C65(), t2.C66());
+
+  t1.C20() = 0.;
+  t1.C21() = 0.;
+  t1.C30() = 0.;
+  t1.C31() = 0.;
+  t1.C40() = 0.;
+  t1.C41() = 0.;
+  t1.C42() = 0.;
+  t1.C43() = 0.;
+  t1.C50() = 0.;
+  t1.C51() = 0.;
+  t1.C52() = 0.;
+  t1.C53() = 0.;
+  t1.C54() = 0.;
+  t1.C60() = 0.;
+  t1.C61() = 0.;
+  t1.C62() = 0.;
+  t1.C63() = 0.;
+  t1.C64() = 0.;
+}
+
+std::tuple<ca::fvec, ca::fvec> CbmKFTrackFitter::Smooth1D(ca::fvec x1, ca::fvec Cxx1, ca::fvec x2, ca::fvec Cxx2)
+{
+  // combine two 1D values
+  ca::fvec x   = (x1 * Cxx1 + x2 * Cxx2) / (Cxx1 + Cxx2);
+  ca::fvec Cxx = Cxx1 * Cxx2 / (Cxx1 + Cxx2);
+  return std::tuple(x, Cxx);
+}
+
+std::tuple<ca::fvec, ca::fvec, ca::fvec, ca::fvec, ca::fvec>
+CbmKFTrackFitter::Smooth2D(ca::fvec x1, ca::fvec y1, ca::fvec Cxx1, ca::fvec /*Cxy1*/, ca::fvec Cyy1, ca::fvec x2,
+                           ca::fvec y2, ca::fvec Cxx2, ca::fvec /*Cxy2*/, ca::fvec Cyy2)
+{
+  // combine two 2D values
+  // TODO: do it right
+  auto [x, Cxx] = Smooth1D(x1, Cxx1, x2, Cxx2);
+  auto [y, Cyy] = Smooth1D(y1, Cyy1, y2, Cyy2);
+  return std::tuple(x, y, Cxx, ca::fvec::Zero(), Cyy);
+}
\ No newline at end of file
diff --git a/reco/KF/CbmKFTrackFitter.h b/reco/KF/CbmKFTrackFitter.h
new file mode 100644
index 0000000000..25d88910bb
--- /dev/null
+++ b/reco/KF/CbmKFTrackFitter.h
@@ -0,0 +1,160 @@
+/* Copyright (C) 2023 GSI Helmholtzzentrum fuer Schwerionenforschung, Darmstadt
+   SPDX-License-Identifier: GPL-3.0-only
+   Authors: Sergey Gorbunov [committer] */
+
+#pragma once  // include this header only once per compilation unit
+
+
+#include "CbmDefs.h"
+
+#include <vector>
+
+#include "CaConstants.h"
+#include "CaMeasurementTime.h"
+#include "CaMeasurementXy.h"
+#include "CaSimd.h"
+#include "CaTrackFit.h"
+#include "CaTrackParam.h"
+
+class CbmMvdHit;
+class CbmStsHit;
+class CbmStsTrack;
+class CbmGlobalTrack;
+class CbmKFVertex;
+class CbmMuchPixelHit;
+class CbmTrdHit;
+class CbmTofHit;
+
+class FairTrackParam;
+class TClonesArray;
+
+namespace cbm::algo::ca
+{
+  class FieldRegion;
+}  // namespace cbm::algo::ca
+
+namespace
+{
+  using namespace cbm::algo;
+}
+
+/// A fitter for the Cbm tracks
+///
+class CbmKFTrackFitter {
+public:
+  /// A node on the trajectory where the track parameters are:
+  /// a) measured and / or
+  /// b) scattered and / or
+  /// c) need to be estimated
+  /// The nodes must be ordered by increasing Z
+  ///
+  struct FitNode {
+
+    double fZ {0.};  ///< Z coordinate
+
+    ca::TrackParamV fTrack {};  ///< fitted track
+
+    /// == Material information (if present)
+    // TODO: change to the material layer index when the material layer is implemented
+    int fMaterialLayer {-1};  ///< index of the material layer. Currently equal to the active tracking station index
+
+    double fRadThick {0.};  ///< material radiation thickness at fZ
+
+    /// == Hit information ( if present )
+
+    ca::MeasurementXy<ca::fvec> fMxy {};  ///< XY-measurement at fZ
+
+    ca::MeasurementTime<ca::fvec> fMt {};  ///< time measurement at fZ
+
+    ECbmModuleId fSystemId {ECbmModuleId::kNotExist};  ///< detector system ID of the hit
+    int fDetectorId {-1};                              ///< detector ID of the hit
+    int fHitId {-1};                                   ///< hit ID
+
+    /// == Flags etc
+
+    bool fIsFitted {false};       ///< true if the node is fitted, false if the fit failed
+    bool fIsRadThickSet {false};  ///< true if the radiation thickness is set
+    int fReference1 {-1};         ///< some reference can be set by the user
+    int fReference2 {-1};         ///< some reference can be set by the user
+  };
+
+  /// A track to be fitted
+  struct Track {
+    std::vector<FitNode> fNodes;  ///< nodes on the track
+    int fFirstHitNode {-1};       ///< index of the first node with the XY measurement
+    int fLastHitNode {-1};        ///< index of the last node with the XY measurement
+
+    /// externally defined inverse momentum for the Multiple Scattering calculation.
+    /// It is used for the tracks in field-free regions.
+    /// When the momentum can be fitted, the fitted value is used.
+    /// the default value is set to 0.1 GeV/c
+    double fMsQp0 {1. / 0.1};
+    bool fIsMsQp0Set {false};
+
+    void MakeConsistent();  // make the structure fields consistent
+  };
+
+  CbmKFTrackFitter();
+  ~CbmKFTrackFitter();
+
+  /// initialize the fitter. It must be called in the Init() of the user task.
+  /// when called later, it sees track branches always empty for whatever reason
+  void Init();
+
+  /// set particle hypothesis (mass and electron flag) via particle PDG
+  void SetParticleHypothesis(int pid);
+
+  /// set particle  mass
+  void SetMassHypothesis(double mass);
+
+  /// set electron flag (bremmstrallung will be applied)
+  void SetElectronFlag(bool isElectron);
+
+  bool CreateTrack(Track& kfTrack, const FairTrackParam& trackFirst, const std::vector<CbmMvdHit>& mvdHits,
+                   const std::vector<CbmStsHit>& stsHits, const std::vector<CbmMuchPixelHit>& muchHits,
+                   const std::vector<CbmTrdHit>& trdHits, const std::vector<CbmTofHit>& tofHits);
+
+  bool CreateMvdStsTrack(Track& kfTrack, const CbmStsTrack& stsTrack);
+  bool CreateGlobalTrack(Track& kfTrack, const CbmGlobalTrack& globalTrack);
+
+  /// fit the track
+  void FitTrack(CbmKFTrackFitter::Track& t);
+
+  /// fit sts tracks
+  // void FitStsTracks(vector<CbmStsTrack>& Tracks, const vector<int>& pidHypo);
+
+private:
+  void FilterFirstMeasurement(const FitNode& n);
+  void AddMaterialEffects(const Track& t, FitNode& n, bool upstreamDirection);
+  // combine two tracks
+  void Smooth(ca::TrackParamV& t1, const ca::TrackParamV& t2);
+
+  std::tuple<ca::fvec, ca::fvec> Smooth1D(ca::fvec x1, ca::fvec Cxx1, ca::fvec x2, ca::fvec Cxx2);
+  std::tuple<ca::fvec, ca::fvec, ca::fvec, ca::fvec, ca::fvec> Smooth2D(ca::fvec x1, ca::fvec y1, ca::fvec Cxx1,
+                                                                        ca::fvec Cxy1, ca::fvec Cyy1, ca::fvec x2,
+                                                                        ca::fvec y2, ca::fvec Cxx2, ca::fvec Cxy2,
+                                                                        ca::fvec Cyy2);
+
+private:
+  // input data arrays
+  TClonesArray* fInputMvdHits {nullptr};
+  TClonesArray* fInputStsHits {nullptr};
+  TClonesArray* fInputMuchHits {nullptr};
+  TClonesArray* fInputTrdHits {nullptr};
+  TClonesArray* fInputTofHits {nullptr};
+
+  TClonesArray* fInputGlobalTracks {nullptr};
+  TClonesArray* fInputStsTracks {nullptr};
+  TClonesArray* fInputMuchTracks {nullptr};
+  TClonesArray* fInputTrdTracks {nullptr};
+  TClonesArray* fInputTofTracks {nullptr};
+
+  //
+
+  bool fIsInitialized = {false};  // is the fitter initialized
+
+  ca::TrackFit fFit;  // track fit object
+
+  double fMass {ca::constants::phys::PionMass};  // mass hypothesis for the fit
+  bool fIsElectron {false};                      // fit track as an electron (with the bermsstrallung effect)
+};
diff --git a/reco/KF/KF.cmake b/reco/KF/KF.cmake
index 273f26b066..af1bd8f834 100644
--- a/reco/KF/KF.cmake
+++ b/reco/KF/KF.cmake
@@ -21,6 +21,7 @@ set(SRCS
   CbmKFTrackInterface.cxx 
   CbmKFUMeasurement.cxx 
   CbmKFVertexInterface.cxx 
+  CbmKFTrackFitter.cxx
 
   #Interface/CbmEcalTrackExtrapolationKF.cxx
   Interface/CbmKFStsHit.cxx 
diff --git a/reco/KF/KFLinkDef.h b/reco/KF/KFLinkDef.h
index 6f16bb68b6..d34ed89840 100644
--- a/reco/KF/KFLinkDef.h
+++ b/reco/KF/KFLinkDef.h
@@ -40,6 +40,7 @@
 #pragma link C++ class CbmL1TofMerger + ;
 #pragma link C++ class CbmL1TrdTracklet + ;
 #pragma link C++ class CbmL1TrdTracklet4 + ;
+#pragma link C++ class CbmKFTrackFitter + ;
 //#pragma link C++ class  CbmL1TrdTrackFinderSts+;
 //#pragma link C++ class  CbmL1CATrdTrackFinderSA+;
 
-- 
GitLab