/* Copyright (C) 2010 Dustin Cartwight
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <config.h>
#include "pos.h"
#include "homog.h"
#include "input.h"
#include "solve.h"
#include "svd.h"
#include <float.h>
#include <time.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#if HAVE_PTHREAD
#include <pthread.h>
#else
#define flockfile(fp) (void) 0;
#define funlockfile(fp) (void) 0;
#endif

typedef struct {
    char        *label;
    problem_t   problem;
    sol_set_t   *sols;
} sol_buffer_t;

typedef struct {
    solve_params_t  *params;
    FILE            *fp;
    int             grading;

#if HAVE_PTHREAD
    /* Circular buffer to store solutions out of order. By buffering solutions
     * as they are obtained, we can guarantee that they will be printed out in
     * the same order that they were read in. */
    sol_buffer_t    *buffer;
    int             buffer_start;
    int             buffer_end;
    pthread_mutex_t buffer_mutex;
    pthread_cond_t  buffer_cv; /* Signal availability of space */
#endif
} batch_thread_data_t;

#define SOL_BUFFER_SZ 64

void *batch_thread(void *p);
sol_set_t *setup_and_solve(problem_t *problem, solve_params_t *params,
        int grading);
#if HAVE_PTHREAD
int buffer_ndx(batch_thread_data_t *data, char *label);
void print_or_buffer(batch_thread_data_t *data, int ndx, problem_t *problem,
        sol_set_t *sols);
#endif
void print_solutions(problem_t *problem, sol_set_t *sols);

void print_problem(problem_t *problem);
void print_grading(problem_t *problem);

double condition_number(problem_t *problem);

void free_problem(problem_t *problem);
void free_sol_set(sol_set_t *sols);

#define USAGE_STR "Usage: %s [options] <problem file>\n"

int main(int argc, char **argv)
{
    int             grading = FALSE;
    int             keep_order = FALSE;
    int             cond = FALSE;
    const char      *filename;
    FILE            *fp;
    solve_params_t  params;
    problem_t       problem;
    int             c;

    init_solve_params(&params);
    while ((c = getopt(argc, argv, "bcdeg:hi:j:nops:t:")) != -1) {
        switch (c) {
            case 'c':
                cond = TRUE;
                break;
            case 'd':
                grading = TRUE;
                break;
            case 'e':
                params.boost = FALSE;
                break;
            case 'g':
                params.granularity = atof(optarg);
                break;
            case 'h':
                printf(USAGE_STR, argv[0]);
                printf("    -d           print grading\n");
                printf("    -e           disable boosting\n");
                printf("    -g <float>   granularity\n");
                printf("    -h           display this help message\n");
                printf("    -i <int>     iterations per solution\n");
                printf("    -j <int>     number of threads\n");
                printf("    -n           use Newton's method instead of EM\n");
                printf("    -o           preserve order in batch mode\n");
                printf("    -p           show progress\n");
                printf("    -s <int>     maximum number of solutions\n");
                printf("    -t <float>   threshold\n");
                return EXIT_SUCCESS;
                break;
            case 'i':
                params.iters_per_soln = atoi(optarg);
                break;
            case 'j':
                params.num_threads = atoi(optarg);
                break;
            case 'n':
                params.newton = TRUE;
                break;
            case 'o':
                keep_order = TRUE;
                break;
            case 'p':
                params.progress = TRUE;
                break;
            case 's':
                params.max_num_solns = atoi(optarg);
                break;
            case 't':
                params.thresh = atof(optarg);
                break;
            default:
                return EXIT_FAILURE;
        }
    }

    if (argc - optind != 1) {
        fprintf(stderr, USAGE_STR, argv[0]);
        return EXIT_FAILURE;
    }

#if HAVE_DRAND48
    srand48(time(NULL));
#else
    srand(time(NULL));
#endif

    filename = argv[optind];
    if (strcmp(filename, "-") == 0)
        fp = stdin;
    else {
        fp = fopen(argv[optind], "r");
        if (!fp) {
            fprintf(stderr, "Couldn't open %s\n", argv[optind]);
            return EXIT_FAILURE;
        }
    }

    flockfile(fp);
    if (read_system(fp, &problem)) {
        sol_set_t   *sols;

        if (cond)
            printf("Condition number: %g\n", condition_number(&problem));
        sols = setup_and_solve(&problem, &params, grading);
        print_solutions(&problem, sols);
    } else if (getc(fp) == '#') {
        /* First non-white space character is a '#', which means we should
         * process it in batch mode */
        batch_thread_data_t data;
#if HAVE_PTHREAD
        int                 num_threads = params.num_threads;
        pthread_t           *thr;
        int                 i;

        funlockfile(fp);

        thr = (pthread_t *)malloc((num_threads - 1) * sizeof(pthread_t));
#endif

        ungetc('#', fp);

        params.num_threads = 1;
        data.fp = fp;
        data.params = &params;
        data.grading = grading;

#if HAVE_PTHREAD
        if (keep_order) {
            data.buffer_start = 0;
            data.buffer_end = 0;
            data.buffer = (sol_buffer_t *)malloc(SOL_BUFFER_SZ *
                    sizeof(sol_buffer_t));
            pthread_mutex_init(&data.buffer_mutex, NULL);
            pthread_cond_init(&data.buffer_cv, NULL);

            for (i = 0; i < SOL_BUFFER_SZ; i++)
                data.buffer[i].sols = NULL;
        } else
            data.buffer = NULL;

        for (i = 0; i < num_threads - 1; i++)
            pthread_create(thr + i, NULL, batch_thread, &data);
#endif
        batch_thread(&data);

#if HAVE_PTHREAD
        for (i = 0; i < num_threads - 1; i++)
            pthread_join(thr[i], NULL);
#endif
    } else {
        fprintf(stderr, "File contained no equations\n");
        return EXIT_FAILURE;
    }

    if (fp != stdin)
        fclose(fp);

    return EXIT_SUCCESS;
}

void *batch_thread(void *p)
{
    batch_thread_data_t *data = (batch_thread_data_t *)p;
    char                *label;

    flockfile(data->fp);
    while ((label = read_label(data->fp)) != NULL) {
        problem_t   problem;
        sol_set_t   *sols;
#if HAVE_PTHREAD
        int         ndx = data->buffer ? buffer_ndx(data, label) : 0;
#endif

        if (!read_system(data->fp, &problem)) {
#if HAVE_PTHREAD
            if (data->buffer) {
                /* Relinquish our spot in the output buffer. Note that we don't
                 * need a lock here because the lock on fp prevents a concurrent
                 * call to buffer_ndx, and this update can't cause a race
                 * condition with print_or_buffer. */
                data->buffer_end = ndx;
            }
            funlockfile(data->fp);
#endif
            fprintf(stderr, "No equations for %s\n", label);
            free(label);
            continue;
        }

        sols = setup_and_solve(&problem, data->params, data->grading);

#if HAVE_PTHREAD
        if (data->buffer) {
            print_or_buffer(data, ndx, &problem, sols);
        } else {
#endif
            flockfile(stdout);
            printf("#%s\n", label);
            print_solutions(&problem, sols);
            funlockfile(stdout);

            free(label);
            free_problem(&problem); 
            free_sol_set(sols);
            free(sols);
#if HAVE_PTHREAD
        }

        flockfile(data->fp);
#endif
    }

    funlockfile(data->fp);
    return NULL;
}

sol_set_t *setup_and_solve(problem_t *problem, solve_params_t *params,
        int grading)
{
    find_homog(problem);
    if (grading)
        print_grading(problem);
    return solve(problem, params);
}

#if HAVE_PTHREAD
/* Get an index into the circular buffer to store the next solution set */
int buffer_ndx(batch_thread_data_t *data, char *label)
{
    int     ndx;

    pthread_mutex_lock(&data->buffer_mutex);

    /* Wait for space in the solution buffer */
    while (data->buffer_start == (data->buffer_end + 1) % SOL_BUFFER_SZ)
        pthread_cond_wait(&data->buffer_cv, &data->buffer_mutex);

    ndx = data->buffer_end;
    data->buffer_end = (data->buffer_end + 1) % SOL_BUFFER_SZ;
    pthread_mutex_unlock(&data->buffer_mutex);

    data->buffer[ndx].label = label;

    return ndx;
}

/* Given a solution set, print it if in the right order, or add it to the
 * solution buffer. */
void print_or_buffer(batch_thread_data_t *data, int ndx, problem_t *problem,
        sol_set_t *sols)
{
    data->buffer[ndx].problem = *problem;
    pthread_mutex_lock(&data->buffer_mutex);
    data->buffer[ndx].sols = sols;

    if (ndx == data->buffer_start) {
        while (data->buffer[ndx].sols) {
            flockfile(stdout);
            printf("#%s\n", data->buffer[ndx].label);
            print_solutions(&data->buffer[ndx].problem,
                    data->buffer[ndx].sols);
            funlockfile(stdout);

            free(data->buffer[ndx].label);
            free_problem(&data->buffer[ndx].problem);
            free_sol_set(data->buffer[ndx].sols);
            free(data->buffer[ndx].sols);
            data->buffer[ndx].sols = NULL;

            ndx = (ndx + 1) % SOL_BUFFER_SZ;
        }

        /* Signal if anyone may have been waiting for more space */
        if (data->buffer_start == (data->buffer_end + 1) % SOL_BUFFER_SZ)
            pthread_cond_broadcast(&data->buffer_cv);
        data->buffer_start = ndx;
    }
    pthread_mutex_unlock(&data->buffer_mutex);
}
#endif

void print_solutions(problem_t *problem, sol_set_t *sols)
{
    int             i, j;

    printf ("%d solution%s\n", sols->num_solns, sols->num_solns == 1 ? "":"s");

    for (i = 0; i < sols->num_solns; i++) {
        if (sols->descr[i].type == TYPE_EXACT)
            printf("exact");
        else if (sols->descr[i].type == TYPE_NEGATIVE)
            printf("negative");
        else {
            if (sols->descr[i].type == TYPE_NEAR_EXACT)
                printf("near exact, ");
            printf("divergence %g", sols->descr[i].diverg);
        }
        printf(", count %d\n", sols->descr[i].count);

        for (j = 0; j < problem->numvars; j++) {
            printf("  %s=%.10g\n", problem->varnames[j],
                    sols->x[i * problem->numvars + j]);
        }
    }
}

void print_problem(problem_t *problem)
{
    int     i, j, k;

    flockfile(stdout);
    printf("%d terms\n", problem->numterms);
    for (i = 0; i < problem->numeqns; i++) {
        int     firstterm = TRUE;
        for (j = 0; j < problem->numterms; j++) {
            int     firstvar = TRUE;
            double  coeff = problem->coeffs[i * problem->numterms + j];
            
            if (coeff == 0.0)
                continue;
            if (!firstterm)
                printf(" + ");
            if (coeff != 1.0)
                printf("%g ", coeff);

            for (k = 0; k < problem->numvars; k++) {
                int     exp = problem->exps[j * problem->numvars + k];

                if (exp == 0) continue;
                if (!firstvar) printf("*");
                printf("%s", problem->varnames[k]);
                if (exp != 1) printf("^%d", exp);
                firstvar = FALSE;
            }
            firstterm = FALSE;
        }
        printf(" = %g;\n", problem->rhs[i]);
    }
    funlockfile(stdout);
}

void print_grading(problem_t *problem)
{
    int     i, j;

    flockfile(stdout);
    for (i = 0; i < problem->numvars; i++) {
        printf("deg %s = (", problem->varnames[i]);
        for (j = 0; j < problem->rank; j++) {
            if (j > 0)
                printf(", ");
            printf("%d", problem->weights[j * problem->numvars + i]);
        }
        printf(")\n");
    }
    
    printf("deg terms = (");
    for (i = 0; i < problem->rank; i++) {
        if (i > 0) printf(", ");
        printf("%d", problem->degs[i]);
    }
    printf(")\n");
    funlockfile(stdout);
}           

/* Compute the condition number of the matrix of coefficients. The condition
 * number is the ratio of the largest singular value to the smallest. This may
 * or not be relevant for the difficulty of problem. */
double condition_number(problem_t *problem)
{
    gsl_matrix  *m = gsl_matrix_alloc(problem->numterms, problem->numeqns);
    gsl_matrix  *scratch = gsl_matrix_alloc(problem->numeqns, problem->numeqns);
    gsl_vector  *svs = gsl_vector_alloc(problem->numeqns);
    double      min_sv = DBL_MAX;
    double      max_sv = 0;
    int         i, j;

    for (i = 0; i < problem->numeqns; i++) {
        for (j = 0; j < problem->numterms; j++) {
            double  coeff = problem->coeffs[i * problem->numterms + j];

            m->data[j * m->tda + i] = coeff;
        }
    }

    gsl_linalg_SV_decomp_jacobi(m, scratch, svs);

    for (i = 0; i < problem->numeqns; i++) {
        if (svs->data[i] < min_sv)
            min_sv = svs->data[i];
        if (svs->data[i] > max_sv)
            max_sv = svs->data[i];
    }

    gsl_matrix_free(m);
    gsl_matrix_free(scratch);
    gsl_vector_free(svs);

    return max_sv / min_sv;
}

void free_problem(problem_t *problem)
{
    free(problem->exps);
    free(problem->coeffs);
    free(problem->rhs);
    if (problem->weights) free(problem->weights);
    if (problem->degs) free(problem->degs);
    if (problem->coeff_totals) free(problem->coeff_totals);
    
    if (problem->varnames) {
        int     i;

        for (i = 0; i < problem->numvars; i++)
            free(problem->varnames[i]);
        free(problem->varnames);
    }
}

void free_sol_set(sol_set_t *sols)
{
    free(sols->x);
    free(sols->descr);
}
