/* Copyright (C) 2010 Dustin Cartwight
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "homog.h"
#include <stdio.h>
#include <stdlib.h>

int try_same_deg(problem_t *problem);
int try_bilinear(problem_t *problem);
int *find_kernel(int *matrix, int rows, int cols, int *rows_kernel);
int reduce_matrix(int *matrix, int rows, int cols);
int *kernel_reduced(int *matrix, int rows, int cols);
int find_gcd(int a, int b, int *pMultA, int *pMultB);
void print_matrix(int *matrix, int rows, int cols);

void find_homog(problem_t *problem)
{
    if (!try_bilinear(problem) && !try_same_deg(problem)) {
        fprintf(stderr, "Error: Couldn't find appropriate grading.\n");
        exit(EXIT_FAILURE);
    }
}

/* Try to see if all the monomials have the same degree with respect to some
 * weighting. */
int try_same_deg(problem_t *problem)
{
    int     rows = problem->numterms;
    int     cols = problem->numvars + 1;
    int     *matrix = (int *)malloc(rows * cols * sizeof(int));
    int     *kernel;
    int     i, j;

    for (i = 0; i < rows; i++) {
        matrix[i * cols] = 1;
        for (j = 0; j < problem->numvars; j++)
            matrix[i * cols + j + 1] = problem->exps[i * problem->numvars + j];
    }

    kernel = find_kernel(matrix, rows, cols, &problem->rank);
    free(matrix);

    if (problem->rank != 1 || kernel[0] <= 0) {
        /* If there's a multigrading, we give up: such a system won't be
         * zero-dimensional anyways. */
        problem->rank = 0;
        free(kernel);
        return FALSE;
    }

    problem->weights = (int *)malloc(problem->numvars * sizeof(int));

    for (i = 0; i < problem->numvars; i++) {
        int w = -kernel[i + 1];
        if (w <= 0) {
            free(kernel);
            free(problem->weights);
            problem->weights = NULL;
            return FALSE;
        }
        problem->weights[i] = w;
    }

    problem->degs = (int *)malloc(sizeof(int));
    problem->degs[0] = kernel[0];

    free(kernel);
    return TRUE;
}

/* See if the problem is bilinear, using the variable names to form the
 * partition set */
int try_bilinear(problem_t *problem)
{
    int     i, j;

    for (i = 0; i < problem->numterms; i++) {
        int     x_deg = 0;
        int     other_deg = 0;

        for (j = 0; j < problem->numvars; j++) {
            if (problem->varnames[j][0] == 'x')
                x_deg += problem->exps[i * problem->numvars + j];
            else
                other_deg += problem->exps[i * problem->numvars + j];
        }

        if (x_deg > 1 || other_deg > 1)
            return FALSE;
    }

    problem->rank = 2;
    problem->weights = (int *)malloc(2 * problem->numvars * sizeof(int));
    problem->degs = (int *)malloc(2 * sizeof(int));

    for (i = 0; i < problem->numvars; i++) {
        if (problem->varnames[i][0] == 'x') {
            problem->weights[i] = 1;
            problem->weights[problem->numvars + i] = 0;
        } else {
            problem->weights[i] = 0;
            problem->weights[problem->numvars + i] = 1;
        }
    }

    problem->degs[0] = problem->degs[1] = 1;
    return TRUE;
}

/* Finds an integer basis for the kernel of the matrix. Sets *rows_kernel to the
 * dimension of the kernel and returns a pointer to a *rows_kernel x cols
 * matrix. */
int *find_kernel(int *matrix, int rows, int cols, int *rows_kernel)
{
    int     rows_nz = reduce_matrix(matrix, rows, cols);
    int     *kernel = kernel_reduced(matrix, rows_nz, cols);

    *rows_kernel = cols - rows_nz;
    return kernel;
}

/* Reduce an integer matrix to echelon form, except that the pivot entries don't
 * have to be 1, in order to keep all entries integer. Returns the number of
 * non-zero rows. */
int reduce_matrix(int *matrix, int rows, int cols)
{
    int     r = 0;
    int     *foo = (int *)malloc(rows * sizeof(int));
    int     *mults = (int *)malloc(rows * sizeof(int));
    int     c;
    int     i, j;

    for (c = 0; c < cols; c++) {
        int     gcd = 0;
        int     nonzero = -1;

        for (i = r; i < rows; i++) {
            int     m;
            gcd = find_gcd(matrix[i * cols + c], gcd, mults + i, &m);
            for (j = r; j < i; j++)
                mults[j] *= m;
            if (mults[i] != 0)
                nonzero = i;
        }

        if (gcd != 0) {
            int     tmp;
            /* Swap rows r and nonzero */
            for (i = c; i < cols; i++) {
                tmp = matrix[r * cols + i];
                matrix[r * cols + i] = matrix[nonzero * cols + i];
                matrix[nonzero * cols + i] = tmp;
            }
            tmp = mults[r];
            mults[r] = mults[nonzero];
            mults[nonzero] = tmp;

            for (i = c; i < cols; i++) {
                matrix[r * cols + i] *= mults[r];
            }
            for (i = r + 1; i < rows; i++) {
                for (j = c; j < cols; j++) {
                    matrix[r * cols + j] += mults[i] * matrix[i * cols + j];
                }
            }
            
            for (i = r + 1; i < rows; i++) {
                int     ratio = matrix[i * cols + c] / gcd;
                for (j = c; j < cols; j++) {
                    matrix[i * cols + j] -= ratio * matrix[r * cols + j];
                }
            }

            r++;
        }
    }

    free(foo);
    free(mults);

    return r;
}

/* Computes the kernel of a reduced matrix, as returned by reduce_matrix. This
 * consists of a k x cols matrix where k = cols - rows */
int *kernel_reduced(int *matrix, int rows, int cols)
{
    int     *kernel = (int *)malloc((cols - rows) * cols * sizeof(int));
    int     r, c;
        /* location of the pivot column in each row */
    int     *pivot_col = (int *)malloc(rows * sizeof(int));
        /* prod_pivots[i] is the product of the pivot entries in rows <= i */
    int     *prod_pivots = (int *)malloc(rows * sizeof(int));
    int     prod = 1;
    int     i, j;

    for (r = 0, c = 0; r < rows; r++, c++) {
        for (; matrix[r * cols + c] == 0; c++) {
            if (c >= cols) {
                fprintf(stderr, "Matrix has row of all zeros\n");
                exit(EXIT_FAILURE);
            }
        }

        prod_pivots[r] = prod *= matrix[r * cols + c];
        pivot_col[r] = c;
    }

    for (i = 0; i < (cols - rows) * cols; i++)
        kernel[i] = 0;

    for (r = rows - 1, c = cols - 1; r >= 0; r--, c--) {
        for (; c > pivot_col[r]; c--) {
            /* index into the appropriate row of the kernel matrix */
            int     *k = kernel + (c - r - 1) * cols;

            k[c] = -prod_pivots[r];
            for (i = r; i >= 0; i--) {
                for (j = pivot_col[i] + 1; j <= c; j++)
                    k[pivot_col[i]] -= matrix[i * cols + j] * k[j];

                kernel[pivot_col[i]] /= matrix[i * cols + pivot_col[i]];
            }
        }
    }

    free(prod_pivots);
    free(pivot_col);
    return kernel;
}

int find_gcd(int a, int b, int *pMultA, int *pMultB)
{
    if (b < 0) {
        int     gcd = find_gcd(a, -b, pMultA, pMultB);
        *pMultB *= -1;
        return gcd;
    }

    if (a < 0) {
        int     gcd = find_gcd(-a, b, pMultA, pMultB);
        *pMultA *= -1;
        return gcd;
    }

    if (b == 0) {
        *pMultA = 1;
        *pMultB = 0;
        return a;
    } else {
        int     rem = a % b;
        int     gcd = find_gcd(b, rem, pMultB, pMultA);

        *pMultB -= *pMultA * (a / b);
        return gcd;
    }
}

void print_matrix(int *matrix, int rows, int cols)
{
    int     i, j;

    printf("matrix:\n");
    for (i = 0; i < rows; i++) {
        for (j = 0; j < cols; j++) {
            printf("%d ", matrix[i * cols + j]);
        }
        printf("\n");
    }
}
