#include <string.h>
#include <iostream>
#include <vector>
#include "TFile.h"
#include "TLorentzVector.h"
#include "TH1.h"
#include "TH2.h"
#include "TF1.h"
#include "TRandom3.h"
#include "TVector3.h"
#include <fstream>

using namespace std;
 

//FUNCTION PROTOTYPES
void printUsage(); // print usage

TLorentzVector vPho(double e1,TRandom *ranGen,TLorentzVector *eScat);

//MAIN
int main(int argc, char **argv) {
  
  char  inFileName[120];
  char  outFileName[120];
  char *argptr;
  int nToRun = 100000;
  int rSeed = 0;
  //Set the default output file
  sprintf(outFileName,"./myOutFile.root");
  //if (argc == 1) printUsage();
  for (int i=1; i<argc; i++) {
    argptr = argv[i];
    if (*argptr == '-') {
      argptr++;
      switch (*argptr) {
      case 'h':
        printUsage();
        break;
      case 'n':
        nToRun = atoi(++argptr);
        break;
      case 's':
        rSeed = atoi(++argptr);
        break;
      case 'o':
        strcpy(outFileName,++argptr);
        break;
      default:
        printUsage();
        break;
      }
    } else {
      strcpy(inFileName,argptr);
    }
  }



  //Create histograms here
  TH1D * hEGamma = new TH1D("hEGamma","",700,4.0,11.0);

 
  gRandom =  new TRandom3();
  gRandom->SetSeed(rSeed);


  double mPro = 0.938;

  TLorentzVector target(0,0,0,mPro);
  double e1 = 10.6;
  
  for(Int_t i=0;i<nToRun;i++){
    if (i%10000 == 0 && i > 0) cout<<"Events processed = "<<i<<endl;

    
    
    TLorentzVector scatteredElectron;
    TLorentzVector virtualPhoton = vPho(e1,gRandom,&scatteredElectron);

    hEGamma->Fill(virtualPhoton.E());
    //cout<<"q2 = "<<virtualPhoton.Mag2()<<endl;
    //cout<<"theta = "<<virtualPhoton.Theta()*180/3.14159<<endl;


    TLorentzVector pTot4Vec = virtualPhoton + target;

  }
  //Write histograms to file
  TFile *outFile = new TFile(outFileName,"RECREATE");
  outFile->cd();
  
  hEGamma->Write();
  outFile->Write();
  
  return 0;
}


void printUsage(){
  fprintf(stderr,"\nUsage:");
  fprintf(stderr,"\nmyEventGen [-switches]\n");
  fprintf(stderr,"\nSWITCHES:\n");
  fprintf(stderr,"-h\tPrint this message\n");
  fprintf(stderr,"-n<arg>\tNumber of events thrown. Default is 100000\n");
  fprintf(stderr,"-s<arg>\tRandom seed. Default is 0\n");
  fprintf(stderr,"-o<arg>\tOutFileName. Default is myOutFile.txt\n\n");

  cout<<"The current default operation is equivalent to the command:"<<endl;
  cout<<"myEventGen -s0 -t1 -n100000 -omyOutFile.root\n"<<endl;

  exit(0);
}

TLorentzVector vPho(double e1,TRandom *ranGen,TLorentzVector *eScat){
  double Qsq,QsqMin,QsqMax,e2,p1,p2,e2Min,e2Max,theta2Min,theta2Max;
  double y,yMin,yMax,rFlux,rFluxMax,rFluxTest,me,theta2,phi2;

  //me = TDatabasePDG::Instance()->GetParticle("e-")->Mass();

  me = 511/1000000;
  //e1 = 11.0;  //HARD CODED INCIDENT ELECTRON ENERGY

  e2Min = 0.5; //HARD CODED e2Min
  e2Max = 6.0; //HARD CODED e2Max

  theta2Min = 2.5*TMath::DegToRad(); //HARD CODED thetaMin
  theta2Max = 4.5*TMath::DegToRad(); //HARD CODED thetaMax
  //Find min and max values of Qsq
  QsqMin = e1*e2Min*pow(theta2Min,2);
  QsqMax = e1*e2Max*pow(theta2Max,2);

  //Find min and max values of y
  yMin = (e1 - e2Max)/e1;
  yMax = (e1 - e2Min)/e1;


  //Find max value of relative flux
  rFluxMax = (1 + pow(1.0-yMin,2))/(yMin*QsqMin);

  rFlux = 0.0;
  rFluxTest = 2.0;

  double QsqMax2,QsqMin2;
  while (rFluxTest>rFlux) {//Monte Carlo the relative flux                                                                     
    y = ranGen->Uniform(yMin,yMax);
    e2 = e1 -e1*y;
    Qsq = ranGen->Uniform(QsqMin,QsqMax);

    QsqMax2 = e1*e2*pow(theta2Max,2);
    QsqMin2 = e1*e2*pow(theta2Min,2);
    if (Qsq<=QsqMax2 && Qsq>=QsqMin2) {
      rFlux = (1 + pow(1.0-y,2))/(y*Qsq);
      rFluxTest = ranGen->Uniform(0.0,rFluxMax);
    }
  }

  if (rFlux > rFluxMax) std::cout<<"BAD rFluxMax value, rFlux = "<<rFlux<<" rFluxMax = "<<rFluxMax<<std::endl;
  if (rFlux > rFluxMax) std::cout<<"Qsq = "<<Qsq<<" QsqMax = "<<QsqMax<<std::endl;
 
  //Obtain theta2
  theta2 = sqrt(Qsq/(e1*e2));

  //Generate phi2
  phi2 = ranGen->Uniform(-TMath::Pi(),TMath::Pi());

  //Get the y value in the CM system
  p1 = sqrt(pow(e1,2) - pow(me,2));
  p2 = sqrt(pow(e2,2) - pow(me,2));
  //TLorentzVector target(0.0, 0.0, 0.0, mPro);
  TLorentzVector e1vec4(0.0,0.0,p1,e1);
  TLorentzVector e2vec4(p2*sin(theta2)*cos(phi2),p2*sin(theta2)*sin(phi2),p2*cos(theta2),e2);
  TLorentzVector vPho4v = e1vec4 - e2vec4;

  *eScat = e2vec4;
  return vPho4v;
}
