#define _USE_MATH_DEFINES
#include <cmath>
#include <string>
#include <iostream>
#include "../tclap/CmdLine.h"
#include <fstream>
#include "../treedef.h"
#include <sstream>
#include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/classification.hpp>
#include <cfloat>
#include <algorithm>
#include <boost/math/distributions/normal.hpp>
#include <set>
#include <boost/random/mersenne_twister.hpp>
#include <ctime>
#include <boost/random/uniform_real.hpp>
#include <boost/random/variate_generator.hpp>
#include <boost/generator_iterator.hpp>


int main(int argc, char** argv)
{

	Taxa * taxa = 0;
	Tree * tree = 0;
	std::cout << std::endl;
	
	boost::mt19937 rng((long int)time(0));
	boost::uniform_real<double> dist(-1,1);
	boost::variate_generator<boost::mt19937, boost::uniform_real<double> > vargen(rng, dist);	
	
	try {  

		TCLAP::CmdLine cmd("StableTraitsSum", ' ', "1.4");
		
        TCLAP::ValueArg<std::string> pathArg("p", "path", "Path to folder containing output files, with base name of output files appended", true, "", "string", cmd);
		TCLAP::ValueArg<int> fromArg("f", "from", "Generation at which analysis begins (i.e. size of burnin in generations, not samples)", false, 0, "integer", cmd);
		TCLAP::ValueArg<int> toArg("t", "to", "Generation at which analysis ends (set to zero to use all samples)", false, 0, "integer", cmd);
		TCLAP::ValueArg<double> hillArg("c", "hillclimblimit", "Hill climbing optimization of branch lengths will end when the improvement in fit across the tree as a whole does not exceed this value for <iterlimit> iterations", false, 0.001, "double", cmd);
		TCLAP::ValueArg<int> itersArg("i", "iterlimit", "Hill climbing optimization of branch lengths will end when the improvement in fit across the tree as a whole does not exceed <hillclimblimit> for this number of iterations", false, 25, "integer", cmd);
	TCLAP::SwitchArg bScaleFlag("b","brownian","Try to find a tree for which the Brownian reconstruction is the same as the stable distribution reconstruction on the original tree", cmd, false);
		
		cmd.parse( argc, argv );
		
		std::ifstream treefile((pathArg.getValue() + ".tree").data());
		if (treefile) {
			std::string newick;
			std::string line;
		    while (getline(treefile, line))
		        newick += line;
		    treefile.close();
		    taxa = new Taxa;
		    tree = new Tree(newick, taxa);
		    std::cout << "Loaded tree from '" << (pathArg.getValue()+".tree") << "'" << std::endl;
		} else {
			std::cout << "Error - could not find tree at '" << (pathArg.getValue()+".tree") << "'" << std::endl;
			return -1;
		}
		
		std::set<Tree::Node*> fixedNodes;
		std::ifstream datafile((pathArg.getValue() + ".fixed_taxa").data());
		if (datafile) {
			auto it = tree->begin();
			while (it != tree->end()) {
				it->taxon()->setValue(DBL_MAX);
				++it;
			}
			std::string line; 
			char * ptr = 0;
			while (datafile.good()) {
				getline(datafile, line);
				if (line.size() > 0) {
					std::vector<std::string> split;
			  		boost::split(split, line, boost::is_any_of("\t\n"), boost::token_compress_on);
			  		taxa->find(split[0])->setValue(strtod(split[1].data(), &ptr));
			  		fixedNodes.insert(tree->node(split[0]));
				}
			}
			
		} else {
			std::cout << "Error - could not find data at '" << (pathArg.getValue()+".fixed_taxa") << "'" << std::endl;
			return -1;
		}
		
		
		NodeValueAccessor nv;
		TaxonValueAccessor tv;
    	PhyCpp::BrownianMotion::ReconstructionResult bm = PhyCpp::BrownianMotion::reconstruct(tree, tv, nv, DBL_MAX);
		
		std::vector<std::vector<double> > values;
		values.push_back(std::vector<double>()); //stable alpha
		values.push_back(std::vector<double>()); //stable scale
		auto it = tree->begin();
		while (it != tree->end()) {
			if (it->taxon()->value() == DBL_MAX)
				values.push_back(std::vector<double>());
			++it;
		}
		
		
		std::string chainpath = pathArg.getValue() + ".chain1.log";
		std::ifstream * chainfile = new std::ifstream(chainpath.data());
		int pos = 1;
		while (*chainfile) {
		
			std::cout << "Reading sample from '" << chainpath << "'..." << std::endl; ;
		
			char * ptr = 0;
			std::string line;
			getline (*chainfile,line);
			//int count = 0;
			while (chainfile->good()) {
				getline(*chainfile, line);
				if (line.size() > 0) {
					std::vector<std::string> split;
			  		boost::split(split, line, boost::is_any_of("\t\n"), boost::token_compress_on);
			  		double s = strtod(split[0].data(), &ptr);
			  		if (s < fromArg.getValue())
			  			continue;
			  		if (toArg.getValue() > 0 && s > toArg.getValue())
			  			 break;
			  		for (int k = 0; k < values.size(); ++k) {
			  			double d = strtod(split[k+2].data(), &ptr);
			  			values[k].push_back(d);
			  		}
				}
				//if (count%1000 == 0)
				//std::cout << ".";
				//std::cout.flush();
				//++count;
			}
		
			++pos;
			std::stringstream ss;
			ss << pathArg.getValue() << ".chain" << pos << ".log";
			chainpath = ss.str();
			delete chainfile;
			chainfile = new std::ifstream(chainpath.data());
		}
		delete chainfile;
		
		std::cout << "Calculating median ancestral states and 95% credible intervals..." << std::endl;
		
		std::ofstream medout((pathArg.getValue() + ".ancstates").data());
		medout << "Parameter\tBrownian\tMedian\t95%CI_Low\t95%CI_High\n";
		
		std::sort(values[0].begin(), values[0].end());
		double median = (values[0].size()%2 == 0) ? (values[0][values[0].size()/2-1] + values[0][values[0].size()/2]) / 2 : values[0][values[0].size()/2];
		int margin = (int)(values[0].size() * 0.025);
		medout << "stable_alpha" << "\t" << "2\t" << median << "\t" << values[0][margin] << "\t" << values[0][values[0].size()-margin-1] << std::endl;

		std::sort(values[1].begin(), values[1].end());
		median = (values[1].size()%2 == 0) ? (values[1][values[1].size()/2-1] + values[1][values[1].size()/2]) / 2 : values[1][values[1].size()/2];
		margin = (int)(values[1].size() * 0.025);
		medout << "stable_scale" << "\t" << sqrt(bm.rate/2) << "\t" << median << "\t" << values[1][margin] << "\t" << values[1][values[1].size()-margin-1] << std::endl;

		it = tree->begin();
		pos = 2;
		while (it != tree->end()) {
			if (it->taxon()->value() == DBL_MAX) {
				std::sort(values[pos].begin(), values[pos].end());
				double median = (values[pos].size()%2 == 0) ? (values[pos][values[pos].size()/2-1] + values[pos][values[pos].size()/2]) / 2 : values[pos][values[pos].size()/2];
				it->taxon()->setValue(median);
				int margin = (int)(values[pos].size() * 0.025);
				medout << it->taxon()->label() << "\t" << it->value() << "\t" << median << "\t" << values[pos][margin] << "\t" << values[pos][values[pos].size()-margin-1] << std::endl;
				++pos;
			} else {
				medout << it->taxon()->label() << "\t" << it->taxon()->value() << "\t" << it->taxon()->value() << "\t" << it->taxon()->value() << "\t" << it->taxon()->value() << std::endl;
			}
			
			++it;
		}
		
		medout.close();
		
		std::cout << "Saved to " << pathArg.getValue() << ".ancstates" << std::endl;
		
		
		
		
		std::vector<double> origLengths;
		std::vector<double> nodeHeight;
		it = tree->begin();
		++it;
		double L = 0;
		while (it != tree->end()) {
			double h = 0;
			Tree::Node * q = *it;
			while (q->parent() != 0) {
				h += q->parentEdge()->length();
				q = q->parent();
			}
			h -= it->parentEdge()->length()/2;
			nodeHeight.push_back(h);
			origLengths.push_back(it->parentEdge()->length());
			L += it->parentEdge()->length();
			++it; 
		}
		
		std::vector<double> newLengths;
		if (bScaleFlag.getValue()) {
		
			it = tree->begin();
			std::vector<double> medValues;
			while (it != tree->end()) {
				medValues.push_back(it->taxon()->value());
				if (fixedNodes.find(*it) == fixedNodes.end()) 
					it->taxon()->setValue(DBL_MAX);
				++it;
			}
			bm = PhyCpp::BrownianMotion::reconstruct(tree, tv, nv, DBL_MAX);
			it = tree->begin();
			double d = 0;
			pos =0;
			while (it != tree->end()) {
				d += fabs(it->value()-medValues[pos]);
				++it;
				++pos;
			}

		
			std::cout << "Optimizing branch lengths to make the Brownian motion ancestral state reconstruction match the stable reconstruction..." << std::endl;
			int noChange = 0;
		
			do {
				double improvement = 0;
				auto kit = tree->begin();
				++kit;
				while (kit != tree->end()) {
				
					double oldLen = kit->parentEdge()->length();
					kit->parentEdge()->setLength(kit->parentEdge()->length() + kit->parentEdge()->length()*vargen()/20);
				
					bm = PhyCpp::BrownianMotion::reconstruct(tree, tv, nv, DBL_MAX);
					double D = 0;
					pos = 0;
					it = tree->begin();
					while (it != tree->end()) {
						D += fabs(it->value()-medValues[pos]);
						++it;
						++pos;
					}
			
					if (D >= d) {

						kit->parentEdge()->setLength(oldLen);
					} else {
						improvement += d - D;
						d = D;
					}
					++kit;
				}
				
				double dd = 0;
				it = tree->begin();
				pos = 0;
				while (it != tree->end()) {
					dd += fabs(it->value()-medValues[pos]);
					++it;
					++pos;
				}
				
				
				if (improvement <= hillArg.getValue()) 
					++noChange;
				else
					noChange = 0; 
				

			
				std::cout << "Improvement: " << improvement << "; total error: " << dd;
				if (noChange > 0)
					std::cout << " (No change for " << noChange << " iterations)";
				std::cout << std::endl;
			
			} while (noChange < itersArg.getValue());
		
		
		
		
		
		
		
		
		
		
		
		
		
		
			double L2 = 0;
			it = tree->begin();
			++it;
			while (it != tree->end()) {
				L2 += it->parentEdge()->length();
				++it;
			}
			it = tree->begin();
			++it;
			while (it != tree->end()) {
				it->parentEdge()->setLength(it->parentEdge()->length() * L/L2);
				newLengths.push_back(it->parentEdge()->length());
				++it;
			}
		
			std::ofstream treeout((pathArg.getValue() + ".scaled_tree").data());
			treeout << tree->toString();
			treeout.close();
			std::cout << "Saved scaled tree to '" << pathArg.getValue() << ".scaled_tree'" << std::endl;

		}
		
		
		std::vector<double> rateLengths;
		it = tree->begin();
		++it;
		double L2 = 0;
		pos = 0;
		while (it != tree->end()) {
			double x = fabs(it->taxon()->value() - it->parent()->taxon()->value())/sqrt(origLengths[pos]);
			it->parentEdge()->setLength(x); 
			L2 += x;
			++it; 
			++pos;
		}
		it = tree->begin();
		++it;
		while (it != tree->end()) {
			it->parentEdge()->setLength(it->parentEdge()->length() * L/L2);
			rateLengths.push_back(it->parentEdge()->length());
			++it;	
		}
		
		std::ofstream treeout((pathArg.getValue() + ".rates_tree").data());
		treeout << tree->toString();
		treeout.close();
		std::cout << "Saved rates tree to '" << pathArg.getValue() << ".ratestree'" << std::endl;
		
		std::ofstream branchOut((pathArg.getValue()+".brlens").data());
		it = tree->begin();
		++it;
		pos = 0;
		branchOut << "Branch\tBranch_midpoint_height\tOrig_len\tRate";
		if (bScaleFlag.getValue()) 
			branchOut << "\tScaled_len";
		branchOut << std::endl;
			
		while (it != tree->end()) {
			branchOut << it->taxon()->label() << "->" << it->parent()->taxon()->label() << "\t" << nodeHeight[pos] << "\t"
				<< origLengths[pos] << "\t" << rateLengths[pos];
				if (bScaleFlag.getValue()) {
					branchOut << "\t" << newLengths[pos];
				}
			branchOut << std::endl;
			++it;
			++pos;
		}
		std::cout << "Saved branch length data to '" << pathArg.getValue() << ".brlens'" << std::endl;
		
		
		it = tree->begin();
		++it;
		L2 = 0;
		while (it != tree->end()) {
			double x = pow(it->taxon()->value() - it->parent()->taxon()->value(),2);
			it->parentEdge()->setLength(x); 
			L2 += x;
			++it; 
		}
		it = tree->begin();
		++it;
		while (it != tree->end()) {
			it->parentEdge()->setLength(it->parentEdge()->length() * L/L2);
			++it;	
		}
		
		std::ofstream sqtreeout((pathArg.getValue() + ".sqchange_tree").data());
		sqtreeout << tree->toString();
		sqtreeout.close();
		std::cout << "Saved squared change tree to '" << pathArg.getValue() << ".sqchange_tree'" << std::endl;


	} catch (TCLAP::ArgException &e) { 
		std::cerr << "Error: " << e.error() << " for command line argument " << e.argId() << std::endl;
		return -1;
	}
	
	
	
	delete taxa;
	delete tree;
	std::cout << std::endl;

	return 0;	
}
