Skip to content
Snippets Groups Projects
Commit c0fd9b8f authored by Axel Puntke's avatar Axel Puntke
Browse files

Added Option to manually specify the number of ONNX threads

parent 9e1309aa
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,7 @@ void ATreePredictionAdder::Init()
std::cout << "Loading ONNX model file " << model_file_name_ << std::endl;
onnx_runner_ = new ONNXRunner();
onnx_runner_->Init(model_file_name_);
onnx_runner_->Init(model_file_name_, num_threads_);
if (onnx_runner_->GetFeatureCount() != feature_field_ids_.size())
{
......
......@@ -23,6 +23,7 @@ class ATreePredictionAdder : public AnalysisTree::Task {
void SetOutputBranchName(std::string output_branch_name) {output_branch_name_ = output_branch_name;}
void SetFeatureFieldNames(std::string feature_field_name_arg) {feature_field_names_ = stringSplit(feature_field_name_arg, ",");}
void SetModelFileName(std::string model_file_name) {model_file_name_ = model_file_name;}
void SetNumThreads(int num_threads) {num_threads_ = num_threads;}
protected:
ONNXRunner* onnx_runner_;
......@@ -40,6 +41,7 @@ protected:
std::vector<std::string> feature_field_names_;
std::vector<int> feature_field_ids_;
std::string model_file_name_{"model_onnx.onnx"};
int num_threads_ = -1;
//**** input fields ***********
int mass2_first_field_id_r_{AnalysisTree::UndefValueInt};
......
......@@ -27,10 +27,12 @@ int ONNXRunner::calculate_product(const std::vector<int64_t>& v)
return total;
}
void ONNXRunner::Init(std::string model_file)
void ONNXRunner::Init(std::string model_file, int num_threads)
{
env_ = new Ort::Env(ORT_LOGGING_LEVEL_WARNING, "atree-prediction-adder");
Ort::SessionOptions session_options;
if (num_threads > 0)
session_options.SetIntraOpNumThreads(num_threads);
session_ = new Ort::Experimental::Session(*env_, model_file, session_options);
auto input_names = session_->GetInputNames();
......
......@@ -18,7 +18,7 @@ class ONNXRunner {
ONNXRunner();
~ONNXRunner() = default;
void Init(std::string model_file);
void Init(std::string model_file, int num_threads = -1);
float PredictSingleInstance(std::vector<float> feature_values);
std::vector<float> PredictBatch(std::vector<float> feature_values);
int GetFeatureCount() {return feature_count_;}
......
......@@ -70,8 +70,10 @@ Specifies the *.onnx file where the model is stored in
Specifies the order and field names of the features which are put into the model in a comma-separated list.
### --o <output-file>
Specifies the output file name where the root tree should be stored in.
### --t
Specified the name of the tree inside the input and output file where the candidates are stored in.
### --t <tree-name>
Specifies the name of the tree inside the input and output file where the candidates are stored in.
### --num_threads <number-of-onnx-threads>
Specifies the number of threads ONNX should use. This option may be neccessary when used inside a slurm environment, because the number of cores cannot be determined in the usual way automatically.
# Usage example
In python, given a trained XGBClassifier `model_clf`, we can export it to the *.onnx format using the [hipe4ml converter](https://github.com/fgrosa/hipe4ml_converter) (install with `pip install hipe4ml-converter`):
......@@ -94,6 +96,6 @@ model_conv.dump_model_onnx("xgboost_lambda_classifier.onnx")
```
Next we run `at_tree_prediction_adder` with a filelist containing a file generated with [PFSimple](https://github.com/HeavyIonAnalysis/PFSimple) which candidates contain all the fields the model needs:
```
./at_tree_prediction_adder -f filelist.txt --ib Candidates_plain --ob Candidates_plainPredicted --m xgboost_lambda_classifier.onnx --features chi2_geo,chi2_prim_first,chi2_prim_second,distance,l_over_dl,mass2_first,mass2_second --o prediction_tree.root
./at_tree_prediction_adder --f filelist.txt --ib Candidates_plain --ob Candidates_plainPredicted --m xgboost_lambda_classifier.onnx --features chi2_geo,chi2_prim_first,chi2_prim_second,distance,l_over_dl,mass2_first,mass2_second --o prediction_tree.root
```
Then you can analyze the outcoming file `prediction_tree.root` with AnalysisTreeQA and use the new candidate field `onnx_pred` (which should contain the signal probability) to apply cuts on.
\ No newline at end of file
......@@ -13,6 +13,7 @@ int main(int argc, char** argv)
std::string feature_field_names = "chi2_geo,chi2_prim_first,chi2_prim_second,distance,l_over_dl,mass2_first,mass2_second";
std::string output_file = "prediction_tree.root";
std::string tree_name = "pTree";
int num_threads = -1;
for (int i = 1; i < argc; ++i)
{
......@@ -51,6 +52,11 @@ int main(int argc, char** argv)
tree_name = std::string(argv[++i]);
printf("Tree name: %s\n", tree_name.c_str());
}
if (strcmp(argv[i], "--num_threads") == 0)
{
num_threads = atoi(argv[++i]);
printf("Number of ONNX threads: %d\n", num_threads);
}
}
const bool make_plain_ttree{true};
......@@ -65,6 +71,7 @@ int main(int argc, char** argv)
at_prediction_adder_task->SetOutputBranchName(output_branch_name);
at_prediction_adder_task->SetModelFileName(model_file_name);
at_prediction_adder_task->SetFeatureFieldNames(feature_field_names);
at_prediction_adder_task->SetNumThreads(num_threads);
man->AddTask(at_prediction_adder_task);
man->Init({filename_pfs}, {tree_name});
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment