/* Copyright (C) 2010 Dustin Cartwight
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <config.h>
#include "input.h"
#include "pos.h"
#include <ctype.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>

#define MAX_LABEL_LEN       256
#define MAX_NUM_EQNS        256
#define MAX_NUM_VARS        256
#define INITIAL_VAR_SPACE   6
#define MAX_VAR_LEN         16

/* Structures which are used temporarily in parsing the input system of
 * equations. */
typedef struct term_st term_t;
struct term_st {
    double  coeff;
    int     numvars;
    int     *vars;
    term_t  *next;
};

typedef struct {
    term_t  *first_term;
    int     numterms;
    double  rhs;
} eqn_t;

typedef struct {
    char    *name;
    int     old_ndx;
} var_t;

typedef struct {
    int     numvars;
    var_t   *vars;
} var_table_t;

int parse_eqn(FILE *fp, eqn_t *eqn, var_table_t *vars);
int parse_term(FILE *fp, term_t *term, var_table_t *vars);
double parse_coeff(FILE *fp);

int read_var(FILE *fp, var_table_t *vars);
int parse_double(FILE *fp, double *p);
int parse_signed_int(FILE *fp, int *p);
int parse_int(FILE *fp);
int skip_spaces(FILE *fp);

void exp_vector(term_t *term, int numvars, int *exp);
int compare_exp_vectors(const void *p1, const void *p2);
int compare_exp_vectors_r(const int *p1, const int *p2, int numvars);
int remove_dup_exps(int *exps, int numvars, int numterms);
int *permute_exp_vectors(int *old_exps, int numterms, var_table_t *vars);
int find_exp(term_t *term, problem_t *problem);

int compare_vars(const void *p1, const void *p2);

/* Number of variables, used by compare_exp_vectors. Note that read_system
 * protects this with the lock in the FILE pointers, so it is non-reentrant when
 * called with different FILE pointers. */
static int numvars_compare;

/* Read a system of equations from a file */
void read_problem(const char *file, problem_t *problem)
{
    FILE        *fp = fopen(file, "r");

    if (!read_system(fp, problem)) {
        fprintf(stderr, "No equations in %s\n", file);
        exit(EXIT_FAILURE);
    }
    fclose(fp);
}

char *read_label(FILE *fp)
{
    int     result = skip_spaces(fp);
    char    *label;
    int     i = 0;

    if (result == EOF)
        return NULL;
    else if (result != '#') {
        fprintf(stderr, "Expected # to begin label\n");
        exit(EXIT_FAILURE);
    }

    label = (char *)malloc(MAX_LABEL_LEN);

    while ((result = getc(fp)) != '\n' && result != EOF) {
        if (i < MAX_LABEL_LEN - 1)
            label[i++] = result;
        else if (i == MAX_LABEL_LEN - 1) {
            label[i] = '\0';
            fprintf(stderr, "Warning: Label %s too long, truncating\n", label);
        }
    }

    label[i] = '\0';
    return label;
}

/* Reads a single system of equations from the stream, and returns whether or
 * not it succeeded. It assumes that flockfile(fp) has already been called and
 * if it successfully reads a system of equations, it unlocks the stream after
 * finishing its parsing of fp. It is only reentrant for calls with the same
 * FILE pointer fp, because it uses fp's lock to protect a static variable. */
int read_system(FILE *fp, problem_t *problem)
{
    eqn_t       *eqns = (eqn_t *)malloc(MAX_NUM_EQNS * sizeof(eqn_t));
    int         numeqns = 0;
    int         totalnumterms = 0;
    var_table_t vars;
    int         i, j;
    term_t      *term;

    vars.numvars = 0;
    vars.vars = (var_t *)malloc(MAX_NUM_VARS * sizeof(var_t));

    while (parse_eqn(fp, eqns + numeqns, &vars)) {
        totalnumterms += eqns[numeqns].numterms;

        if (++numeqns >= MAX_NUM_EQNS) {
            fprintf(stderr, "Too many equations. Maximum is %d\n",
                    MAX_NUM_EQNS);
            break;
        }
    }

    if (numeqns == 0)
        return FALSE;

    /* Translate problem to problem_t structure */
    problem->numvars = vars.numvars;
    problem->numeqns = numeqns;
    problem->exps = (int *)malloc(totalnumterms * vars.numvars * sizeof(int));
    problem->rank = 0;
    problem->weights = NULL;
    problem->degs = NULL;
    problem->varnames = (char **)malloc(vars.numvars * sizeof(char *));
    problem->coeff_totals = NULL;

    /* Create set of exponents */
    j = 0;
    for (i = 0; i < numeqns; i++) {
        for (term = eqns[i].first_term; term != NULL; term = term->next)
            exp_vector(term, vars.numvars, problem->exps + j++ * vars.numvars);
    }
    if (j != totalnumterms) {
        fprintf(stderr, "Internal error: number of terms went from %d to %d\n",
                totalnumterms, j);
        exit(EXIT_FAILURE);
    }

    /* Remove duplicate exponents */
    numvars_compare = vars.numvars;
    qsort(problem->exps, totalnumterms, vars.numvars * sizeof(int),
          compare_exp_vectors);
    problem->numterms = remove_dup_exps(problem->exps, vars.numvars,
            totalnumterms);

#if HAVE_PTHREAD
    /* Note that because the previous use of qsort is non-reentrant, we use the
     * lock on fp to prevent executing it twice. This works only so long as the
     * read_system is only called with the same FILE pointer. */
    funlockfile(fp);
#endif

    /* Create coefficients and right hand side. */
    problem->coeffs = (double *)malloc(problem->numeqns * problem->numterms
                                        * sizeof(double));
    problem->rhs = (double *)malloc(problem->numeqns * sizeof(double));

    for (i = 0; i < problem->numeqns * problem->numterms; i++)
        problem->coeffs[i] = 0.0;

    for (i = 0; i < numeqns; i++) {
        for (term = eqns[i].first_term; term != NULL; term = term->next) {
            int     term_ndx = find_exp(term, problem);
            
            problem->coeffs[i * problem->numterms + term_ndx] += term->coeff;
        }
        problem->rhs[i] = eqns[i].rhs;
    }

    /* re-order variables alphabetically */
    qsort(vars.vars, vars.numvars, sizeof(var_t), compare_vars);
    problem->exps = permute_exp_vectors(problem->exps, problem->numterms,&vars);

    for (i = 0; i < vars.numvars; i++)
        problem->varnames[i] = vars.vars[i].name;

    /* Free memory */
    for (i = 0; i < numeqns; i++) {
        for (term = eqns[i].first_term; term != NULL; ) {
            term_t  *next_term = term->next;
            free(term->vars);
            free(term);
            term = next_term;
        }
    }
    free(eqns);

    return TRUE;
}

/* Parse a single equation from a file. Returns TRUE if it parsed an equation
 * successfully. */
int parse_eqn(FILE *fp, eqn_t *eqn, var_table_t *vars)
{
    int     result;

    eqn->first_term = NULL;
    eqn->numterms = 0;

    result = skip_spaces(fp);
    if (result == EOF)
        return FALSE;
    ungetc(result, fp);
    if (result == '#')
        return FALSE;

    do {
        term_t  *new_term = (term_t *)malloc(sizeof(term_t));

        result = parse_term(fp, new_term, vars);

        if (new_term->coeff == 0.0) {
            /* No need to include terms which have coefficent 0.0 */
            free(new_term->vars);
            free(new_term);
        } else {
            new_term->next = eqn->first_term;
            eqn->first_term = new_term;
            eqn->numterms++;
        }
    } while(result == '+');

    if (result == EOF)
        return FALSE;
    else if (result != '=' && result != '-') {
        fprintf(stderr, "Expected =, got %c\n", (char) result);
        exit(EXIT_FAILURE);
    }

    if (!parse_double(fp, &eqn->rhs)) {
        fprintf(stderr, "Expected right hand side\n");
        exit (EXIT_FAILURE);
    }

    result = skip_spaces(fp);
    if (result != ';') {
        fprintf(stderr, "Expected ;, got %c\n", (char) result);
        exit (EXIT_FAILURE);
    }
    return TRUE;
}

/* Parses a single term and returns the next character from the stream after
 * the term. */
int parse_term(FILE *fp, term_t *term, var_table_t *vars)
{
    int     c;

    term->coeff = parse_coeff(fp);
    term->numvars = 0;
    term->vars = (int *)malloc(INITIAL_VAR_SPACE * sizeof(int));

    do {
        int var = read_var(fp, vars);
        int exp = 1;
        int new_numvars;

        if (var == -1) {
            fprintf(stderr, "Unexpected EOF in the middle of term\n");
            exit(EXIT_FAILURE);
        }

        if ((c = skip_spaces(fp)) == '^') {
            exp = parse_int(fp);
            c = skip_spaces(fp);
        }

        new_numvars = term->numvars + exp;
        if (new_numvars > INITIAL_VAR_SPACE)
            term->vars = (int *)realloc(term->vars, new_numvars * sizeof(int));

        while (--exp >= 0)
            term->vars[term->numvars++] = var;
    } while (c == '*');

    return c;
}

/* Parses a coefficient and an optional * following it. */
double parse_coeff(FILE *fp)
{
    double  coeff;
    if (parse_double(fp, &coeff)) {
        int     c = skip_spaces(fp);
        if (c != '*')
            ungetc(c, fp);
        return coeff;
    } else
        return 1.0;
}

/* Attempts to read a single variable name from the stream and returns the
 * variable index. Returns -1 if it encounters EOF before reading a variable
 * name. */
int read_var(FILE *fp, var_table_t *vars)
{
    char    var[MAX_VAR_LEN];
    int     varlen = 0;
    int     c;
    int     i;

    c = skip_spaces(fp);
    if (!isalpha(c)) {
        if (c == EOF)
            return -1;
        else {
            fprintf(stderr, "Invalid initial character for variable: %c\n",
                    (char) c);
            exit(EXIT_FAILURE);
        }
    }

    do {
        var[varlen++] = (char) c;
        if (varlen >= MAX_VAR_LEN) {
            var[varlen] = '\0';
            fprintf(stderr, "Variable beginning %s is too long\n", var);
            exit(EXIT_FAILURE);
        }
        c = getc(fp);
    } while (isalnum(c));
    ungetc(c, fp);
    var[varlen] = '\0';

    /* Lookup variable in table */
    for (i = 0; i < vars->numvars; i++) {
        if (strcmp(vars->vars[i].name, var) == 0)
            return i;
    }

    /* Variable not found, create new entry */
    if (vars->numvars >= MAX_NUM_VARS) {
        fprintf(stderr, "Too many variables: %s\n", var);
        exit(EXIT_FAILURE);
    }
    vars->vars[vars->numvars].old_ndx = vars->numvars;
    vars->vars[vars->numvars].name = (char *)malloc(varlen + 1);
    strcpy(vars->vars[vars->numvars].name, var);
    return vars->numvars++;
}

/* Attempts to read a single floating point number from the stream and returns
 * TRUE if successful. This parses the number in stream rather than using
 * strtod, but that may be a mistake. */
int parse_double(FILE *fp, double *p)
{
    double  result = parse_int(fp);
    int     c;

    if (result == -1.0)
        return FALSE;

    /* Read fractional part */
    if ((c = getc(fp)) == '.') {
        double  pos = 1.0;
        c = getc(fp);
        if (!isdigit(c)) {
            fprintf(stderr, "Expected fractional part, got %c\n", (char) c);
            exit(EXIT_FAILURE);
        }
        do {
            pos *= 0.1;
            result += pos * (c - '0');
            c = getc(fp);
        } while (isdigit(c));
    }
    
    /* Read exponent */
    if (c == 'e' || c == 'E') {
        int     exp;

        if (!parse_signed_int(fp, &exp)) {
            fprintf(stderr, "Invalid exponent\n");
            exit(EXIT_FAILURE);
        }
        result *= pow(10, exp);
    } else 
        ungetc(c, fp);

    *p = result;
    return TRUE;
}

int parse_signed_int(FILE *fp, int *p)
{
    int sign = -1;
    int mag;
    int c = getc(fp);

    if (c != '-') {
        sign = -1;
        ungetc(c, fp);
    }
    mag = parse_int(fp);
    if (mag == -1)
        return FALSE;
    else {
        *p = mag * sign;
        return TRUE;
    }
}

/* Reads a single non-negative integer from the stream and returns the integer
 * or -1 on error. */
int parse_int(FILE *fp)
{
    int     c;
    int     result = 0;

    c = skip_spaces(fp);
    if (!isdigit(c)) {
        ungetc(c,fp);
        return -1;
    }

    do {
        result *= 10;
        result += c - '0';
        c = getc(fp);
    } while (isdigit(c));

    ungetc(c, fp);
    return result;
}

/* Reads zero or more spaces and returns the next character from the stream (or
 * EOF). */
int skip_spaces(FILE *fp)
{
    int     c;
    do {
        c = getc(fp);
    } while(isspace(c));
    return c;
}

void exp_vector(term_t *term, int numvars, int *exp)
{
    int     i;

    /* compute the exponent vector for this term */
    for (i = 0; i < numvars; i++)
        exp[i] = 0;
    for (i = 0; i < term->numvars; i++) {
        int var = term->vars[i];
        if (var < 0 || var >= numvars) {
            fprintf(stderr, "Internal error: Variable index out of range: %d\n",
                    var);
            exit(EXIT_FAILURE);
        }
        exp[var]++;
    }
}

int compare_exp_vectors(const void *p1, const void *p2)
{
    const int    *exp1 = (const int *)p1;
    const int    *exp2 = (const int *)p2;

    return compare_exp_vectors_r(exp1, exp2, numvars_compare);
}

int compare_exp_vectors_r(const int *exp1, const int *exp2, int numvars)
{
    int         i;

    for (i = 0; i < numvars; i++) {
        if (exp1[i] != exp2[i])
            return exp2[i] - exp1[i];
    }

    return 0;
}

/* Removes duplicate exponents */
int remove_dup_exps(int *exps, int numvars, int numterms)
{
    int     src, dst = 0;

    for (src = 1; src < numterms; src++) {
        if (compare_exp_vectors_r(exps + src*numvars, exps + dst*numvars,
                    numvars) != 0) {
            if (++dst != src) {
                memcpy(exps + dst * numvars, exps + src * numvars,
                        numvars * sizeof(int));
            }
        }
    }

    return dst + 1;
}

int *permute_exp_vectors(int *old_exps, int numterms, var_table_t *vars)
{
    int     *exps = (int *)malloc(numterms * vars->numvars * sizeof(int));
    int     i, j;

    for (i = 0; i < numterms; i++) {
        for (j = 0; j < vars->numvars; j++) {
            int old_off = i * vars->numvars + vars->vars[j].old_ndx;

            exps[i * vars->numvars + j] = old_exps[old_off];
        }
    }
    
    free(old_exps);
    return exps;
}

/* Find the term index of an exponent, or create one if it doesn't exist. It
 * returns the term index or -1 if it created a new term. */
int find_exp(term_t *term, problem_t *problem)
{
    int     low = 0, hi = problem->numterms - 1;
    int     *exp = (int *)malloc(problem->numvars * sizeof(int));

    exp_vector(term, problem->numvars, exp);
    while (low <= hi) {
        int mid = (low + hi) / 2;
        int comp;
        
        comp = compare_exp_vectors_r(exp,
                problem->exps + mid * problem->numvars, problem->numvars);

        if (comp < 0)
            hi = mid - 1;
        else if (comp > 0)
            low = mid + 1;
        else {
            free(exp);
            return mid;
        }
    }

    fprintf(stderr, "Internal error: couldn't find exponent\n");
    exit(EXIT_FAILURE);
    return -1;
}

int compare_vars(const void *p1, const void *p2)
{
    return strcmp(((const var_t *) p1)->name, ((const var_t *) p2)->name);
}
