//  GREKO Chess Engine
//  (c) 2002-2016 Vladimir Medvedev <vrm@bk.ru>
//  http://greko.su

//  learn.cpp: evaluation learning
//  modified: 30-June-2016

#include <math.h>
#include <map>

#include "eval.h"
#include "learn.h"
#include "notation.h"
#include "utils.h"

void Pgn2Fen(const std::string& srcFile, const std::string& dstFile)
{
	FILE* src = fopen(srcFile.c_str(), "rt");
	if (src == NULL)
	{
		out("Can't open file '%s'\n", srcFile.c_str());
		return;
	}

	FILE* dst = fopen(dstFile.c_str(), "wt");
	if (src == NULL)
	{
		out("Can't open file '%s'\n", dstFile.c_str());
		fclose(src);
		return;
	}

	int games = 0;
	char buf[1024];
	std::vector<std::string> lines;
	std::vector<U32> moves;
	std::vector<int> plies;

	Position pos;
	std::string result = "*";
	size_t len = 0;

	while (fgets(buf, sizeof(buf), src))
	{
		if (buf[strlen(buf) - 1] == '\n')
			buf[strlen(buf) - 1] = 0;

		if (buf[0] == '[')
		{
			if (strstr(buf, "[Event"))
			{
				if (!lines.empty())
				{
					len = pos.Ply();
					for (size_t i = 0; i < lines.size() && i < moves.size(); ++i)
					{
						fprintf(dst, "%s result \"%s\"; mv %d; ply %d; len %d;\n",
							lines[i].c_str(), result.c_str(), moves[i], plies[i], int(len));
					}
				}

				lines.clear();
				moves.clear();
				plies.clear();

				pos.SetInitial();
				lines.push_back(pos.Fen());
				result = "*";

				out("Games: %d\r", ++games);
			}
			else if (strstr(buf, "[Result"))
			{
				if (strstr(buf, "1-0"))
					result = "1-0";
				else if (strstr(buf, "0-1"))
					result = "0-1";
				else if (strstr(buf, "1/2-1/2"))
					result = "1/2-1/2";
				else
					result = "*";
			}
		}
		else
		{
			const char* sep = " +#";
			char* tk = strtok(buf, sep);
			while (tk != NULL)
			{
				if (strstr(tk, sep) == NULL)
				{
					Move mv = StrToMove(tk, pos);
					std::string fen = pos.Fen();
					int ply = pos.Ply();

					if (mv && pos.MakeMove(mv))
					{
						if (!mv.Captured() && !pos.InCheck())
						{
							lines.push_back(fen);
							moves.push_back((U32)mv);
							plies.push_back(ply);
						}
					}
				}
				tk = strtok(NULL, sep);
			}
		}
	}

	if (!lines.empty())
	{
		len = pos.Ply();
		for (size_t i = 0; i < lines.size() && i < moves.size() && i < plies.size(); ++i)
		{
			fprintf(dst, "%s result \"%s\"; mv %d; ply %d; len %d;\n",
				lines[i].c_str(), result.c_str(), moves[i], plies[i], int(len));
		}
	}

	out("\n");

	fclose(src);
	fclose(dst);
}

class Functor
{
public:
	virtual double operator() (const std::vector<int>& x) = 0;
};

struct PosInfo
{
	std::string fen;
	double result;
	Move bm;
	int ply;
	int len;
};

class MyFunctor : public Functor
{
public:
	MyFunctor(const std::string& file);
	virtual double operator() (const std::vector<int>& x) { return f(m_training, x); }
	double Validation(const std::vector<int>& x) { return f(m_validation, x); }
	void SetDivisor(double x) { m_divisor = x; }
private:
	double f(const std::vector<PosInfo>& data, const std::vector<int>& x);
	std::vector<PosInfo> m_training;
	std::vector<PosInfo> m_validation;
	double m_divisor;
};

MyFunctor::MyFunctor(const std::string& file)
{
	m_divisor = 150.;

	FILE* src = fopen(file.c_str(), "rt");
	if (src == NULL)
		return;

	printf("Loading positions...\n");
	char buf[1024];
	while (fgets(buf, sizeof(buf), src))
	{
		double result = -1;
		if (strstr(buf, "1-0"))
			result = 1;
		else if (strstr(buf, "1/2-1/2"))
			result = 0.5;
		else if (strstr(buf, "0-1"))
			result = 0;

		if (result != -1)
		{
			PosInfo posInfo;
			posInfo.fen = buf;
			posInfo.result = (float)result;

			char* ptrMv = strstr(buf, "mv");
			if (ptrMv != NULL)
				posInfo.bm = atoi(ptrMv + 3);

			char* ptrPly = strstr(buf, "ply");
			if (ptrPly != NULL)
				posInfo.ply = atoi(ptrPly + 4);

			char* ptrLen = strstr(buf, "len");
			if (ptrLen != NULL)
				posInfo.len = atoi(ptrLen + 4);

			if (Rand32() % 10000 < 8000)
				m_training.push_back(posInfo);
			else
				m_validation.push_back(posInfo);

			if (m_training.size() % 1000 == 0)
				printf("%d / %d\r", int(m_training.size()), int(m_validation.size()));
		}
	}
	printf("Training set: %d\n", int(m_training.size()));
	printf("Validation set: %d\n", int(m_validation.size()));

	fclose(src);
}

double MyFunctor::f(const std::vector<PosInfo>& data, const std::vector<int>& x)
{
	double r = 0;
	int n = 0;
	Position pos;

	InitEval(x);

	for (size_t i = 0; i < data.size(); ++i)
	{
		const PosInfo& pi = data[i];
		if (pos.SetFen(pi.fen))
		{
			EVAL e = Evaluate(pos, -INFINITY_SCORE, INFINITY_SCORE);
			if (pos.Side() == BLACK) e = -e;

			double prediction = 1 / (1 + exp(-e / m_divisor));
			double result = pi.result;
			double sqError = (prediction - result) * (prediction - result);
			double scale = exp(-(pi.len - pi.ply) / 20.);

			r += sqError * scale;
			++n;
		}
	}

	return (n > 0)? sqrt(r / n) : 0;
}

const std::string currentDateTime()
{
	time_t now = time(0);
	struct tm  tstruct;
	char buf[80];
	tstruct = *localtime(&now);
	strftime(buf, sizeof(buf), "%Y-%m-%d %X", &tstruct);
	return buf;
}

void Learn(const std::string& file, int firstParam, int lastParam)
{
	RandSeed32((unsigned int)time(0));

	FILE* src = fopen(file.c_str(), "rt");
	if (src == NULL)
	{
		printf("Can't open file '%s'\n", file.c_str());
		return;
	}
	fclose(src);

	MyFunctor f(file);

	std::vector<int> x;
	LoadVector("weights.txt", x);

	if (SaveVector("weights.old", x))
		out("Old values saved in file 'weights.old'\n");
	else
		out("Failed to save old values in file 'weights.old'\n");

	printf("Start optimization...\n");
	std::string learningLogFile = "learning.log";
	FILE* dst = fopen(learningLogFile.c_str(), "at");
	if (dst == NULL)
	{
		out("Failed to open file '%s'\n", learningLogFile.c_str());
		return;
	}

	double y = f(x);
	double validation = f.Validation(x);
	double validationPrev = validation;

	std::string ts = currentDateTime();

	int lineCounter = 0;
	int iteration = 0;
	bool improved = true;
	int unchanged = 0;

	printf("%d %.12lf %.12lf %s\n", lineCounter, y, validation, ts.c_str());
	fprintf(dst, "%d %.12lf %.12lf %s\n", lineCounter, y, validation, ts.c_str());
	fflush(dst);

	while (improved)
	{
		improved = false;
		++iteration;

		for (int param = firstParam; param <= lastParam; ++param)
		{
			printf("Parameter %d of %d: %s = %d\n", param + 1, int(x.size()), ParamName(param).c_str(), x[param]);
			fflush(dst);

			++unchanged;

			int step = 1;
			while (step > 0)
			{
				std::vector<int> x1 = x;
				x1[param] += step;
				if (x1[param] < x_min[param]) x1[param] = x_min[param];
				if (x1[param] > x_max[param]) x1[param] = x_max[param];
				double y1 = f(x1);

				std::vector<int> x2 = x;
				x2[param] -= step;
				if (x2[param] < x_min[param]) x2[param] = x_min[param];
				if (x2[param] > x_max[param]) x2[param] = x_max[param];
				double y2 = f(x2);

				if (y1 >= y && y2 >= y)
				{
					step /= 2;
					continue;
				}
				else if (y1 < y2)
				{
					x = x1;
					y = y1;
					improved = true;
					unchanged = 0;
				}
				else
				{
					x = x2;
					y = y2;
					improved = true;
					unchanged = 0;
				}

				++lineCounter;
				validation = f.Validation(x);

				std::string ts = currentDateTime();
				printf("%d %.12lf %.12lf %s [%d.%d] %d\n", lineCounter, y, validation, ts.c_str(), iteration, param + 1, x[param]);
				fprintf(dst, "%d %.12lf %.12lf %s [%d.%d] %d\n", lineCounter, y, validation, ts.c_str(), iteration, param + 1, x[param]);
				fflush(dst);

				SaveVector("weights.txt", x);
			}

			if (unchanged >= lastParam - firstParam)
			{
				std::string s = "Break by no progress on learning set";
				printf("%s\n", s.c_str());
				fprintf(dst, "%s\n", s.c_str());
				goto PRINT_RESULTS;
			}
		}

		if (validation < validationPrev)
			validationPrev = validation;
		else
		{
			std::string s = "Break by no progress on validation set";
			printf("%s\n", s.c_str());
			fprintf(dst, "%s\n", s.c_str());
			break;
		}
	}

PRINT_RESULTS:

	char buf[256];
	sprintf(buf, "\n*** NEW SET OF WEIGHTS: ***\n\n");
	printf("%s", buf);
	fprintf(dst, "%s", buf);

	for (int j = 0; j < NUM_WEIGHTS; ++j)
	{
		sprintf(buf, "%s: %d\n", ParamName(j).c_str(), x[j]);
		printf("%s", buf);
		fprintf(dst, "%s", buf);
	}
	sprintf(buf, "\nSaved to file 'weights.txt', old values in file 'weights.old'\n\n");
	printf("%s", buf);
	fprintf(dst, "%s", buf);

	fclose(dst);
}
