/* Copyright 2009 Regents of the University of California
 * All rights reserved.
 */
#include "solutions.h"
#include <math.h>
#include <stdlib.h>
#include <string.h>

/* Computes the observed values wthat would result from the given values of xs
 * and ys. */
void forward(bilinear_problem_t b, double *xs, double *ys, double *sim_data)
{
    int     i, j, l;
    
    for (l = 0; l < b.k; ++l) {
        sim_data[l] = 0.0;
        for (i = 0; i < b.n; ++i) {
            for (j = 0; j < b.m; ++j) {
                int ndx = (l * b.n + i) * b.m + j;
                
                sim_data[l] += b.functionals[ndx] * xs[i] * ys[j];
            }
        }
    }
}

void print_soln(bilinear_problem_t b, double *xs, double *ys, double err)
{
    int     i;
    
    printf("%s", b.name);
    for (i = 0; i < b.n; ++i)
        printf(" %f", xs[i]);
    printf(" |");
    for (i = 0; i < b.m; ++i)
        printf(" %f", ys[i]);
    printf(" *%e\n", err);
}

void print_soln_extras(bilinear_problem_t b, double *xs, double *ys)
{
    int     i;
    double  *pred = (double *)malloc(b.k * sizeof(double));
    int     zeros_printed = 0;

    for (i = 0; i < b.m; ++i) {
        if (ys[i] < 0.001) {
            if (!zeros_printed) {
                printf("    zeros:");
                zeros_printed = 1;
            }
            printf(" %d", i);
        }
    }
    if (zeros_printed)
        printf("\n");
    
    forward(b, xs, ys, pred);
    for (i = 0; i < b.k; i++)
        printf("    o %s: %f/%f=%f\n", b.functional_names[i], b.observeds[i],
               pred[i], b.observeds[i] / pred[i]);
    free(pred);
}

double compute_err(bilinear_problem_t b, double *xs, double *ys)
{
    double  sum_sq = 0.0;
    double  *pred = (double *)malloc(b.k * sizeof(double));
    int     i;
    
    forward(b, xs, ys, pred);
    for (i = 0; i < b.k; i++) {
        double  diff = (pred[i] - b.observeds[i]);
        sum_sq += diff * diff;
    }
    free(pred);
    return sqrt(sum_sq);
}

/* Takes an array of putative solutions to the bilinear problem and merges those
 * which differ by less than a certain amount in the L-infinity norm. The chosen
 * solutions are rearranged to the beginnings of the arrays and the number of
 * such solutions is returned. */
int combine_repeats(bilinear_problem_t b, double *xs, double *ys, double *errs,
                    int num_starts)
{
    double      eps = 0.01;
    int         i, j, k;
    int         *refs = (int *)malloc(num_starts * sizeof(int));
    int         num_unique = 0;
    
    for (i = 0; i < num_starts; ++i) {
        for (k = 0; k < num_unique; ++k) {
            double  sum_sq = 0.0;
            double  y_sum = 0.0;
            
            for (j = 0; j < b.m; ++j) {
                double  y_ref = ys[refs[k] * b.m + j];
                double  diff = ys[i * b.m + j] - y_ref;
                y_sum += y_ref;
                sum_sq += diff * diff;
            }
            sum_sq /= y_sum * y_sum;
            for (j = 0; j < b.n; ++j) {
                double  diff = xs[i * b.n + j] - xs[refs[k] * b.n + j];
                sum_sq += diff * diff;
            }
            
            if (sum_sq <= (b.n + b.m) * eps * eps) {
                if (errs[refs[k]] > errs[i])
                    refs[k] = i;
                break;
            }
        }
            
        if (k == num_unique) {
            refs[num_unique++] = i;
        }
    }
    
#if 1
    for (i = 0; i < num_unique; ++i) {
        if (refs[i] != i) {
            memcpy(xs + i * b.n, xs + refs[i] * b.n, b.n * sizeof(*xs));
            memcpy(ys + i * b.m, ys + refs[i] * b.m, b.m * sizeof(*ys));
            errs[i] = errs[refs[i]];
        }
    }
    
    free(refs);
    
    return num_unique;
#else
    for (i = 0; i < num_starts; ++i) {
        if (i != refs[0]) {
            for (j = 0; j < b.n; j++) {
                xs[i * b.n + j] -= xs[refs[0] * b.n + j];
            }
            for (j = 0; j < b.m; j++) {
                ys[i * b.m + j] -= ys[refs[0] * b.m + j];
            }
        }
    }
    
    free(refs);
    return num_starts;
#endif
}

void report_solns(bilinear_problem_t b, double *xs, double *ys, int num_starts,
        int print_extras)
{
    double  *errs = (double *)malloc(num_starts * sizeof(double));
    int     i;
    int     num_unique;
    
    for (i = 0; i < num_starts; ++i) {
        errs[i] = compute_err(b, xs + i * b.n, ys + i * b.m);
    }
    
    num_unique = combine_repeats(b, xs, ys, errs, num_starts);
    
    for (i = 0; i < num_unique; ++i) {
        if (num_unique > 1)
            printf("+");
        print_soln(b, xs + i * b.n, ys + i * b.m, errs[i]);
        if (print_extras)
            print_soln_extras(b, xs + i * b.n, ys + i * b.m);
    }
    
    free(errs);
}
