/* Copyright 2009 Regents of the University of California
 * All rights reserved.
 */
#include "bilinear.h"
#include "solutions.h"
#include <stdlib.h>
#include <time.h>
#include <math.h>
#include <unistd.h>

#define NUM_STARTS 20
#define MAX_NUM_ITERS 1000000
#define MAX_IPF_ITERS 100
#define EM_EPS 1.0e-8

double *rescale(bilinear_problem_t b)
{
    int     i, l;
    double  *rescales = (double *)malloc(b.k * sizeof(double));

    for (l = 0; l < b.k; ++l) {
        rescales[l] = 0.0;
        for (i = 0; i < b.n * b.m; ++i)
            rescales[l] += b.functionals[l * b.n * b.m + i];
        
        for (i = 0; i < b.n * b.m; ++i)
            b.functionals[l * b.n * b.m + i] /= rescales[l];
    }
    
    return rescales;
}

void rescale_observeds(bilinear_problem_t b, double *rescales)
{
    int     l;
    
    for (l = 0; l < b.k; ++l) {
        b.observeds[l] /= rescales[l];
    }
}

/* This coputes one iteration of the IPF algorithm. It updates the estimates for
 * the vectors xs and ys. */
void ipf_iter(bilinear_problem_t b, double *marg_functionals, double *suff_x,
              double *suff_y, double *xs, double *ys)
{
    int     i, j;

    for (i = 0; i < b.n; i++) {
        double  suff_est = 0.0;
        double  factor;
        
        for (j = 0; j < b.m; j++)
            suff_est += marg_functionals[i * b.m + j] * xs[i] * ys[j];
        factor = sqrt(suff_x[i] / suff_est);
        xs[i] *= factor;
    }
    
    for (j = 0; j < b.m; j++) {
        double suff_est = 0.0;
        double factor;
        
        for (i = 0; i < b.n; i++)
            suff_est += marg_functionals[i * b.m + j] * xs[i] * ys[j];
        factor = sqrt(suff_y[j] / suff_est);
        ys[j] *= factor;
    }
}

/* This computes one iteration of the "alternative IPF algorithm". This uses the
 * bilinearity, i.e. the fact that the Newton polytope has codimension 2 to
 * speed up the algorithm: First the xs are rescaled, then the ys. */
void ipf_iter_alt(bilinear_problem_t b, double *marg_functionals,
                  double *suff_x, double *suff_y, double *xs, double *ys)
{
    int     i, j;
    
    /* Improve estimate by adjusting xs */
    for (i = 0; i < b.n; i++) {
        double  suff_est = 0.0;
        
        for (j = 0; j < b.m; j++)
            suff_est += marg_functionals[i * b.m + j] * ys[j];
        xs[i] = suff_x[i] / suff_est;
    }
    
    /* Improve estimate by adjusting ys */
    for (j = 0; j < b.m; j++) {
        double  suff_est = 0.0;
        
        for (i = 0; i < b.n; i++)
            suff_est += marg_functionals[i * b.m + j] * xs[i];
        ys[j] = suff_y[j] / suff_est;
    }
}

double diff_parameters(bilinear_problem_t b, double *xs, double *ys,
                       double *old_xs, double *old_ys)
{
    int     i, j;
    double  diff = 0.0;
    double  sum = 0.0;
    
    for (i = 0; i < b.n; ++i) {
        for (j = 0; j < b.m; ++j) {
            diff += fabs(xs[i] * ys[j] - old_xs[i] * old_ys[j]);
            sum += xs[i] * ys[j];
        }
    }
    return diff / sum / (b.n * b.m);
}

void ipf(bilinear_problem_t b, double *marg_functionals, double *exps,
         double *xs, double *ys)
{
    int     i, j;
        /* Sufficient statistics */
    double  *suff_x = (double *)malloc(b.n * sizeof(double));
    double  *suff_y = (double *)malloc(b.m * sizeof(double));
    double  *old_xs = (double *)malloc(b.n * sizeof(double));
    double  *old_ys = (double *)malloc(b.m * sizeof(double));
    double  diff = 1.0;
    
    for (i = 0; i < b.n; i++) {
        suff_x[i] = 0.0;
        for (j = 0; j < b.m; j++) {
            suff_x[i] += exps[i * b.m + j];
        }
    }
    for (j = 0; j < b.m; j++) {
        suff_y[j] = 0.0;
        for (i = 0; i < b.n; i++) {
            suff_y[j] += exps[i * b.m + j];
        }
    }
    
    /* :TODO: Right threshold for iteration here? */
    for (i = 0; i < MAX_IPF_ITERS && diff > EM_EPS; ++i) {
        for (j = 0; j < b.n; j++)
            old_xs[j] = xs[j];
        for (j = 0; j < b.m; j++)
            old_ys[j] = ys[j];
        ipf_iter_alt(b, marg_functionals, suff_x, suff_y, xs, ys);
        diff = diff_parameters(b, xs, ys, old_xs, old_ys);
    }
    
    free(old_xs);
    free(old_ys);
    free(suff_x);
    free(suff_y);
}

/* Computes the expected counts, marginalized across the l dimension. */
void compute_exps(bilinear_problem_t b, double *xs, double *ys, double *exps)
{
    int     i, j, l;
    double  *weighting = (double *)malloc(b.n * b.m * sizeof(double));
    
    for (i = 0; i < b.n * b.m; i++)
        exps[i] = 0.0;
    
    for (l = 0; l < b.k; ++l) {
        double  total_weighting = 0.0;
        
        for (i = 0; i < b.n; ++i) {
            for (j = 0; j < b.m; ++j) {
                int ndx = (l * b.n + i) * b.m + j;
                double w = xs[i] * ys[j] * b.functionals[ndx];
                weighting[i * b.m + j] = w;
                total_weighting += w;
            }
        }
        
        for (i = 0; i < b.n; ++i) {
            for (j = 0; j < b.m; ++j) {
                exps[i * b.m + j] += b.observeds[l] * weighting[i * b.m + j]
                        / total_weighting;
            }
        }
    }
    
    free(weighting);
}

/* One iteration of the "EM" algorithm */
void em_iter(bilinear_problem_t b, double *marg_functionals, double *xs,
             double *ys)
{
    int     i;
    double  *exps = (double *)malloc(b.n * b.m * sizeof(double));
    double  sum;
    
    compute_exps(b, xs, ys, exps);
    ipf(b, marg_functionals, exps, xs,  ys);
    
#if 1
    /* :TODO: is this the right way to do renormalization? */
    sum = 0.0;
    for (i = 0; i < b.n; ++i)
        sum += xs[i];
    for (i =0; i < b.n; ++i)
        xs[i] /= sum;
    for (i = 0; i < b.m; ++i)
        ys[i] *= sum;
#endif
    
    free(exps);
}

void compute_marg_functionals(bilinear_problem_t b, double *marg_functionals)
{
    int     i, l;
    
    for (i = 0; i < b.n * b.m; i++) {
        marg_functionals[i] = 0.0;
        for (l = 0; l < b.k; l++)
            marg_functionals[i] += b.functionals[l * b.n * b.m + i];
    }
}

/* Run the EM algorithm until it converges to a solution. */
int em(bilinear_problem_t b, double *xs, double *ys, int print_iterations)
{
    int     i, j;
    double  *old_xs, *old_ys;
    double  diff = 1.0;
    double  *marg_functionals = (double *)malloc(b.n * b.m * sizeof(double));
    
    old_xs = (double *)malloc(b.n * sizeof(double));
    old_ys = (double *)malloc(b.m * sizeof(double));
    
    compute_marg_functionals(b, marg_functionals);
    
    for (i = 0; i < b.n; ++i)
        xs[i] = drand48();
    for (i = 0; i < b.m; ++i)
        ys[i] = drand48();
    
    for (i = 0; i < MAX_NUM_ITERS && diff >= EM_EPS; ++i) {
        for (j = 0; j < b.n; ++j)
            old_xs[j] = xs[j];
        for (j = 0; j < b.m; ++j)
            old_ys[j] = ys[j];
        em_iter(b, marg_functionals, xs, ys);
        diff = diff_parameters(b, xs, ys, old_xs, old_ys);
    }
    
    if (print_iterations) 
        fprintf(stderr, "Num iterations: %d\n", i);
    
    free(marg_functionals);
    free(old_xs);
    free(old_ys);

    return i < MAX_NUM_ITERS;
}

void add_background(bilinear_problem_t b, double level)
{
    int     i;
    double  total = 0.0, avg;

    for (i = 0; i < b.k; i++)
        total += b.observeds[i];
    avg = total / b.k;
    for (i = 0; i < b.k; i++)
        b.observeds[i] += avg * level;
}

void check_zeros(bilinear_problem_t b)
{
    int     i;
    int     num_zeros = 0;
    
    for (i = 0; i < b.k; i++) {
        if (b.observeds[i] == 0.0) {
            num_zeros++;
        }
    }
    
    if (num_zeros > 0) {
        fprintf(stderr, "Warning: %s has %d observeds equal to zero\n", b.name,
                num_zeros);
        fprintf(stderr, "    Use the -b option to add background noise\n");
    }
}

int main(int argc, char **argv)
{
    bilinear_problem_t  b;
    double              *xs;
    double              *ys;
    double              *rescales;
    FILE                *fp;
    char                flag;
    double              background = 0.0;
    int                 print_help = TRUE;
    int                 usage_status = EXIT_FAILURE;
    int                 print_iterations = FALSE;
    int                 verbose = FALSE;
    
    srand48(time(NULL));
    
    while ((flag = getopt(argc, argv, "b:hiv")) != -1) {
        switch(flag) {
            case 'b':
                background = atof(optarg);
                break;
            case 'h':
                print_help = TRUE;
                usage_status = EXIT_SUCCESS;
                break;
            case 'i':
                print_iterations = TRUE;
                break;
            case 'v':
                verbose = TRUE;
                break;
            default:
                print_help = TRUE;
                break;
        }
    }
    
    if (argc - optind > 1 || print_help) {
        fprintf(stderr, "Usage: %s [options] [<file>]\n", argv[0]);
        fprintf(stderr, "  -b <num>  Add background\n");
        fprintf(stderr, "  -h        Print this help message\n");
        fprintf(stderr, "  -i        Print the numbers of iterations to stderr\n");
        fprintf(stderr, "  -v        Print additional information about solutions\n");
        return usage_status;
    } else if (argc - optind == 1) {
        char    *filename = argv[optind];
        
        fp = fopen(filename, "r");
        if (!fp) {
            fprintf(stderr, "Can't open %s\n", filename);
            return EXIT_FAILURE;
        }
    } else {
        fp = stdin;
    }
    
    read_bilinear_coeffs(&b, fp);
    read_to_observed(fp);
    rescales = rescale(b);
    
    xs = (double *)malloc(b.n * NUM_STARTS * sizeof(double));
    ys = (double *)malloc(b.m * NUM_STARTS * sizeof(double));
    
    while (read_observed_line(&b, fp)) {
        int     j;
        
        if (background > 0.0) {
            add_background(b, background);
        }
        check_zeros(b);
        rescale_observeds(b, rescales);
        
        for (j = 0; j < NUM_STARTS; ++j) {
            if (!em(b, xs + j * b.n, ys + j * b.m, print_iterations)) {
                fprintf(stderr, "Giving up\n");
                break;
            }
        }
        report_solns(b, xs, ys, j, verbose);
    }
    
    free(xs);
    free(ys);
    free(rescales);
    if (fp != stdin)
        fclose(fp);
    
    return EXIT_SUCCESS;
}
