#include "chain.h"
#define _USE_MATH_DEFINES

#include <string>
#include <boost/math/special_functions/gamma.hpp>
#include <boost/random/gamma_distribution.hpp>
#include <cfloat>
#include "stabledistribution.h"
#include "global.h"
#include <iostream>
#include <sstream>
#include "chainmanager.h"
#include <iostream>

Chain::Chain(ChainManager * man, int id, bool b, bool bstrict, const std::string & treePath, const std::string & dataPath, const std::string & outputBasePath, double igammashape, double igammascale, long seed)
    : myid(id),
      tree(0),
      taxa(0),
      igamma_shape(igammashape),
      igamma_scale(igammascale),
      stable_alpha(0),
      stable_scale(0),
      rng(seed),
      unif01(rng, boost::uniform_real<double>(0,1)),
      expdist1(rng, boost::exponential_distribution<double>(1.0)),
      lgamshape(boost::math::lgamma(igamma_shape)),
      brownian(b),
      brownianstrict(bstrict),
      cval(0),
      clen(0),
      errorMsg(""),
      manager(man)
{
	setCycleLength(100);
	
	std::ifstream t;
    t.open(treePath.data());
    std::string newick;
    if (t.is_open()) {
        std::string line;
        while (getline(t, line))
            newick += line;
        t.close();
        taxa = new Taxa;
        tree = new Tree(newick, taxa);
        auto it = tree->begin();
        int counter = 0;
        while (it != tree->end()) {
            if (it->taxon() == 0) {
                std::stringstream cs;
                cs << "n" << counter;
                while (taxa->find(cs.str()) != 0) {
                    ++counter;
                    cs.str("");
                    cs << "n" << counter;
                }
                it->setTaxon(cs.str());
                ++counter;
            }
            it->taxon()->value() = it->value() = DBL_MAX;
            ++it;
        }
    } else {
    	errorMsg += "Failed to open tree file.\n";
        return;
    }

    // check tree does not have zero branch lenths
    auto eit = tree->begin();
    while (eit != tree->end()) {
      if (eit->parentEdge() != 0) {
        if (eit->parentEdge()->length() <= 0 && *eit!= tree->root()) {
          errorMsg += "Tree contained zero or negative branch lengths, which are not allowed.";
          return;
        }
      }
      ++eit;
    }

    // load data from file
    std::ifstream df;
    std::vector<std::string> data;
    df.open(dataPath.data());
    if (df.is_open()) {
        std::string line;
        while (getline(df, line))
            data.push_back(line);
        df.close();
        for (unsigned int j = 0; j < data.size(); ++j) {
		    std::vector<std::string> datum;
		    std::istringstream iss(data[j]);
		    std::copy(std::istream_iterator<std::string>(iss),
		         std::istream_iterator<std::string>(),
		         std::back_inserter<std::vector<std::string> >(datum));
		    if (datum.size() == 2) {
		        Tree::Taxon * q = taxa->find(datum[0]);
		        if (q == 0) {
		            errorMsg += "Error while loading data: taxon '";
		            errorMsg += datum[0].data();
		            errorMsg += "' could not be found in the tree.\n";
					return;
		        }
		        char ** c = 0;
		        taxa->find(datum[0])->value() = strtod(datum[1].data(), c);
		    }
		}
    } else {
    	errorMsg += "Failed to open tree file.\n";
        return;
    }
    
    auto testit = tree->begin();
    std::string taxaWarn;
    while (testit != tree->end()) {
        if (testit->isTip() && testit->taxon()->value() == DBL_MAX)
            taxaWarn += "    " + testit->taxon()->label() + "\n";
        ++testit;
    }
    if (taxaWarn.size() > 0) {
        errorMsg += ("Some tips had no character states: \n" + taxaWarn);
        return;
    }
    
    NodeValueAccessor nv;
	TaxonValueAccessor tv;
    PhyCpp::BrownianMotion::ReconstructionResult bm = PhyCpp::BrownianMotion::reconstruct(tree, tv, nv, DBL_MAX);
    if (igamma_scale == 0) {
        igamma_scale = (1.0+igamma_shape) * bm.rate/2;
    }

	// randomly initialize tree
	bool started = false;
    double minval = 9999999;
    double maxval = 0;
    auto mit = tree->begin();
    while (mit != tree->end()) {
        if (mit->taxon()->value() != DBL_MAX) {
            if (started) {
                minval = std::min(minval, mit->taxon()->value());
                maxval = std::max(maxval, mit->taxon()->value());
            } else {
                started = true;
                minval = mit->taxon()->value();
                maxval = mit->taxon()->value();
            }
        }
        ++mit;
    }
    
    auto it2 = tree->rbegin();
    while (it2 != tree->rend()) {
        if (it2->taxon()->value() == DBL_MAX) {
            it2->value() = minval + unif01()*(maxval-minval);
        }
        ++it2;
    }
   
   igamma_mode = sqrt(bm.rate/2);
    
    // randomly initialize parameters
    do {
		stable_alpha = (brownian || brownianstrict) ? 2 : 1 + unif01();
		boost::gamma_distribution<double> gam(igamma_shape);
		boost::variate_generator<boost::mt19937, boost::gamma_distribution<double> > gamvgen(rng, gam);
		stable_scale = brownianstrict ? igamma_mode : pow(igamma_scale/gamvgen(), 1.0/stable_alpha);
	} while (logL() == -INFINITY);
	 
	std::stringstream ss;
	ss << outputBasePath << ".chain" << (1+myid) << ".log";
	output.open(ss.str().data());
	output << "Gen\tLogL\talpha\tscale\t";
	auto it = tree->begin();
	while (it != tree->end()) {
		if (it->taxon()->value() == DBL_MAX)
		output << it->taxon()->label() << "\t";
		++it;
	}
	output << std::endl;
    output << 0 << "\t" << logL() << "\t" << stable_alpha << "\t" << stable_scale << "\t";
    it = tree->begin();
    int maxcon = 0;
	while (it != tree->end()) {
		int thiscon = 0;
		Tree::Node * n = it->leftChild();
		while (n != 0) {
			++thiscon;
			n = n->rightSibling();
		}
		maxcon = std::max(thiscon, maxcon);
		if (it->taxon()->value() == DBL_MAX)
		output << it->value() << "\t";
		++it;
	}
	output << std::endl;

	cval = new double[maxcon + 1];
	clen = new double[maxcon + 1];
	
	if (myid == 0) {
		std::ofstream treeoutput;
		treeoutput.open((outputBasePath + ".tree").data());
		treeoutput << tree->toString();
		treeoutput.close();
	
		std::ofstream taxoutput;
		taxoutput.open((outputBasePath + ".fixed_taxa").data());
		it = tree->begin();
		while (it != tree->end()) {
			if (it->taxon()->value() != DBL_MAX)
				taxoutput << it->taxon()->label() << "\t" << it->taxon()->value() << std::endl;
			++it;
		}
		taxoutput.close();

		std::ofstream proutput;
		proutput.open((outputBasePath + ".prior").data());
		proutput << "igamma_shape " << igamma_shape << std::endl;
		proutput << "igamma_scale " << igamma_scale << std::endl;
		proutput.close();
	}
}

Chain::~Chain()
{
    if (tree != 0)
        delete tree;
    if (taxa != 0)
        delete taxa;
    if (cval != 0)
    	delete [] cval;
    if (clen != 0)
    	delete [] clen;
}

double Chain::logL() const
{   
    double result = 0;
    Tree::Node::iterator it = tree->begin();
    ++it;
    StableDistribution sd(stable_alpha, stable_scale);
    while (it != tree->end()) {
       result += sd.logPDF(it->value() - it->parent()->value(), it->parentEdge()->length());
       ++it;
    }
    double x = POW(stable_scale, stable_alpha);
    result +=  -(igamma_scale/x) + igamma_shape * LOG(igamma_scale/x) - LOG(x) - lgamshape;
    return result;
}

double Chain::logCondL(double x, int numCons)
{
    double result = 0;
    StableDistribution dist(stable_alpha, stable_scale);
    for (int i = 0; i < numCons; ++i)
        result += dist.logPDF(cval[i] - x, clen[i]);
    return result;
}

double Chain::mcmcSampleAlpha(double logP)
{
	double curVal = stable_alpha;
	do {
		setAlpha(curVal + unif01()*0.01 - 0.005);	//todo
	}	while (stable_alpha > 2 || stable_alpha < 1);
	double L2 = logL();
	if (unif01() < exp(L2-logP)) {
		return L2;
	} else {
		setAlpha(curVal);
		return logP;
	}
}	

double Chain::mcmcSampleScale( double logP)
{ 	
	double curVal = stable_scale;
	do {
	setScale(curVal + unif01()*(igamma_mode/10) - (igamma_mode/20)); //todo 
	} while (stable_scale <= 0);
	double L2 = logL();
	if (unif01() < exp(L2-logP)) {
		return L2;
	} else {
		setScale(curVal);
		return logP;
	}
}

double Chain::sliceSampleTraitValue(double curVal, double logP, int numCons)
{
	double result = 0;

    // get lower and upper bounds on the region of interest
    double mins = cval[0];
    double maxs = cval[0];
    for (int i = 1; i < numCons; ++i) {
    	mins = std::min(mins, cval[i]);
    	maxs = std::max(maxs, cval[i]);
    }
    if (maxs == mins) {
        if (maxs != 0) {
            maxs*= 2;
            mins/= 2;
        } else {
            maxs = 1;
            mins = -1;
            std::cout << "Error 7f\n";
        }
    }
    while (logCondL(mins, numCons) > logP)
        mins -= (maxs-mins);
    while (logCondL(maxs, numCons) > logP)
        maxs += (maxs-mins);

    // attempt to find a new values by random sampling within the broad approximate slice
    int maxRoughAttempts = 500;
    double F = logP - 1;
    int guessCount = 0;
    while (F < logP && guessCount < maxRoughAttempts) {
        result = mins + unif01() * (maxs - mins);
        F = logCondL(result, numCons);
        ++guessCount;
    }

    // if we failed to get anything with the quick and dirty method, use newton raphson slow but accurate method
    if (guessCount >= maxRoughAttempts) {
        result = curVal;
 //       result = uniformFromRanges(r, NRbracketL(curVal, logP, connectedVals, connectedLens, alpha, scale, mins, maxs));
    }
    return result;
}

void Chain::run(int pos, int iters)
{
	for (int i = 0; i < iters; ++i) {
       	 auto it = tree->rbegin();
         while (it != tree->rend()) {
             if (it->taxon()->value() == DBL_MAX) {
                 int qq = 0;
                 Tree::Node * n = it->leftChild();
                 while (n != 0) {
                     cval[qq] = (n->value());
                     clen[qq] = (n->parentEdge()->length());
                     n = n->rightSibling();
                     ++qq;
                 }
                 if (*it != tree->root()) {
                     cval[qq] = (it->parent()->value());
                     clen[qq] = (it->parentEdge()->length());
                     ++qq;
                 }
                 double logp = logCondL(it->value(), qq);
                 double Q = expdist1();
                 while (Q != Q || Q == std::numeric_limits<double>::infinity())
                     Q = expdist1();
                 it->value() = sliceSampleTraitValue(it->value(), logp - Q, qq);
             }
             ++it;
        }
       double LL = logL();
       if (!brownian && !brownianstrict) {
            LL = mcmcSampleAlpha(LL);
       }
       if (!brownianstrict) 
        LL = mcmcSampleScale(LL);
        if (pos+i != 0 && (pos+i)%cycleLength == 0) {
			output << (pos+i) << "\t" << LL << "\t" << stable_alpha << "\t" << stable_scale;
			auto oit = tree->begin();
			while (oit != tree->end()) {
				if (oit->taxon()->value() == DBL_MAX) {
					output << "\t" << oit->value();
				}
				++oit;
			}
			output << std::endl;
			manager->reportProgress(myid, pos+i, LL);
		}
	}
}
