#include "chainmanager.h"
#include <boost/math/special_functions/gamma.hpp>
#include <boost/random/mersenne_twister.hpp>
#include <fstream>
#include "stabledistribution.h"
#include "global.h"
#include <omp.h>
#include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/classification.hpp>
#include <set>

ChainManager::ChainManager(long seed, int stages, int numchains, const std::string & treePath, const std::string & dataPath, const std::string & outputBasePath, int numGens, int thinning, double igammashape, double igammascale, bool brownian, bool brownianstrict)
: numChains(numchains),
  numStages(stages),
  errorMsg(""),
  tree(0), 
  taxa(0),
  numgens(numGens),
  printfreq(thinning),
  outPath(outputBasePath),
  ofs((outputBasePath + ".progress").data())
{
	boost::mt19937 mt(seed);
	for (int i = 0; i < numChains; ++i) {
		iters.push_back(0);
		Ls.push_back(0);
        chains.push_back(new Chain(this, i, brownian, brownianstrict, treePath, dataPath, outputBasePath, igammashape, igammascale, mt()));
        chains[i]->setCycleLength(printfreq);
    }
    igamma_shape = chains[0]->igamShape();
    igamma_scale = chains[0]->igamScale();
    lgamshape = boost::math::lgamma(igamma_shape);
    errorMsg = chains[0]->errorMessage();
    if (errorMsg.size() == 0) {
    	std::ifstream t;
		t.open(treePath.data());
		std::string newick;
		if (t.is_open()) {
		    std::string line;
		    while (getline(t, line))
		        newick += line;
		    t.close();
		    tree = new Tree(newick, chains[0]->getTaxa());
		}
    }
    
    
}

ChainManager::~ChainManager()
{
	if (tree != 0)
		delete tree;
	if (taxa != 0)
		delete taxa;
	for (int i = 0; i < chains.size(); ++i) 
		delete chains[i];
}

void ChainManager::reportProgress(int id, int iter, double L)
{
	#pragma omp critical
	{
		iters[id] = iter;
		Ls[id] = L;
		std::stringstream ss;
		ss << "\r";
		for (int i = 0; i < iters.size(); i++) 
			ss << i+1 << "|" << iters[i] << " ";
		std::string str = ss.str();
		str += std::string(80-str.size(), ' ');
		std::cout << str;
		std::cout.flush();
	}
}

double ChainManager::logL(double stable_alpha, double stable_scale) const
{
    double result = 0;
    Tree::Node::iterator it = tree->begin();
    ++it;
    StableDistribution sd(stable_alpha, stable_scale);
    while (it != tree->end()) {
       result += sd.logPDF(it->value() - it->parent()->value(), it->parentEdge()->length());
       ++it;
    }
    double x = POW(stable_scale, stable_alpha);
    result +=  -(igamma_scale/x) + igamma_shape * LOG(igamma_scale/x) - LOG(x) - lgamshape;
    return result;
}

bool ChainManager::go()
{
	if (errorMsg.size() > 0)
		return false;
		
	for (int i = 0; i < chains.size(); ++i) {
		reportProgress(i, 0, chains[i]->logL());
	}	
	
	ofs << "Iteration\t";
	for (int i = 0; i < chains.size(); ++i) {
		std::stringstream ss;
		ss << "Chain" << i+1 << "_DIC" << "\t";
		ss << "Chain" << i+1 << "_PBIC" << "\t";
		ss << "Chain" << i+1 << "_pd" << "\t";
		ss << "Chain" << i+1 << "_dhat" << "\t";
		ofs << ss.str();
	}
	ofs << "max_PSRF" << std::endl;
		
	int totalSamples = numgens/printfreq;
	int b = totalSamples/numStages;

	int stage = 0;
	for (int j = 0; j < numgens; j += b*printfreq) {
		++stage;
	
		#ifdef MULTITHREADED
			#pragma omp parallel for
		#endif
	    for (int i = 0; i < chains.size(); ++i)
	        chains[i]->run(j,b*printfreq);
	    
		std::stringstream rep;
		rep << "\rStage " << stage << " of " << numStages << " is complete.";
		std::string str = rep.str();
		str += std::string(80-str.size(), ' ');
		std::cout << str << std::endl;
		
		int margin = (int)(0.025*(stage * b)/2) ; // number of items on either side of the 2-tailed 95% confidence bound
		if (margin == 0)
			margin = 1;
		
		std::vector<std::vector<double> > chainmeans; 
		std::vector<std::vector<double> > chainSumSqr;
	
		for (int i = 0; i < chains.size(); ++i) {
			chainmeans.push_back(std::vector<double>());
			chainSumSqr.push_back(std::vector<double>());
			for (int k = 0; k < 3; ++k) { // likelihood, stable alpha and stable scale
				chainmeans[chainmeans.size()-1].push_back(0);
				chainSumSqr[chainSumSqr.size()-1].push_back(0);
			}
			auto it = tree->begin();
			while (it != tree->end()) {
                if (it->taxon() == 0 || it->taxon()->value() == DBL_MAX) {
					chainmeans[chainmeans.size()-1].push_back(0);
					chainSumSqr[chainSumSqr.size()-1].push_back(0);
				}
				++it;
			}
		}
		
		ofs << j+b*printfreq << "\t";
		
		int count = 0;
		for (int i = 0; i < chains.size(); ++i) {
			std::stringstream ss;
			ss << outPath << ".chain" << i+1 << ".log";
			std::ifstream ifs(ss.str().data());
			if (ifs.is_open()) {
				std::string line;
			  	count = 0;
			  	if (ifs.good()) {
			  		getline (ifs,line);
			  	}
			  	while (count < (stage*b)/2 && ifs.good()) {
			  		getline (ifs,line);
			  		++count;
			  	}
			  	count = 0;
			  	char * ptr = 0;
				while ( ifs.good() ) {
					getline (ifs,line);
				  	if (line.size() > 0) {
				  		++count;
				  		std::vector<std::string> split;
				  		boost::split(split, line, boost::is_any_of("\t\n"), boost::token_compress_on);
				  		for (int k = 0; k < chainmeans[i].size(); ++k) {
				  			double d = strtod(split[k+1].data(), &ptr);
				  			chainmeans[i][k] += d/(((double)stage*b)/2);
				  		}
				  	}
				}
				ifs.close();
			}
			auto it = tree->begin();
		    int pos = 3;
		    while (it != tree->end()) {
                if (it->taxon() == 0 || it->taxon()->value() == DBL_MAX) {
		    		it->value() = chainmeans[i][pos];
		    		++pos;
		    	}
		    	else {
		    		it->value() = it->taxon()->value();
		    	}
		    	++it;
		    }
		    
		    double dhat = -2*logL(chainmeans[i][1], chainmeans[i][2]); // deviance of the mean parameters
		    double pd = -2*chainmeans[i][0] - dhat;					// mean deviance - deviance of the mean
		    double DIC = pd + dhat;
		    double PBIC = -2*chainmeans[i][0] + 2 * pd;
		    
			ofs << DIC << "\t" << PBIC << "\t" << pd << "\t" << dhat << "\t";
			std::cout << "Chain " << i+1 << ": DIC = " << DIC << "; PBIC = " << PBIC << std::endl;
		}
		
		count = 0;
		for (int i = 0; i < chains.size(); ++i) {
			std::stringstream ss;
			ss << outPath << ".chain" << i+1 << ".log";
			std::ifstream ifs(ss.str().data());
			if (ifs.is_open()) {
				std::string line;
			  	count = 0;
			  	if (ifs.good()) {
			  		getline (ifs,line);
			  	}
			  	while (count < (stage*b)/2 && ifs.good()) {
			  		getline (ifs,line);
			  		++count;
			  	}
			  	count = 0;
			  	char * ptr = 0;
				while ( ifs.good() ) {
					getline (ifs,line);
				  	if (line.size() > 0) {
				  		++count;
				  		std::vector<std::string> split;
				  		boost::split(split, line, boost::is_any_of("\t\n"), boost::token_compress_on);
				  		for (int k = 0; k < chainmeans[i].size(); ++k) {
				  			double d = strtod(split[k+1].data(), &ptr);
				  			chainSumSqr[i][k] += (d-chainmeans[i][k])*(d-chainmeans[i][k]);
				  		}
				  	}
				}
				ifs.close();
			}
		    
		}
		
		double maxrhat = 0;
		for (int param = 1; param<chainmeans[0].size(); ++param) {
			double B_n = 0;
			double W = 0;
			for (int i = 0; i < chains.size(); ++i) {
				double grandmean = 0;
				for (int k = 0; k < chains.size(); ++k) {
					grandmean += chainmeans[k][param]/chains.size();
				}
				B_n += (chainmeans[i][param]-grandmean)*(chainmeans[i][param]-grandmean);
				W += chainSumSqr[i][param];
			}
			B_n /= chains.size()-1;
			W /= ((((double)stage*b)/2)-1)*chains.size();
			double s2plus = (W*((((double)stage*b)/2)-1))/(((double)stage*b)/2) + B_n;
			double rhat = ((s2plus/W)*(chains.size()+1))/chains.size() - ((double)(((double)stage*b)/2)-1)/(chains.size()*(((double)stage*b)/2));
			if (rhat > maxrhat)
				maxrhat=rhat;

		}	
		
		std::cout << "Max PSRF: " << maxrhat << std::endl;
		ofs << maxrhat << std::endl;
		ofs.flush();
		
		//std::cout << (((double)stage*b)/2) << " " << count << std::endl;
	
	
	}
	ofs.close();
	return true;
}
