/***************************************************************************
 *                                                                         *
 *   hky85m.cpp     (begin: Feb 20 2003)                                   *
 *                                                                         *
 *   Parallel IQPNNI - Important Quartet Puzzle with NNI                   *
 *                                                                         *
 *   Copyright (C) 2005 by Le Sy Vinh, Bui Quang Minh, Arndt von Haeseler  *
 *   Copyright (C) 2003-2004 by Le Sy Vinh, Arndt von Haeseler             *
 *   {vinh,minh}@cs.uni-duesseldorf.de                                     *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include <math.h>
#include <iostream>

#include "hky85m.h"
#include "brent.h"
#include "opturtree.h"
#include "outstream.h"
#include "interface.h"

#ifdef PARALLEL
#include <mpi.h>
#endif

//the constructor function of this class
HKY85M::HKY85M () {}

//--------------------------------------------------------------------
//compute the substitution rate per unit time for this model
double HKY85M::cmpSubRate () {
	return sub_rate_base + tsTvRatio_ * sub_rate_coefficent;
}

//--------------------------------------------------------------------
//all things are inited here
void HKY85M::init () {

	sub_rate_base = 2 * (stateFrqArr_[BASE_A] + stateFrqArr_[BASE_G]) * (stateFrqArr_[BASE_C] + stateFrqArr_[BASE_T]);
	sub_rate_coefficent = 2 * (stateFrqArr_[BASE_A]*stateFrqArr_[BASE_G] + stateFrqArr_[BASE_C]*stateFrqArr_[BASE_T]);

	for (int base = BASE_A; base <= BASE_G; base++) {
		double totalState_;
		if (base == BASE_A || base == BASE_G)
			totalState_ = stateFrqArr_[BASE_A] + stateFrqArr_[BASE_G];
		else
			totalState_ = stateFrqArr_[BASE_C] + stateFrqArr_[BASE_T];

		coefficient1[base] = stateFrqArr_[base] * (1.0/totalState_ - 1.0);
		coefficient2[base] = (totalState_ - stateFrqArr_[base]) / totalState_;
		coefficient3[base] = stateFrqArr_[base]/totalState_;
	}

	total_num_subst = cmpSubRate ();
}

//--------------------------------------------------------------------
//all things are inited here
void HKY85M::reCmpSubRate () {
	total_num_subst = cmpSubRate ();
}

//--------------------------------------------------------------------
/*
cmp the probability of changing from nucleotide stateNo1 into nucleotide stateNo2
after a period of brLen / subRate_
*/

double HKY85M::cmpProbChange (const int stateNo1, const int stateNo2, const double brLen) {
	double prob_;
	double time_ = brLen  / total_num_subst;

	double totalState_;
	if (stateNo2 == BASE_A || stateNo2 == BASE_G)
		totalState_ = stateFrqArr_[BASE_A] + stateFrqArr_[BASE_G];
	else
		totalState_ = stateFrqArr_[BASE_C] + stateFrqArr_[BASE_T];

	double ax_ = 1.0 + totalState_ * (tsTvRatio_ - 1.0);

	if (stateNo1 == stateNo2) {
		prob_ = stateFrqArr_[stateNo2];
		prob_ += (brLen == 0.0) ? coefficient1[stateNo2] : coefficient1[stateNo2] * exp(-time_);
		prob_ += (brLen == 0.0) ? coefficient2[stateNo2] : coefficient2[stateNo2] * exp(-time_ * ax_);

	} else {
		//count probability of ts (purine -> purine, or pyridimine -> pyridimine)
		if ( stateNo1 + stateNo2 == TS_SIGN ) {
			prob_ = stateFrqArr_[stateNo2];
			prob_ += (brLen == 0.0) ? coefficient1[stateNo2] : coefficient1[stateNo2] * exp(-time_);
			prob_ -= (brLen == 0.0) ? coefficient3[stateNo2] : coefficient3[stateNo2] * exp(-time_ * ax_);
		} else //count probability of tv (purine <-> pyridimine)
			prob_ = (brLen == 0.0) ? 0.0 : stateFrqArr_[stateNo2] * (1.0 - exp (-time_) );
	}

	return prob_;
}

//--------------------------------------------------------------------
/*
cmp the array of probability of changing from one nucleotide 
into one nucleotide after a period of brLen / subRate_
*/
void HKY85M::cmpProbChange (const double brLen, DMat20 &probMat) {
	double time_ = brLen  / total_num_subst;
	double exptime = (brLen == 0.0) ? 1.0 : exp(-time_);
	double ag_exptime = (brLen == 0.0) ? 1.0 : exp (-time_ * ( 1.0 + (stateFrqArr_[BASE_A] + stateFrqArr_[BASE_G]) * (tsTvRatio_ - 1.0)));
	double ct_exptime = (brLen == 0.0) ? 1.0 : exp (-time_ * ( 1.0 + (stateFrqArr_[BASE_C] + stateFrqArr_[BASE_T]) * (tsTvRatio_ - 1.0)));
	
	for (int stateNo1 = 0; stateNo1 < NUM_BASE; stateNo1++)
		for (int stateNo2 = 0; stateNo2 < NUM_BASE; stateNo2++) {
			double exptime_ax;
			if (stateNo2 == BASE_A || stateNo2 == BASE_G)
				exptime_ax = ag_exptime;
			else
				exptime_ax = ct_exptime;
				
			if (stateNo1 == stateNo2) {
				probMat[stateNo1][stateNo2] = stateFrqArr_[stateNo2] + coefficient1[stateNo2] * exptime + 
					coefficient2[stateNo2] * exptime_ax;
			} else {
				//count probability of ts (purine -> purine, or pyridimine -> pyridimine)
				if ( stateNo1 + stateNo2 == TS_SIGN ) {
					probMat[stateNo1][stateNo2] = stateFrqArr_[stateNo2] + coefficient1[stateNo2] * exptime - 
						coefficient3[stateNo2] * exptime_ax;
				} else //count probability of tv (purine <-> pyridimine)
					probMat[stateNo1][stateNo2] = stateFrqArr_[stateNo2] * (1.0 - exptime );
			}
		
		} //end for stateNo1
}


//--------------------------------------------------------------------
/*
cmp the array of probability derivatives of changing from one nucleotide 
into one nucleotide after a period of brLen / subRate_
*/
void HKY85M::cmpProbChangeDerivatives (const double brLen, DMat20 &probMat, DMat20 &derv1, DMat20 &derv2) {
	double time_ = brLen  / total_num_subst;
	double exptime = (brLen == 0.0) ? 1.0 : exp(-time_);
	double ag_ax = 1.0 + (stateFrqArr_[BASE_A] + stateFrqArr_[BASE_G]) * (tsTvRatio_ - 1.0);
	double ct_ax = 1.0 + (stateFrqArr_[BASE_C] + stateFrqArr_[BASE_T]) * (tsTvRatio_ - 1.0);
	double ag_exptime = (brLen == 0.0) ? 1.0 : exp (-time_ * ag_ax);
	double ct_exptime = (brLen == 0.0) ? 1.0 : exp (-time_ * ct_ax);
	
	for (int stateNo1 = 0; stateNo1 < NUM_BASE; stateNo1++)
		for (int stateNo2 = 0; stateNo2 < NUM_BASE; stateNo2++) {
			double exptime_ax;
			double ax_;
			if (stateNo2 == BASE_A || stateNo2 == BASE_G) {
				ax_ = ag_ax;
				exptime_ax = ag_exptime;
			} else {
				ax_ = ct_ax;
				exptime_ax = ct_exptime;
			}
			if (stateNo1 == stateNo2) {
				double value1 = coefficient1[stateNo2] * exptime;
				double value2 = coefficient2[stateNo2] * exptime_ax;
				double temp = value2 * ax_;
				
				derv1[stateNo1][stateNo2] = -value1 - temp;
				derv2[stateNo1][stateNo2] = value1 + temp * ax_;
				probMat[stateNo1][stateNo2] = stateFrqArr_[stateNo2] + value1 + value2;
			} else {
				//count probability of ts (purine -> purine, or pyridimine -> pyridimine)
				if ( stateNo1 + stateNo2 == TS_SIGN ) {
					double value1 = coefficient1[stateNo2] * exptime;
					double value3 = coefficient3[stateNo2] * exptime_ax;
					double temp = value3 * ax_;
				
					derv1[stateNo1][stateNo2] = -value1 + temp;
					derv2[stateNo1][stateNo2] = value1 - temp * ax_;
					
					probMat[stateNo1][stateNo2] = stateFrqArr_[stateNo2] + value1 - value3;
				} else //count probability of tv (purine <-> pyridimine) 
				{
					double value = stateFrqArr_[stateNo2] * exptime;
					derv1[stateNo1][stateNo2] = value;
					derv2[stateNo1][stateNo2] = -value;
					probMat[stateNo1][stateNo2] = stateFrqArr_[stateNo2] - value;
				}
			}
		} //end for stateNo1

}

//--------------------------------------------------------------------
/*compute the log likelihood when given tsTvRatio,
this function is used to optimize the tsTvRatio of this model
*/
double HKY85M::cmpNegLogLi (double tsTvRatio) {
	//cout << "tsvs = " << tsTvRatio << endl;
	tsTvRatio_ = tsTvRatio;
	if (tsTvRatio_ < MIN_TS_TV_RATIO)
		tsTvRatio_ = MIN_TS_TV_RATIO;
	if (tsTvRatio_ > MAX_TS_TV_RATIO)
		tsTvRatio_ = MAX_TS_TV_RATIO;


	reCmpSubRate ();
	opt_urtree.cmpLiNd ();

	double logLi_ = opt_urtree.getLogLi ();

	return -logLi_;
}

//--------------------------------------------------------------------
//optimize the tsTvRatio basing Brent method
bool HKY85M::optPam () {
	double fx_, error_;
	if (isMasterProc())
		std::cout <<"Optimizing transition/transversion ratio ..." << endl;
	Brent::turnOnOptedPam ();

	double oriTsTvRatio_ = tsTvRatio_;
	
	tsTvRatio_= optOneDim (MIN_TS_TV_RATIO, tsTvRatio_, MAX_TS_TV_RATIO,
	                       EPS_MODEL_PAM_ERROR, &fx_, &error_);

	//double logLi_ = -cmpNegLogLi (tsTvRatio_);


	//	if (isMasterProc())
	//	cout << "Time optimizing parameter(s): " << end_time - start_time << endl;

	Brent::turnOffOptedPam ();

	//	double logLi_ = -cmpNegLogLiTsTvRatio ();

	std::cout.precision (10);
	//  std::cout <<"Log likelihood: " << logLi_ << endl;

	std::cout.precision (5);
	if (isMasterProc())
		std::cout << "Transition/transversion ratio: " << tsTvRatio_ / 2.0 << endl;
		
	return (fabs (tsTvRatio_ - oriTsTvRatio_) > EPS_MODEL_PAM_ERROR);
}


//--------------------------------------------------------------------
//release all memory of this class
void HKY85M::release () {}

//--------------------------------------------------------------------
//the destructor function of this class
HKY85M::~HKY85M () {
	release ();
	//  std::cout << "this is the destructor of HKY85M class " << endl;
}

