/***************************************************************************
 *                                                                         *
 *                  (begin: Feb 20 2003)                                   *
 *                                                                         *
 *   Parallel IQPNNI - Important Quartet Puzzle with NNI                   *
 *                                                                         *
 *   Copyright (C) 2005 by Le Sy Vinh, Bui Quang Minh, Arndt von Haeseler  *
 *   Copyright (C) 2003-2004 by Le Sy Vinh, Arndt von Haeseler             *
 *   {vinh,minh}@cs.uni-duesseldorf.de                                     *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

/***************************************************************************
                          libr.cpp  -  description
                             -------------------
    begin                : Thu Apr 24 2003
    copyright            : (C) 2003 by 
    email                : vinh@cs.uni-duesseldorf.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
#include <math.h>


#include <time.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#include "constant.h"
#include "libr.h"
#include "ptnls.h"
#include "rate.h"
#include "model.h"
#include "opturtree.h"
#include "dmat20.h"
#include "ptnratecube.h"
#include "interface.h"

#ifdef PARALLEL
#include <mpi.h>
#endif
//extern int nto[NUM_CHAR];

#ifdef _OPENMP
#include <omp.h>
#endif

#ifdef WIN32
#include <float.h>
#define isnan(x) _isnan(x)
#endif

const double VERY_SMALL_POSITIVE = 1e-200;

//#define PRINT_NEGATIVE_INFO

/**
	the constructor
*/
LiBr::LiBr () {
	smaLiNd_ = new LiNd;
	greLiNd_ = new LiNd;
	ptn_logl = NULL;
}

/**
	copy the base br into this likelihood branch
*/
void LiBr::copyBase (Br<double> &br) {
	this->id_ = br.getId ();
	this->smaNdNo_ = br.getSmaNd ();
	this->greNdNo_ = br.getGreNd ();
	this->isEx_ = br.isEx ();
	this->len_ = br.getLen ();
	this->isStG_ = br.isStG ();

	greLiNd_->set (greNdNo_, IN, smaNdNo_, id_);
	if (isEx_ == 1)
		smaLiNd_->set (smaNdNo_, EX, greNdNo_, id_);
	else
		smaLiNd_->set (smaNdNo_, IN, greNdNo_, id_);
}


/**
	@return 1 if this is an outgroup node
*/
int LiBr::isOutGrp () {
	if (id_ == opt_urtree.outGrpBrNo_ )
		return 1;
	else
		return 0;
}

/**
	@param existNdNo the node ID at the opposite
	@return the remaining liNd of this branch 
	opposite to existNdNo
*/
LiNd &LiBr::getRemLiNd (int existNdNo) {
	if (existNdNo != smaNdNo_)
		return *smaLiNd_;
	else
		return *greLiNd_;
}

/**
	@return the small liNd of this branch
*/
LiNd &LiBr::getSmaLiNd () {
	return *smaLiNd_;
}

/**
	@return the great liNd of this branch
*/
LiNd &LiBr::getGreLiNd () {
	return *greLiNd_;
}

/**
	print information 
*/
void LiBr::writeInf () {
	std::cout << id_ << " / " << smaNdNo_ << " / " << greNdNo_ << " / " << isStG_ << endl;
	getHeadLiNd ().writeInf ();
	getTailLiNd ().writeInf ();
}

/**
	set the new tail LiNd
	@param newTailLiNd new tail LiNd
*/
void LiBr::setTailLiNd (LiNd *newTailLiNd) {
	if (isStG_ == 1)
		greLiNd_ = newTailLiNd;
	else
		smaLiNd_ = newTailLiNd;
	int tailNdNo_ = newTailLiNd->getId ();

	isEx_ = isExNd (tailNdNo_);

	if (smaLiNd_->getId () > greLiNd_->getId () )
		Utl::swap (smaLiNd_, greLiNd_);
	smaNdNo_ = smaLiNd_->getId ();
	greNdNo_ = greLiNd_->getId ();
	if (greNdNo_ == tailNdNo_)
		isStG_ = 1;
	else
		isStG_ = 0;
	smaLiNd_->setParNdBr (greNdNo_, id_);
	greLiNd_->setParNdBr (smaNdNo_, id_);
}

/**
	@return the tail liNd of this branch
*/
LiNd &LiBr::getTailLiNd () {

	if (isStG_ == 1)
		return *greLiNd_;
	else
		return *smaLiNd_;
}

/**
	@return the head liNd of this branch
*/
LiNd &LiBr::getHeadLiNd () {
	if (isStG_ == 1)
		return *smaLiNd_;
	else
		return *greLiNd_;
}

/**
	@param ndNo node ID
	@return the LiNd corresponding to ndNo
*/
LiNd &LiBr::getLiNd (int ndNo) {
	if (ndNo == smaNdNo_)
		return *smaLiNd_;
	else
		return *greLiNd_;
}

/**
	compute the likelihood for liNd of this branch
	@param ndNo node ID
*/
void LiBr::cmpLiNd (int ndNo) {

	if (ndNo == greNdNo_)
		greLiNd_->cmpLi ();

	if (ndNo == smaNdNo_)
		smaLiNd_->cmpLi ();
}

/**
	compute the likelihood for tail liNd of this branch
*/
void LiBr::cmpTailLiNd () {
	if (getTailNd () == greNdNo_) //the case, the tail node is greNdNo_
		greLiNd_->cmpLi ();
	else // the case the tail node is small node
		smaLiNd_->cmpLi ();
}

/**
	compute the likelihood for head liNd of this branch
*/
void LiBr::cmpHeadLiNd () {
	if (getHeadNd () == greNdNo_) //the case, the head node is greNdNo_
		greLiNd_->cmpLi ();
	else // the case the head node is small node
		smaLiNd_->cmpLi ();

}

/**
	reCompute the likelihood for tail liNd of this branch
*/
void LiBr::reCmpTailLiNd () {
	if (getTailNd () == greNdNo_) //the case, the tail node is greNdNo_
		greLiNd_->reCmpLi ();
	else
		smaLiNd_->reCmpLi ();
}

/**
	reCompute the likelihood for head liNd of this branch
*/
void LiBr::reCmpHeadLiNd () {
	if (getHeadNd () == greNdNo_) //the case, the head node is greNdNo_
		greLiNd_->reCmpLi ();
	else
		smaLiNd_->reCmpLi ();
}

/**
	create the external descendant for all liNds of this branch
*/
void LiBr::createExDesLiNd () {
	greLiNd_->createExDesNd ();
	smaLiNd_->createExDesNd ();
}



/***********************************************************************
* compute the probalibity change matrix times base frequency
***********************************************************************/


inline void LiBr::cmpProbChangeFrq (const double brLen, DMat20 &probMat) {
	mymodel.cmpProbChange(brLen, probMat);
	int nState_ = mymodel.getNState ();
	for (int stateNo1 = 0; stateNo1 < nState_; stateNo1 ++)
		for (int stateNo2 = 0; stateNo2 < nState_; stateNo2 ++)
			probMat[stateNo1][stateNo2] *= mymodel.stateFrqArr_[stateNo1];
}

inline void LiBr::cmpProbChangeFrq (const double brLen, const int classNo, DMat20 &probMat) {
	mymodel.cmpProbChange(brLen, classNo, probMat);
	int nState_ = mymodel.getNState ();
	for (int stateNo1 = 0; stateNo1 < nState_; stateNo1 ++)
		for (int stateNo2 = 0; stateNo2 < nState_; stateNo2 ++) {
			if (isnan(probMat[stateNo1][stateNo2]) /* != probMat[stateNo1][stateNo2]*/ ) {
				cout << "NaN" << endl;
			}
			probMat[stateNo1][stateNo2] *= mymodel.stateFrqArr_[stateNo1];
		}
}


inline void LiBr::cmpProbChangeDerivativesFrq (const double brLen, DMat20 &probMat,
        DMat20 &derv1, DMat20 &derv2) {
	mymodel.cmpProbChangeDerivatives(brLen, probMat, derv1, derv2);
	int nState_ = mymodel.getNState ();
	for (int stateNo1 = 0; stateNo1 < nState_; stateNo1 ++)
		for (int stateNo2 = 0; stateNo2 < nState_; stateNo2 ++) {
			probMat[stateNo1][stateNo2] *= mymodel.stateFrqArr_[stateNo1];
			derv1[stateNo1][stateNo2] *= mymodel.stateFrqArr_[stateNo1];
			derv2[stateNo1][stateNo2] *= mymodel.stateFrqArr_[stateNo1];
		}
}
/**
	compute the array of probability of changing matrix 
	and the 1st and 2nd derivatives
	times frequencies between all pairs of states
	after a period of brLen / subRate_
	@param brLen branch length (or evolutional time)
	@param probMat (OUT) the prob. change matrix
	@param derv1 (OUT) 1st of P(t)
	@param derv2 (OUT) 2nd of P(t)
*/
void LiBr::cmpProbChangeDerivativesFrq (const double brLen, int classNo, DMat20 &probMat,
                                        DMat20 &derv1, DMat20 &derv2) {
	mymodel.cmpProbChangeDerivatives(brLen, classNo, probMat, derv1, derv2);
	int nState_ = mymodel.getNState ();
	for (int stateNo1 = 0; stateNo1 < nState_; stateNo1 ++)
		for (int stateNo2 = 0; stateNo2 < nState_; stateNo2 ++) {
			probMat[stateNo1][stateNo2] *= mymodel.stateFrqArr_[stateNo1];
			derv1[stateNo1][stateNo2] *= mymodel.stateFrqArr_[stateNo1];
			derv2[stateNo1][stateNo2] *= mymodel.stateFrqArr_[stateNo1];
		}
}


/**********************************************************
***********************************************************
*    start codes to compute log likelihood and derivatives
***********************************************************
**********************************************************/

DMat20 derv1Mat, derv2Mat, proChangeMat_;
DMat20 *derv1Mat_OMP = NULL;
DMat20 * derv2Mat_OMP = NULL;
DMat20 * proChangeMat_OMP = NULL;

/**********************************************************
*
*    UNIFORM and SPECIFIC rate
*
**********************************************************/

//--------------------------------------------------------------------
/**
	compute the log likelihood of this tree 
	using uniform rate
	@return tree log-likelihood
*/
double LiBr::cmpLogLiUniformRate () {
	opt_urtree.nCmpInLogLi_ ++;

	LiNd *headLi_;
	LiNd *tailLi_;

	if (isEx_) {
		headLi_ = greLiNd_;
		tailLi_ = smaLiNd_;
	} else {
		headLi_ = &getHeadLiNd ();
		tailLi_ = &getTailLiNd ();
	}

#ifdef _OPENMP
	if (omp_threads > 1 && !proChangeMat_OMP && isSiteSpec()) {
		proChangeMat_OMP = new DMat20[omp_threads];
	}
#endif

	if (!isSiteSpec()) cmpProbChangeFrq (len_, proChangeMat_);
	int nState_ = mymodel.getNState ();
	double logLi_ = 0.0;
	int nPtn_ = ptnlist.getNPtn ();
	
#ifdef _OPENMP
	#pragma omp parallel for  reduction(+: logLi_)
#endif
	for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
		if (isSiteSpec() && atBound(MIN_PTN_RATE, myrate.getPtnRate(ptnNo_), MAX_PTN_RATE)) 
			continue;

		DMat20 *proChangeMatRef = &proChangeMat_;

#ifdef _OPENMP
		if (isSiteSpec() && omp_threads > 1) {
			int thread_num = omp_get_thread_num();
			proChangeMatRef = &proChangeMat_OMP[thread_num];
		} 
#endif
		if (isSiteSpec()) 
			cmpProbChangeFrq (len_ * myrate.getPtnRate(ptnNo_), *proChangeMatRef);

		LDOUBLE ptnLi_ = 0.0;
		int start_addr = ptnNo_ * nState_;
		int exStateNo_ = -1;
		if (isEx_) exStateNo_ = ptnlist.getBase(ptnNo_, smaNdNo_);


		if (exStateNo_ >= 0 && exStateNo_ < nState_) {
			for (int inStateNo_ = 0; inStateNo_ < nState_; inStateNo_ ++)
				ptnLi_ += greLiNd_->liNdCube_[start_addr + inStateNo_] *
					(*proChangeMatRef)[exStateNo_][inStateNo_];
		} else
		for (int headStateNo_ = 0; headStateNo_ < nState_; headStateNo_ ++) 
		//if (headLi_->liNdCube_[start_addr + headStateNo_] != 0.0)
		{
			LDOUBLE stateLi_ = 0.0;
			if (exStateNo_ == BS_UNKNOWN)
				stateLi_ = mymodel.stateFrqArr_[headStateNo_];
			else
			for (int tailStateNo_ = 0; tailStateNo_ < nState_; tailStateNo_ ++)
				stateLi_ += tailLi_->liNdCube_[start_addr + tailStateNo_] *
				    (*proChangeMatRef)[headStateNo_][tailStateNo_];
			ptnLi_ += stateLi_ * headLi_->liNdCube_[start_addr + headStateNo_];
		}

		if (ptnLi_ <= 0.0) {
			ptnLi_ = VERY_SMALL_POSITIVE;
			Utl::announceError (BAD_CMP_LI);
		}

		double logPtnLi_ = log (ptnLi_);
		if (ptn_logl) {
			ptn_logl[ptnNo_] += logPtnLi_;
		}
		logLi_ += logPtnLi_ * ptnlist.weightArr_[ptnNo_];
	}

	logLi_ += headLi_->liscale + tailLi_->liscale;

	return logLi_;
}


// COMBINED VERSION OF In and Ex
/**
	compute the log likelihood derivative 
	using uniform rate
	@param calc_logli true if you want to compute log-likelihood
	@param logli_derv1 (OUT) 1st derivative of log-likelihood
	@param logli_derv2 (OUT) 2nd derivative of log-likelihood
	@return log-likelihood if calc_logli==true
*/
double LiBr::cmpLogLiDerivativeUniformRate (double &logli_derv1_ret, double
&logli_derv2_ret, bool calc_logli) {
	//opt_urtree.nCmpInLogLi_ ++;

	LiNd *headLi_;
	LiNd *tailLi_;

	if (isEx_) {
		headLi_ = greLiNd_;
		tailLi_ = smaLiNd_;
	} else {
		headLi_ = &getHeadLiNd ();
		tailLi_ = &getTailLiNd ();
	}
#ifdef _OPENMP
	if (omp_threads > 1 && !derv1Mat_OMP && isSiteSpec()) {
		if (!proChangeMat_OMP) proChangeMat_OMP = new DMat20[omp_threads];
		derv1Mat_OMP = new DMat20[omp_threads];
		derv2Mat_OMP = new DMat20[omp_threads];
	}
#endif

	if (!isSiteSpec()) cmpProbChangeDerivativesFrq (len_, proChangeMat_, derv1Mat, derv2Mat);

	int nState_ = mymodel.getNState ();
	double logli = 0.0;
	double logli_derv1 = 0.0;
	double logli_derv2 = 0.0;
	int nPtn_ = ptnlist.getNPtn ();

#ifdef _OPENMP
	#pragma omp parallel for reduction(+: logli, logli_derv1, logli_derv2)
#endif
	for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
		if (isSiteSpec() && atBound(MIN_PTN_RATE, myrate.getPtnRate(ptnNo_), MAX_PTN_RATE))
			continue;

		DMat20 *proChangeMatRef = &proChangeMat_;
		DMat20 *derv1MatRef = &derv1Mat;
		DMat20 *derv2MatRef = &derv2Mat;

#ifdef _OPENMP
		if (isSiteSpec() && omp_threads > 1) {
			int thread_num = omp_get_thread_num();
			proChangeMatRef = &proChangeMat_OMP[thread_num];
			derv1MatRef = &derv1Mat_OMP[thread_num];
			derv2MatRef = &derv2Mat_OMP[thread_num];
		} 
#endif
		if (isSiteSpec()) {
			double ptnRate_ = myrate.getPtnRate(ptnNo_);
			cmpProbChangeDerivativesFrq (len_ * ptnRate_, *proChangeMatRef, *derv1MatRef, *derv2MatRef);
			double rate_sqr = ptnRate_ * ptnRate_;
			for (int i = 0; i < nState_; i++)
				for (int j = 0; j < nState_; j++) {
					(*derv1MatRef)[i][j] *= ptnRate_;
					(*derv2MatRef)[i][j] *= rate_sqr;
				}
		}

		LDOUBLE ptnLi_ = 0.0;
		LDOUBLE ptnLi_1 = 0.0;
		LDOUBLE ptnLi_2 = 0.0;
		int start_addr = ptnNo_ * nState_;
		int exStateNo_ = -1;
		if (isEx_) exStateNo_ = ptnlist.getBase(ptnNo_, smaNdNo_);

		if (exStateNo_ >= 0 && exStateNo_ < nState_) {
			for (int inStateNo_ = 0; inStateNo_ < nState_; inStateNo_ ++) {
				ptnLi_ += greLiNd_->liNdCube_[start_addr + inStateNo_] *
				            (*proChangeMatRef)[exStateNo_][inStateNo_];
				ptnLi_1 += greLiNd_->liNdCube_[start_addr + inStateNo_] *
				             (*derv1MatRef)[exStateNo_][inStateNo_];
				ptnLi_2 += greLiNd_->liNdCube_[start_addr + inStateNo_] *
				             (*derv2MatRef)[exStateNo_][inStateNo_];
			}
		} else
		for (int headStateNo_ = 0; headStateNo_ < nState_; headStateNo_ ++) 
		//if (headLi_->liNdCube_[start_addr + headStateNo_] != 0.0)
		{

			LDOUBLE stateLi_ = 0.0;
			LDOUBLE stateLi_1 = 0.0;
			LDOUBLE stateLi_2 = 0.0;
			if (exStateNo_ == BS_UNKNOWN)
				stateLi_ = mymodel.stateFrqArr_[headStateNo_];
			else
			for (int tailStateNo_ = 0; tailStateNo_ < nState_; tailStateNo_ ++) {
				stateLi_ += (*proChangeMatRef)[headStateNo_][tailStateNo_] *
				    tailLi_->liNdCube_[start_addr + tailStateNo_];
				stateLi_1 += (*derv1MatRef)[headStateNo_][tailStateNo_] *
				    tailLi_->liNdCube_[start_addr + tailStateNo_];
				stateLi_2 += (*derv2MatRef)[headStateNo_][tailStateNo_] *
				    tailLi_->liNdCube_[start_addr + tailStateNo_];
			}

			LDOUBLE temp = headLi_->liNdCube_[start_addr + headStateNo_];
			ptnLi_ += stateLi_ * temp;
			ptnLi_1 += stateLi_1 * temp;
			ptnLi_2 += stateLi_2 * temp;
		}

		if (ptnLi_ <= 0.0) {
			ptnLi_ = VERY_SMALL_POSITIVE;
			Utl::announceError (BAD_CMP_LI);
		}
			
		double temp1 = ptnLi_1 / ptnLi_;
		double temp2 = ptnLi_2 / ptnLi_;
		if (isnan(temp1) || isnan(temp2)) {
			Utl::announceError("NaN in cmpLogLiDerivativeUniformRate");
		}
		logli_derv1 += temp1 * ptnlist.weightArr_[ptnNo_];
		logli_derv2 += (temp2 - temp1 * temp1) * ptnlist.weightArr_[ptnNo_];
		if (calc_logli) {
			double logPtnLi_ = log (ptnLi_);
			logli += logPtnLi_ * ptnlist.weightArr_[ptnNo_];
		}
	}
	logli_derv1_ret = logli_derv1;
	logli_derv2_ret = logli_derv2;
	logli += headLi_->liscale + tailLi_->liscale;
	return logli;
}

/**********************************************************
*
*    GAMMA + YANG CODON MODEL 
*
**********************************************************/

DMat20 probChangeCube_[MAX_NUM_RATE];
DMat20 prob_cube_derv1[MAX_NUM_RATE];
DMat20 prob_cube_derv2[MAX_NUM_RATE];

/**
	compute the log likelihood of this tree 
	using gamma rate
	@return tree log-likelihood
*/

double LiBr::cmpLogLiGammaRate () {
	opt_urtree.nCmpInLogLi_ ++;

	LiNd *headLi_;
	LiNd *tailLi_;
	if (isEx_) {
		headLi_ = greLiNd_;
		tailLi_ = smaLiNd_;
	} else {
		headLi_ = &getHeadLiNd ();
		tailLi_ = &getTailLiNd ();
	}

	int nState_ = mymodel.getNState ();
	int nRate_ = myrate.getNRate ();
	int rate_state = nRate_ * nState_;
	double probRate_ = myrate.getProb ();

	int rateNo_;
	for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
		if (myrate.isNsSyHeterogenous()) {
			cmpProbChangeFrq (len_, rateNo_, probChangeCube_[rateNo_]);
		} else {
			cmpProbChangeFrq (len_ * myrate.getRate (rateNo_), probChangeCube_[rateNo_]);
		}
	}

	double logLi_ = 0.0;
	int nPtn_ = ptnlist.getNPtn ();
#ifdef _OPENMP
	#pragma omp parallel for  private(rateNo_) reduction(+: logLi_)
#endif
	for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
		LDOUBLE ptnLi_ = 0.0;
		int start_addr = ptnNo_ * rate_state;
		int exStateNo_ = -1;
		if (isEx_) exStateNo_ = ptnlist.getBase(ptnNo_, smaNdNo_);
		LDOUBLE ratePtnLi_ = 0.0;
		for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
			
			int local_addr = start_addr + rateNo_ * nState_;

			if (exStateNo_ >= 0 && exStateNo_ < nState_) {
/*
				if (myrate.use_invar_site && rateNo_ == nRate_ - 1)
					ratePtnLi_ = greLiNd_->liNdCube_[local_addr + exStateNo_] *
					               probChangeCube_[rateNo_][exStateNo_][exStateNo_];
				else*/
					for (int inStateNo_ = 0; inStateNo_ < nState_; inStateNo_ ++)
						ratePtnLi_ += greLiNd_->liNdCube_[local_addr + inStateNo_] *
						                probChangeCube_[rateNo_][exStateNo_][inStateNo_];
			} else
			for (int headStateNo_ = 0; headStateNo_ < nState_; headStateNo_ ++) 
			//if (headLi_->liNdCube_[local_addr + headStateNo_] != 0.0) 
			{
				LDOUBLE stateLi_ = 0.0;

				if (exStateNo_ == BS_UNKNOWN)
					stateLi_ = mymodel.stateFrqArr_[headStateNo_];
				else /*
				if (myrate.use_invar_site && rateNo_ == nRate_ - 1) {
					stateLi_ = probChangeCube_[rateNo_][headStateNo_][headStateNo_] *
						tailLi_->liNdCube_[local_addr + headStateNo_];
				} else*/ {
					for (int tailStateNo_ = 0; tailStateNo_ < nState_; tailStateNo_ ++)
						stateLi_ += probChangeCube_[rateNo_][headStateNo_][tailStateNo_] *
						    tailLi_->liNdCube_[local_addr + tailStateNo_];

				}
				ratePtnLi_ += stateLi_ * headLi_->liNdCube_[local_addr + headStateNo_];
			}

		}

		if (myrate.isNsSyHeterogenous())
			ptnLi_ = ratePtnLi_ * mymodel.getClassProb(rateNo_);
		else {/*if (!myrate.use_invar_site || rateNo_ < nRate_ - 1)*/
			ptnLi_ = ratePtnLi_  * probRate_;

			if (ptnlist.getPtn(ptnNo_).is_const && ptnlist.getBase(ptnNo_, 0) < nState_)
				ptnLi_ += myrate.prob_invar_site * mymodel.stateFrqArr_[ptnlist.getBase(ptnNo_, 0)];
		}


		if (ptnLi_ <= 0.0) {
			ptnLi_ = VERY_SMALL_POSITIVE;
			Utl::announceError (BAD_CMP_LI);
		}

		double logPtnLi_ = log (ptnLi_);
		if (ptn_logl) ptn_logl[ptnNo_] += logPtnLi_;
		logLi_ += logPtnLi_ * ptnlist.weightArr_[ptnNo_];
	}

	logLi_ += headLi_->liscale + tailLi_->liscale;

	return logLi_;
}


/**
	compute the log likelihood of this tree 
	using gamma rate
	@param calc_logli true if you want to compute log-likelihood
	@param logli_derv1 (OUT) 1st derivative of log-likelihood
	@param logli_derv2 (OUT) 2nd derivative of log-likelihood
	@return log-likelihood if calc_logli==true
*/
double LiBr::cmpLogLiDerivativeGammaRate (double &logli_derv1_ret, double
&logli_derv2_ret, bool calc_logli) {
	opt_urtree.nCmpInLogLi_ ++;

	LiNd *headLi_;
	LiNd *tailLi_;

	if (isEx_) {
		headLi_ = greLiNd_;
		tailLi_ = smaLiNd_;
	} else {
		headLi_ = &getHeadLiNd ();
		tailLi_ = &getTailLiNd ();
	}

	int nState_ = mymodel.getNState ();
	int nRate_ = myrate.getNRate ();
	int rate_state = nRate_ * nState_;
	double probRate_ = myrate.getProb ();

	int rateNo_;
	for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
		if (myrate.isNsSyHeterogenous()) 
			cmpProbChangeDerivativesFrq (len_, rateNo_, probChangeCube_[rateNo_],
				prob_cube_derv1[rateNo_], prob_cube_derv2[rateNo_]);
		else {
			double rate_ = myrate.getRate (rateNo_);
			cmpProbChangeDerivativesFrq (len_ * rate_, probChangeCube_[rateNo_],
				prob_cube_derv1[rateNo_], prob_cube_derv2[rateNo_]);
			double rate_sqr = rate_ * rate_;
			for (int i = 0; i < nState_; i++)
				for (int j = 0; j < nState_; j++) {
					prob_cube_derv1[rateNo_][i][j] *= rate_;
					prob_cube_derv2[rateNo_][i][j] *= rate_sqr;
				}
		}
	}

	double logli = 0.0;
	double logli_derv1 = 0.0;
	double logli_derv2 = 0.0;
	int nPtn_ = ptnlist.getNPtn ();

#ifdef _OPENMP
	#pragma omp parallel for  private(rateNo_) reduction(+: logli, logli_derv1, logli_derv2)
#endif
	for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
		LDOUBLE ptnLi_ = 0.0;
		LDOUBLE ptnLi_1 = 0.0;
		LDOUBLE ptnLi_2 = 0.0;
		int start_addr = ptnNo_ * rate_state;
		int exStateNo_ = -1;
		if (isEx_) exStateNo_ = ptnlist.getBase(ptnNo_, smaNdNo_);
		LDOUBLE ratePtnLi_ = 0.0;
		LDOUBLE ratePtnLi_1 = 0.0;
		LDOUBLE ratePtnLi_2 = 0.0;

		for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
			//bool invar_calc = (myrate.use_invar_site && rateNo_ == nRate_ - 1);
			int local_addr = start_addr + rateNo_ * nState_;

			if (exStateNo_ >= 0 && exStateNo_ < nState_) {
				/*if (invar_calc)
					ratePtnLi_ = greLiNd_->liNdCube_[local_addr + exStateNo_] *
					               probChangeCube_[rateNo_][exStateNo_][exStateNo_];*/
				for (int inStateNo_ = 0; inStateNo_ < nState_; inStateNo_ ++) {
					//if (!invar_calc)
						ratePtnLi_ += greLiNd_->liNdCube_[local_addr + inStateNo_] *
						                probChangeCube_[rateNo_][exStateNo_][inStateNo_];
					ratePtnLi_1 += greLiNd_->liNdCube_[local_addr + inStateNo_] *
					                 prob_cube_derv1[rateNo_][exStateNo_][inStateNo_];
					ratePtnLi_2 += greLiNd_->liNdCube_[local_addr + inStateNo_] *
					                 prob_cube_derv2[rateNo_][exStateNo_][inStateNo_];
				}

			} else 
			for (int headStateNo_ = 0; headStateNo_ < nState_; headStateNo_ ++) 
			//if (headLi_->liNdCube_[local_addr + headStateNo_] != 0.0) 
			{
				LDOUBLE stateLi_ = 0.0;
				LDOUBLE stateLi_1 = 0.0;
				LDOUBLE stateLi_2 = 0.0;
				if (exStateNo_ == BS_UNKNOWN)
					stateLi_ = mymodel.stateFrqArr_[headStateNo_];
				else {
					/*if (invar_calc)
						stateLi_ = probChangeCube_[rateNo_][headStateNo_][headStateNo_] *
							tailLi_->liNdCube_[local_addr + headStateNo_];*/
					for (int tailStateNo_ = 0; tailStateNo_ < nState_; tailStateNo_ ++) {
						//if (!invar_calc)
							stateLi_ += probChangeCube_[rateNo_][headStateNo_][tailStateNo_] *
								tailLi_->liNdCube_[local_addr + tailStateNo_];
						stateLi_1 += prob_cube_derv1[rateNo_][headStateNo_][tailStateNo_] *
							tailLi_->liNdCube_[local_addr + tailStateNo_];
						stateLi_2 += prob_cube_derv2[rateNo_][headStateNo_][tailStateNo_] *
							tailLi_->liNdCube_[local_addr + tailStateNo_];
					}
				}

				LDOUBLE temp = headLi_->liNdCube_[local_addr + headStateNo_];

				ratePtnLi_ += stateLi_ * temp;
				ratePtnLi_1 += stateLi_1 * temp;
				ratePtnLi_2 += stateLi_2 * temp;
			}
		}

		if (myrate.isNsSyHeterogenous()) {
			ptnLi_ = ratePtnLi_  * mymodel.getClassProb(rateNo_);
			ptnLi_1 = ratePtnLi_1 * mymodel.getClassProb(rateNo_);
			ptnLi_2 = ratePtnLi_2 * mymodel.getClassProb(rateNo_);
		} else	/*if (!myrate.use_invar_site || rateNo_ < nRate_ - 1) */{
			ptnLi_ = ratePtnLi_  * probRate_;
			ptnLi_1 = ratePtnLi_1 * probRate_;
			ptnLi_2 = ratePtnLi_2 * probRate_;
			if (ptnlist.getPtn(ptnNo_).is_const && ptnlist.getBase(ptnNo_, 0) < nState_) {
				ptnLi_ += myrate.prob_invar_site * mymodel.stateFrqArr_[ptnlist.getBase(ptnNo_, 0)];
			}

		}/* else {
			ptnLi_ += ratePtnLi_ * myrate.prob_invar_site;
			ptnLi_1 += ratePtnLi_1 * myrate.prob_invar_site;
			ptnLi_2 += ratePtnLi_2 * myrate.prob_invar_site;
		}*/

		if (ptnLi_ <= 0.0) {
			ptnLi_ = VERY_SMALL_POSITIVE;
			Utl::announceError (BAD_CMP_LI);
		}

		double temp1 = ptnLi_1 / ptnLi_;
		double temp2 = ptnLi_2 / ptnLi_;
		logli_derv1 += temp1 * ptnlist.weightArr_[ptnNo_];
		logli_derv2 += (temp2 - temp1 * temp1) * ptnlist.weightArr_[ptnNo_];
		if (calc_logli)
			logli += log (ptnLi_) * ptnlist.weightArr_[ptnNo_];
	}
	logli_derv1_ret = logli_derv1;
	logli_derv2_ret = logli_derv2;
	logli += headLi_->liscale + tailLi_->liscale;
	return logli;
}


/**********************************************************
*
*    Computing the GAMMA site-rate based on empirical Bayesian analysis
*
**********************************************************/

void LiBr::empiricalBayesGammaRate (ostream &out) {

	int nState_ = mymodel.getNState ();
	int nRate_ = myrate.getNRate ();
	int rate_state = nRate_ * nState_;
	double probRate_ = myrate.getProb ();
	int ptnNo_;

	int rateNo_;
	int nPtn_ = ptnlist.getNPtn ();
	Vec<double> *ptn_rate = myrate.getPtnRate();
	ptn_rate->set(nPtn_, nPtn_);


	for (ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
		LDOUBLE exPtnLi_ = 0.0;
		int start_addr = ptnNo_ * rate_state;
		(*ptn_rate)[ptnNo_] = 0.0;

		for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
			int local_addr = start_addr + rateNo_ * nState_;

			LDOUBLE rateExPtnLi_ = 0.0;
			for (int inStateNo_ = 0; inStateNo_ < nState_; inStateNo_ ++) {
				double exStateLi_ = 0.0;

				/*if (myrate.use_invar_site && rateNo_ == nRate_ - 1)
					exStateLi_ = probChangeCube_[rateNo_][inStateNo_][inStateNo_] *
						smaLiNd_->liNdCube_[local_addr + inStateNo_];
				else*/ for (int exStateNo_ = 0; exStateNo_ < nState_; exStateNo_ ++)
					exStateLi_ += probChangeCube_[rateNo_][inStateNo_][exStateNo_] *
						smaLiNd_->liNdCube_[local_addr + exStateNo_];

				rateExPtnLi_ += greLiNd_->liNdCube_[local_addr + inStateNo_] * exStateLi_;
			}

			double li_rate;

			/*if (!myrate.use_invar_site || rateNo_ < nRate_ - 1) {*/
			li_rate = rateExPtnLi_ * probRate_;
			(*ptn_rate)[ptnNo_] += li_rate * myrate.getRate(rateNo_);
			/*} else {
				li_rate = rateExPtnLi_ * myrate.prob_invar_site;
			}*/
			exPtnLi_ += li_rate;
		}

		if (exPtnLi_ <= 0.0) {
			exPtnLi_ = VERY_SMALL_POSITIVE;
			Utl::announceError (BAD_CMP_LI);
		}

		(*ptn_rate)[ptnNo_] /= exPtnLi_;
	}
}


//-----------------------------------------------
//-  codon-based Nielsen Yang Model
//-----------------------------------------------
/**
	compute the log likelihood of this tree via this an external 	branch in case of one rate
*/

void LiBr::empiricalBayesAnalysis(ostream &out) {

	if (!isMasterProc())
		return;

	cout << "Inferring sites under positive selection..." << endl;

	int classNo_;
	int nState_ = mymodel.getNState ();
	int nClass_ = myrate.getNRate ();
	int class_state = nClass_ * nState_;
	int siteNo_;
	char line[100];

	out << endl << "Positively selected codon sites (*: P>95%; **: P>99%)" << endl;
	out << " Site   Pr(w>1)     Mean w" << endl;
	
	double threshold0 = 0.5;
	double threshold1 = 0.95;
	double threshold2 = 0.99;


	Seq *seq_;
	seq_ = &alignment.getSeq (smaNdNo_);


	double sum = 0;
	for (classNo_ = 0; classNo_ < nClass_; classNo_ ++) {
		if (classNo_ < nClass_ - 1)
			sum += mymodel.getClassProb(classNo_);
		cmpProbChangeFrq (len_, classNo_, probChangeCube_[classNo_]);
	}
	if (fabs(sum + mymodel.getClassProb(nClass_ - 1) - 1.0) > ZERO) {
		cout << "sum = " << sum + mymodel.getClassProb(nClass_ - 1) << "!!!!" << endl;
		((CodonNY98*) mymodel.model_)->nsSyProbVec[nClass_-1] = 1.0 - sum;
	}

	int nSite_ = alignment.getNSite();



	LDOUBLE *rateExPtnLi_ = new LDOUBLE [nClass_];
	double *logExPtnLi_ = new double [nSite_];
	double *mean_w = new double [nSite_];
	double *pos_probability = new double[nSite_];

	for (siteNo_ = 0; siteNo_ < nSite_; siteNo_ ++) {

		int ptnNo_ = alignment.getPtn(siteNo_);
		int exStateNo_ = (*seq_).items_[siteNo_];
		LDOUBLE exPtnLi_ = 0.0;
		LDOUBLE max_value = 0.0;
		int start_addr = ptnNo_ * class_state;

		for (classNo_ = 0; classNo_ < nClass_; classNo_ ++) {
			int local_addr = start_addr + classNo_ * nState_;
			rateExPtnLi_[classNo_] = 0.0;
			if (exStateNo_ < nState_) {
				for (int inStateNo_ = 0; inStateNo_ < nState_; inStateNo_ ++) {
					rateExPtnLi_[classNo_] += greLiNd_->liNdCube_[local_addr + inStateNo_] *
					                          probChangeCube_[classNo_][exStateNo_][inStateNo_];

				}
			} else {
				for (int inStateNo_ = 0; inStateNo_ < nState_; inStateNo_ ++) {
					LDOUBLE exStateLi_ = 0.0;

					if (exStateNo_ == BS_UNKNOWN)
						exStateLi_ = mymodel.stateFrqArr_[inStateNo_];
					else {
						for (int exStateNo_ = 0; exStateNo_ < nState_; exStateNo_ ++)
							exStateLi_ += probChangeCube_[classNo_][inStateNo_][exStateNo_] *
							              smaLiNd_->liNdCube_[local_addr + exStateNo_];
					}

					rateExPtnLi_[classNo_] += greLiNd_->liNdCube_[local_addr + inStateNo_] * exStateLi_;
				}
			}

			rateExPtnLi_[classNo_] *= mymodel.getClassProb(classNo_);
			if (mymodel.getClassRatio(classNo_) > 1.0) {
				max_value += rateExPtnLi_[classNo_];
			}
			exPtnLi_ += rateExPtnLi_[classNo_];

		}

		if (exPtnLi_ < -ZERO) {
			cout << "len = " << len_ << endl;
			Utl::announceError (BAD_CMP_LI);
		}

		logExPtnLi_[siteNo_] = log (exPtnLi_);


		pos_probability[siteNo_] = max_value / exPtnLi_;
		mean_w[siteNo_] = 0.0;
		for (int cNo = 0; cNo < nClass_; cNo++)
			mean_w[siteNo_] += (rateExPtnLi_[cNo] / exPtnLi_) * mymodel.getClassRatio(cNo);
			
		if (pos_probability[siteNo_] > threshold0) {
			sprintf(line, "%5d %8.3f", alignment.getOriginSite(siteNo_) + 1, pos_probability[siteNo_]);
			out << line;
			if (pos_probability[siteNo_] > threshold1)
				out << "*";
			else
				out << " ";
			if (pos_probability[siteNo_] > threshold2)
				out << "*";
			else
				out << " ";
			sprintf(line, "%10.3f", mean_w[siteNo_]);

			out << line << endl;

		}
	}


	out << endl << "Detailed information for all codon sites";
	if (alignment.getOriginSite(nSite_-1) != nSite_-1) {
		out << " (Some sites were discarded due to gaps/stop codons)";
	}
	out << ":" << endl;
	out << " Site   Pr(w>1)     Mean w     LogL" << endl;


	for (siteNo_ = 0; siteNo_ < nSite_; siteNo_ ++) {
		sprintf(line, "%5d %8.3f %11.3f %16.10f", alignment.getOriginSite(siteNo_) + 1, 
			pos_probability[siteNo_], mean_w[siteNo_], logExPtnLi_[siteNo_]);
		out << line << endl;
		
	}

	delete mean_w;
	delete logExPtnLi_;
	delete rateExPtnLi_;
	
}


/**********************************************************
*
*    General functions to compute log likelihood and derivatives
*
**********************************************************/

inline bool LiBr::isSiteSpec() {
	return (myrate.getType () == SITE_SPECIFIC && myrate.isOptedSpecificRate ());
}



/**
	compute the log-likelihood from external branch
	@return tree log-likelihood
*/

double LiBr::cmpLogLi () {

	RATE_TYPE rateType_ = myrate.getType ();

	if (myrate.isNsSyHeterogenous() || rateType_ == GAMMA)
		return cmpLogLiGammaRate ();
	else 
		return cmpLogLiUniformRate ();
}


// compute the log-likelihood from external branchd
double LiBr::cmpLogLiDerivatives (double &logli_derv1, double &logli_derv2, bool calc_logli) {

	RATE_TYPE rateType_ = myrate.getType ();
	if (myrate.isNsSyHeterogenous() || rateType_ == GAMMA)
		return cmpLogLiDerivativeGammaRate (logli_derv1, logli_derv2, calc_logli);
	else 
		return cmpLogLiDerivativeUniformRate (logli_derv1, logli_derv2, calc_logli);

}

//-------------------------------------------------------------------
//set the two liNds to be not cmped
void LiBr::openLiNd () {
	greLiNd_->turnOffIsCmped ();
	smaLiNd_->turnOffIsCmped ();
}

//-------------------------------------------------------------------
//set the two liNds to be not cmped
void LiBr::openTailLiNd () {
	if (getTailNd () == greNdNo_ )
		greLiNd_->turnOffIsCmped ();
	else
		smaLiNd_->turnOffIsCmped ();
}

//-------------------------------------------------------------------
//set the two liNds to be not cmped
void LiBr::openHeadLiNd () {
	if (getHeadNd () == greNdNo_ )
		greLiNd_->turnOffIsCmped ();
	else
		smaLiNd_->turnOffIsCmped ();
}


//======================================================
//set memory for this liBr

void LiBr::setLimit () {}

//======================================================
//release all memory of this class
void LiBr::release () {
	smaLiNd_->release ();
	greLiNd_->release ();
}

//======================================================
//the destructor
LiBr::~LiBr () {
	release ();
}
