/* Copyright (C) 2010 Dustin Cartwight
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <config.h>
#include "solve.h"
#include "deriv.h"
#include <float.h>
#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#if HAVE_PTHREAD
#include <pthread.h>
#else
typedef void *pthread_mutex_t;
#define pthread_mutex_lock(a) (void) 0
#define pthread_mutex_unlock(a) (void) 0
#endif

#define MAX_EM_STEPS 10000000
#define MAX_IPF_STEPS 10000
#define COMPARE_FREQ 100
#define MAX_IPF_THRESH 1.01
#define MIN_IPF_THRESH (1.0 + 5e-13)

typedef struct {
    problem_t       *problem;
    sol_set_t       *sol;
    solve_params_t  *params;
    int             iter;
    int             num_threads; /* number of threads currently running */
#if HAVE_PTHREAD
    pthread_cond_t  done;
#endif
    pthread_mutex_t mutex;

    int             num_diverged;
    int             num_negative;
} thread_data_t;

void *solve_thread(void *p);
void solve_once(problem_t *problem, sol_set_t *sols, pthread_mutex_t *mutex,
        solve_params_t *params);
void em_step(problem_t *problem, double *x, double ipf_thresh);
void ipf(problem_t *problem, double *term_exps, double *x, double thresh);
double *compute_monoms(problem_t *problem, double *x);
double ipf_step(problem_t *problem, double *suff, double *x);
void make_coeff_totals(problem_t *problem);
void renormalize(problem_t *problem, double *x);

int boost_converg(problem_t *p, double *x, double *old_x);

void solve_once_newton(problem_t *problem, sol_set_t *sols,
        thread_data_t *data, solve_params_t *params);

double *rand_vector(int len);
double kl_diverg(problem_t *problem, double *sol);
int compare_found_solns(double *x, int numvars, sol_set_t *sol,
        double granularity, int num_solns);
double diff_solns(double *x, double *y, int numvars);

/* Set up params structure with the default parameters */
void init_solve_params(solve_params_t *params)
{
    params->num_threads = 1;
    params->iters_per_soln = 10;
    params->max_num_solns = 8;
    params->granularity = 0.05;
    params->thresh = params->granularity / 100.0;
    params->progress = FALSE;
    params->boost = TRUE;
    params->newton = FALSE;
}

sol_set_t *solve(problem_t *problem, solve_params_t *params)
{
    thread_data_t   data;

    if (problem->coeff_totals == NULL)
        make_coeff_totals(problem);

    data.sol = (sol_set_t *)malloc(sizeof(sol_set_t));
    data.sol->num_solns = 0;
    data.sol->x = (double *)malloc(params->max_num_solns * problem->numvars
            * sizeof(double));
    data.sol->descr = (sol_descr_t *)malloc(params->max_num_solns *
            sizeof(sol_descr_t));

    if (problem->rank <= 0) {
        fprintf(stderr, "Error: Can't solve equations without proper homogeneity\n");
        return data.sol;
    }

    if (params->newton && problem->numvars != problem->numeqns) {
        fprintf(stderr, "Error: Can't use Newton's method when the number of ");
        fprintf(stderr, "variables is not equal to\n");
        fprintf(stderr, "       the number of equations.\n");
        return data.sol;
    }

    data.problem = problem;
    data.params = params;
    data.iter = 0;
    data.num_threads = 1;
#if HAVE_PTHREAD
    pthread_mutex_init(&data.mutex, NULL);
    pthread_cond_init(&data.done, NULL);
#else
    if (params->num_threads > 1)
        fprintf(stderr, "Warning: Using a single thread because threads were not enabled at runtime\n");
    data.mutex = NULL;
#endif
    data.num_diverged = 0;
    data.num_negative = 0;

    solve_thread(&data);

#if HAVE_PTHREAD
    /* wait for all threads to finish */
    pthread_mutex_lock(&data.mutex);
    while (data.num_threads > 0)
        pthread_cond_wait(&data.done, &data.mutex);

    pthread_mutex_unlock(&data.mutex);
    pthread_mutex_destroy(&data.mutex);
    pthread_cond_destroy(&data.done);
#endif

    if (params->progress)
        fprintf(stderr, "\n");

    if (params->newton) {
        printf("%d diverged, %d negative\n", data.num_diverged,
                data.num_negative);
    }
    
    return data.sol;
}

/* Repeatedly finds solutions in a single thread. */
void *solve_thread(void *p)
{
    thread_data_t   *data = (thread_data_t *)p;

    pthread_mutex_lock(&data->mutex);
    do {
        data->iter++;

#if HAVE_PTHREAD
        /* Replenish number of threads if it has dropped below what we want */
        while (data->num_threads < data->params->num_threads) {
            pthread_t   thr;

            pthread_create(&thr, NULL, solve_thread, data);
            pthread_detach(thr);
            data->num_threads++;
        }

        pthread_mutex_unlock(&data->mutex);
#endif
        if (data->params->newton)
            solve_once_newton(data->problem, data->sol, data, data->params);
        else
            solve_once(data->problem, data->sol, &data->mutex, data->params);
        if (data->params->progress)
            fprintf(stderr, "\rFound %d solutions, %d/%d iterations",
                    data->sol->num_solns, data->iter - data->num_threads + 1,
                    data->sol->num_solns * data->params->iters_per_soln);
    } while (data->iter < data->params->iters_per_soln * data->sol->num_solns
              && data->sol->num_solns < data->params->max_num_solns);
    
#if HAVE_PTHREAD
    if (--data->num_threads <= 0) {
        pthread_cond_signal(&data->done);
    }
    pthread_mutex_unlock(&data->mutex);
#endif
    return NULL;
}

/* Finds one solution to the problem and checks to see if it has been found
 * before. It locks mutex before looking at the other solutions and returns with
 * mutex still locked. */
void solve_once(problem_t *problem, sol_set_t *sol, pthread_mutex_t *mutex,
        solve_params_t *params)
{
    int     i, j;
    double  *old_x = (double *)malloc(problem->numvars * sizeof(double));
    double  ipf_thresh = MAX_IPF_THRESH;
    int     type;
    double  diverg;
    int     ndx;
    double  *x = rand_vector(problem->numvars);

    i = 0;
    for (;;) {
        int     num_solns;
        double  diff = 0.0;

        for (j = 0; j < COMPARE_FREQ; i++, j++) {
            memcpy(old_x, x, problem->numvars * sizeof(double));
            em_step(problem, x, ipf_thresh);
            diff = diff_solns(old_x, x, problem->numvars);

            if (params->boost)
                boost_converg(problem, x, old_x);

            ipf_thresh = 1.0 + diff * diff;
            if (ipf_thresh > MAX_IPF_THRESH)
                ipf_thresh = MAX_IPF_THRESH;
            if (ipf_thresh < MIN_IPF_THRESH)
                ipf_thresh = MIN_IPF_THRESH;

        }

        if (diff < params->thresh / 100.0
                && est_dist_min(problem, x) < params->thresh)
            break;

        /* At some point, we decide it's taking too long and give up */
        if (i >= MAX_EM_STEPS) {
            fprintf(stderr, "Maximum number of iterations achieved\n");
            break;
        }

        /* lock to get a consistent picture of other solutions */
        pthread_mutex_lock(mutex);
        num_solns = sol->num_solns;
        pthread_mutex_unlock(mutex);

        ndx = compare_found_solns(x, problem->numvars, sol, params->granularity,
                    num_solns);
        if (ndx >= 0) {
            /* found repeat of an old solution */
            free(x);
            free(old_x);
            pthread_mutex_lock(mutex);
            sol->descr[ndx].count++;
            return;
        }
    }

    type = is_exact(problem, x);
    diverg = type == TYPE_EXACT ? 0.0 : kl_diverg(problem, x);
    pthread_mutex_lock(mutex);

    /* Last check to see if we've already found this solution. We keep the lock
     * because we don't want another thread to add the same solution after we've
     * started comparing */
    ndx = compare_found_solns(x, problem->numvars, sol, params->granularity,
            sol->num_solns);

    if (ndx >= 0)
        sol->descr[ndx].count++;
    else if (sol->num_solns < params->max_num_solns) {
        /* New solution. If there's no space, we just abandon it */
        memcpy(sol->x + sol->num_solns * problem->numvars, x,
                problem->numvars * sizeof(double));
        sol->descr[sol->num_solns].count = 1;
        sol->descr[sol->num_solns].type = type;
        sol->descr[sol->num_solns].diverg = diverg;
        sol->num_solns++;
    }

    free(old_x);
    free(x);
}

void em_step(problem_t *problem, double *x, double ipf_thresh)
{
        /* value of each monomial, i.e. without coefficient */
    double  *monom_val = compute_monoms(problem, x);
        /* values of terms of current equation */
    double  *term_val = (double *)malloc(problem->numterms * sizeof(double));
        /* expected contribution of each term */
    double  *term_exps = (double *)malloc(problem->numterms * sizeof(double));
    int     i, j;

    for (i = 0; i < problem->numterms; i++)
        term_exps[i] = 0.0;

    for (i = 0; i < problem->numeqns; i++) {
        double  total_eqn = 0.0; /* total value of terms in this equation */

        for (j = 0; j < problem->numterms; j++) {
            double  val = problem->coeffs[i*problem->numterms+j] * monom_val[j];
            term_val[j] = val;
            total_eqn += val;
        }

        for (j = 0; j < problem->numterms; j++)
            term_exps[j] += problem->rhs[i] * term_val[j] / total_eqn;
    }

    free(term_val);
    free(monom_val);

    ipf(problem, term_exps, x, ipf_thresh);
    free(term_exps);
}

double *compute_monoms(problem_t *problem, double *x)
{
    double  *monom_val = (double *)malloc(problem->numterms * sizeof(double));
    int     i, j, k;

    for (i = 0; i < problem->numterms; i++) {
        double  val = 1.0;

        for (j = 0; j < problem->numvars; j++) {
            int     exp = problem->exps[i * problem->numvars + j];
            double  var = x[j];

            for (k = 0; k < exp; k++)
                val *= var;
        }
        monom_val[i] = val;
    }

    return monom_val;
}

/* Uses IPF to compute an approximate solution to a monomial system of
 * equations. Each coordinate of the sufficient statistics will be within a
 * factor of thresh > 1.0 */
void ipf(problem_t *problem, double *term_exps, double *x, double thresh)
{
        /* sufficient statistic */
    double  *suff = (double *)malloc(problem->rank * problem->numvars
                                        * sizeof(double));
    int     i, j, k;
    double  max_ratio;

    for (i = 0; i < problem->rank; i++) {
        for (j = 0; j < problem->numvars; j++) {
            int ndx = i * problem->numvars + j;
            suff[ndx] = 0.0;
            for (k = 0; k < problem->numterms; k++) {
                suff[ndx] += problem->exps[k * problem->numvars + j]
                        * problem->weights[ndx] * term_exps[k];
            }
        }
    }

    if (thresh < 1.0 + 5e-13)
        thresh = 1.0 + 5e-13;

    /* Find approximate solution to monomial system of equations with
     * coefficients coeff_totals and right hand sides term_exps */
    for (i = 0; i < MAX_IPF_STEPS; i++) {
        max_ratio = ipf_step(problem, suff, x);
        if (max_ratio < thresh)
            break;
    }

#if 0
    if (i > 1000)
        fprintf(stderr, "%d IPF iterations to %g\n", i + 1, max_ratio - 1.0);
#endif

    free(suff);
}

/* One step of the IPF algorithm */
double ipf_step(problem_t *problem, double *suff, double *x)
{
    double  *term_est = (double *)malloc(problem->numterms * sizeof(double));
    int     i, j, k, l;
    double  max_ratio = 1.0;

    for (i = 0; i < problem->rank; i++) {
        /* Compute values of each term for current value of vector x */
        for (j = 0; j < problem->numterms; j++) {
            term_est[j] = problem->coeff_totals[j];
            for (k = 0; k < problem->numvars; k++) {
                for (l = 0; l < problem->exps[j * problem->numvars + k]; l++)
                    term_est[j] *= x[k];
            }
        }

        for (j = 0; j < problem->numvars; j++) {
            int     ndx = i * problem->numvars + j;
            double  suff_est = 0.0;
            double  ratio;

            if (x[j] <= DBL_MIN * 256) {
                /* Instability arises as a coordinate approaches zero */
                x[j] = 0.0;
                continue;
            }

            if (problem->weights[ndx] == 0)
                continue;
            for (k = 0; k < problem->numterms; k++) {
                suff_est += problem->exps[k*problem->numvars + j] * term_est[k];
            }
            suff_est *= problem->weights[ndx];
            ratio = suff[ndx] / suff_est;
            if (problem->degs[i] == problem->weights[ndx])
                x[j] *= ratio;
            else
                x[j] *= pow(ratio, ((double) problem->weights[ndx])
                        / problem->degs[i]);

            if (ratio > max_ratio)
                max_ratio = ratio;
            else if (ratio * max_ratio < 1.0)
                max_ratio = 1.0 / ratio;
        }
    }

    free(term_est);
    return max_ratio;
}

void make_coeff_totals(problem_t *problem)
{
    int     i, j;

    problem->coeff_totals = (double*)malloc(problem->numterms * sizeof(double));

    for (i = 0; i < problem->numterms; i++) {
        problem->coeff_totals[i] = 0.0;
        for (j = 0; j < problem->numeqns; j++) {
            problem->coeff_totals[i] += problem->coeffs[j*problem->numterms + i];
        }
    }
}

int boost_converg(problem_t *p, double *arg_x, double *old_x)
{
    double  *ratios = (double *)malloc(p->numvars * sizeof(double));
    double  kl = kl_diverg(p, arg_x);
    double  *x = (double *)malloc(p->numvars * sizeof(double));
    double  *new_x = old_x;
    int     i;
    int     cnt;

    for (i = 0; i < p->numvars; i++)
        ratios[i] = old_x[i] == 0.0 ? 1.0 : arg_x[i] / old_x[i];

    memcpy(x, arg_x, p->numvars * sizeof(double));

    for (cnt = 0; cnt < 1000; cnt++) {
        /*double  *tmp;*/
        double  new_kl;

        for (i = 0; i < p->numvars; i++)
            new_x[i] = x[i] * ratios[i];

        new_kl = kl_diverg(p, new_x);
        if (new_kl >= kl)
            break;

        if (cnt > 0)
            memcpy(arg_x, x, p->numvars * sizeof(double));
        memcpy(x, new_x, p->numvars * sizeof(double));

        kl = new_kl;
/*        tmp = new_x;
        new_x = x;
        x = tmp;*/
    }

/*    if (x == old_x)
        memcpy(new_x, x, p->numvars * sizeof(double));*/

    free(ratios);
    free(x);
    return cnt == 0 ? 0 : cnt - 1;
}

/* Finds a single non-negative solution by Newton's method. */
void solve_once_newton(problem_t *problem, sol_set_t *sols,
        thread_data_t *data, solve_params_t *params)
{
    double  *x = NULL;
    int     ndx;
    int     i;
    int     type;

    do {
        if (x)
            free(x);
        x = rand_vector(problem->numvars);
        /* A single EM step to get the correct scale for the solution */
        em_step(problem, x, MIN_IPF_THRESH);

        if (newton(problem, x, 1e8)) {
            /* Check for negative entries in x */
            type = TYPE_EXACT;
            for (i = 0; i < problem->numvars; i++) {
                if (x[i] < 0.0) {
                    type = TYPE_NEGATIVE;
                    pthread_mutex_lock(&data->mutex);
                    data->num_negative++;
                    pthread_mutex_unlock(&data->mutex);
                    break;
                }
            }
        } else {
            type = TYPE_NEGATIVE;
            pthread_mutex_lock(&data->mutex);
            data->num_diverged++;
            pthread_mutex_unlock(&data->mutex);
        }
    } while (type == TYPE_NEGATIVE);

    pthread_mutex_lock(&data->mutex);
    ndx = compare_found_solns(x, problem->numvars, sols, 1e-10,
                                        sols->num_solns);
    if (ndx >= 0)
        sols->descr[ndx].count++;
    else if (sols->num_solns < params->max_num_solns) {
        memcpy(sols->x + sols->num_solns * problem->numvars, x,
                problem->numvars * sizeof(double));
        sols->descr[sols->num_solns].count = 1;
        sols->descr[sols->num_solns].diverg = 0.0;
        sols->descr[sols->num_solns].type = type;
        sols->num_solns++;
    }

    free(x);
}

double *rand_vector(int len)
{
#if HAVE_PTHREAD
    static pthread_mutex_t rand_mutex = PTHREAD_MUTEX_INITIALIZER;
#endif
    double  *x = (double *)malloc(len * sizeof(double));
    int     i;

    /* drand48 (and rand) are not thread-safe, so we serialize access */
    pthread_mutex_lock(&rand_mutex);
    for (i = 0; i < len; i++) {
#if HAVE_DRAND48
        x[i] = drand48();
#else
        x[i] = ((double) rand()) / ((double) RAND_MAX);
#endif
    }
    pthread_mutex_unlock(&rand_mutex);

    return x;
}

/* Computes the (generalized) KL-divergence between the values produced by
 * solution and the right hand side in problem */
double kl_diverg(problem_t *problem, double *sol)
{
    double  *monom_val = compute_monoms(problem, sol);
    double  diverg = 0.0;
    int     i, j;

    for (i = 0; i < problem->numeqns; i++) {
        double  val = 0.0;
        double  rhs = problem->rhs[i];

        for (j = 0; j < problem->numterms; j++)
            val += problem->coeffs[i * problem->numterms + j] * monom_val[j];

        if (rhs > 0.0)
            diverg += rhs * log(rhs / val);
        diverg += val - rhs;
    }
    free(monom_val);
    return diverg;
}

/* Compares against all previously found solutions and returns the found solutio
 * index, or -1 if none were found. The number of solutions is passed as an
 * argument because sol->num_solns may be being written by other threads at the
 * same time. */
int compare_found_solns(double *x, int numvars, sol_set_t *sol,
        double granularity, int num_solns)
{
    int     i;

    for (i = 0; i < num_solns; i++) {
        if (diff_solns(sol->x + i * numvars, x, numvars) <= granularity) {
            /* solution has been found before */
            return i;
        }
    }
    return -1;
}

/* Compute the maximum relative distance between two points. */
double diff_solns(double *x, double *y, int numvars)
{
    int     i;
    double  max_diff = 0.0;
    double  max = 0.0;

    for (i = 0; i < numvars; i++) {
        double  diff = y[i] < x[i] ? x[i] - y[i] : y[i] - x[i];

        if (diff > max_diff)
            max_diff = diff;
        if (x[i] > max)
            max = x[i];
    }
    if (max == 0.0)
        return max_diff == 0.0 ? 0.0 : DBL_MAX;
    else
        return max_diff / max;
}
