/* Copyright (C) 1996-2001, 2004, 2007, 2009 Gerard Jungman, Brian Gough
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or (at
 * your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses/>.
 */

/* Derived from GNU Scientific Library, version 1.13, files as noted below */

#include "lu.h"
#include "gsl.h"
#include <math.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>

/* From vector/gsl_vector_double.h */

int gsl_vector_memcpy (gsl_vector * dest, const gsl_vector * src);

/* From matrix/gsl_matrix_double.h */

int gsl_matrix_swap_rows(gsl_matrix * m, const size_t i, const size_t j);

/* From permutation/gsl_permutation.h */

void gsl_permutation_init (gsl_permutation * p);

int gsl_permutation_swap (gsl_permutation * p, const size_t i, const size_t j);

/* From cblas/gsl_cblas.h */

enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};

void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
                 const int N, const double *A, const int lda, double *X,
                 const int incX);

void cblas_xerbla(int p, const char *rout, const char *form, ...);

/* From cblas/cblas.h */

#define INDEX int
#define OFFSET(N, incX) ((incX) > 0 ?  0 : ((N) - 1) * (-(incX)))
#define BLAS_ERROR(x)  cblas_xerbla(0, __FILE__, x); 

/* From blas/gsl_blas_types.h */

typedef  enum CBLAS_ORDER       CBLAS_ORDER_t;
typedef  enum CBLAS_TRANSPOSE   CBLAS_TRANSPOSE_t;
typedef  enum CBLAS_UPLO        CBLAS_UPLO_t;
typedef  enum CBLAS_DIAG        CBLAS_DIAG_t;
typedef  enum CBLAS_SIDE        CBLAS_SIDE_t;

/* From blas/gsl_blas.h */

int  gsl_blas_dtrsv (CBLAS_UPLO_t Uplo,
                     CBLAS_TRANSPOSE_t TransA, CBLAS_DIAG_t Diag,
                     const gsl_matrix * A,
                     gsl_vector * X);

/* From permutation/gsl_permute_vector_double.h */

int gsl_permute_vector (const gsl_permutation * p, gsl_vector * v);

/* From linalg/lu.c */

#define REAL double

/* Factorise a general N x N matrix A into,
 *
 *   P A = L U
 *
 * where P is a permutation matrix, L is unit lower triangular and U
 * is upper triangular.
 *
 * L is stored in the strict lower triangular part of the input
 * matrix. The diagonal elements of L are unity and are not stored.
 *
 * U is stored in the diagonal and upper triangular part of the
 * input matrix.  
 * 
 * P is stored in the permutation p. Column j of P is column k of the
 * identity matrix, where k = permutation->data[j]
 *
 * signum gives the sign of the permutation, (-1)^n, where n is the
 * number of interchanges in the permutation. 
 *
 * See Golub & Van Loan, Matrix Computations, Algorithm 3.4.1 (Gauss
 * Elimination with Partial Pivoting).
 */

int
gsl_linalg_LU_decomp (gsl_matrix * A, gsl_permutation * p, int *signum)
{
  if (A->size1 != A->size2)
    {
      GSL_ERROR ("LU decomposition requires square matrix", GSL_ENOTSQR);
    }
  else if (p->size != A->size1)
    {
      GSL_ERROR ("permutation length must match matrix size", GSL_EBADLEN);
    }
  else
    {
      const size_t N = A->size1;
      size_t i, j, k;

      *signum = 1;
      gsl_permutation_init (p);

      for (j = 0; j < N - 1; j++)
        {
          /* Find maximum in the j-th column */

          REAL ajj, max = fabs (gsl_matrix_get (A, j, j));
          size_t i_pivot = j;

          for (i = j + 1; i < N; i++)
            {
              REAL aij = fabs (gsl_matrix_get (A, i, j));

              if (aij > max)
                {
                  max = aij;
                  i_pivot = i;
                }
            }

          if (i_pivot != j)
            {
              gsl_matrix_swap_rows (A, j, i_pivot);
              gsl_permutation_swap (p, j, i_pivot);
              *signum = -(*signum);
            }

          ajj = gsl_matrix_get (A, j, j);

          if (ajj != 0.0)
            {
              for (i = j + 1; i < N; i++)
                {
                  REAL aij = gsl_matrix_get (A, i, j) / ajj;
                  gsl_matrix_set (A, i, j, aij);

                  for (k = j + 1; k < N; k++)
                    {
                      REAL aik = gsl_matrix_get (A, i, k);
                      REAL ajk = gsl_matrix_get (A, j, k);
                      gsl_matrix_set (A, i, k, aik - aij * ajk);
                    }
                }
            }
        }
      
      return GSL_SUCCESS;
    }
}

int
gsl_linalg_LU_solve (const gsl_matrix * LU, const gsl_permutation * p, const gsl_vector * b, gsl_vector * x)
{
  if (LU->size1 != LU->size2)
    {
      GSL_ERROR ("LU matrix must be square", GSL_ENOTSQR);
    }
  else if (LU->size1 != p->size)
    {
      GSL_ERROR ("permutation length must match matrix size", GSL_EBADLEN);
    }
  else if (LU->size1 != b->size)
    {
      GSL_ERROR ("matrix size must match b size", GSL_EBADLEN);
    }
  else if (LU->size2 != x->size)
    {
      GSL_ERROR ("matrix size must match solution size", GSL_EBADLEN);
    }
  else if (gsl_singular (LU)) 
    {
      GSL_ERROR ("matrix is singular", GSL_EDOM);
    }
  else
    {
      int status;

      /* Copy x <- b */

      gsl_vector_memcpy (x, b);

      /* Solve for x */

      status = gsl_linalg_LU_svx (LU, p, x);

      return status;
    }
}

int
gsl_linalg_LU_svx (const gsl_matrix * LU, const gsl_permutation * p, gsl_vector * x)
{
  if (LU->size1 != LU->size2)
    {
      GSL_ERROR ("LU matrix must be square", GSL_ENOTSQR);
    }
  else if (LU->size1 != p->size)
    {
      GSL_ERROR ("permutation length must match matrix size", GSL_EBADLEN);
    }
  else if (LU->size1 != x->size)
    {
      GSL_ERROR ("matrix size must match solution/rhs size", GSL_EBADLEN);
    }
  else if (gsl_singular (LU)) 
    {
      GSL_ERROR ("matrix is singular", GSL_EDOM);
    }
  else
    {
      /* Apply permutation to RHS */

      gsl_permute_vector (p, x);

      /* Solve for c using forward-substitution, L c = P b */

      gsl_blas_dtrsv (CblasLower, CblasNoTrans, CblasUnit, LU, x);

      /* Perform back-substitution, U x = c */

      gsl_blas_dtrsv (CblasUpper, CblasNoTrans, CblasNonUnit, LU, x);

      return GSL_SUCCESS;
    }
}

int
gsl_singular (const gsl_matrix * LU)
{
  size_t i, n = LU->size1;

  for (i = 0; i < n; i++)
    {
      double u = gsl_matrix_get (LU, i, i);
      if (u == 0) return 1;
    }
 
 return 0;
}

/* From permutation/init.c */

gsl_permutation *
gsl_permutation_alloc (const size_t n)
{
  gsl_permutation * p;

  if (n == 0)
    {
      GSL_ERROR_VAL ("permutation length n must be positive integer",
                        GSL_EDOM, 0);
    }

  p = (gsl_permutation *) malloc (sizeof (gsl_permutation));

  if (p == 0)
    {
      GSL_ERROR_VAL ("failed to allocate space for permutation struct",
                        GSL_ENOMEM, 0);
    }

  p->data = (size_t *) malloc (n * sizeof (size_t));

  if (p->data == 0)
    {
      free (p);         /* exception in constructor, avoid memory leak */

      GSL_ERROR_VAL ("failed to allocate space for permutation data",
                        GSL_ENOMEM, 0);
    }

  p->size = n;

  return p;
}

void
gsl_permutation_init (gsl_permutation * p)
{
  const size_t n = p->size ;
  size_t i;

  /* initialize permutation to identity */

  for (i = 0; i < n; i++)
    {
      p->data[i] = i;
    }
}

void
gsl_permutation_free (gsl_permutation * p)
{
  RETURN_IF_NULL (p);
  free (p->data);
  free (p);
}

/* From permutation/permutation.c */

int
gsl_permutation_swap (gsl_permutation * p, const size_t i, const size_t j)
{
  const size_t size = p->size ;
  
  if (i >= size)
    {
      GSL_ERROR("first index is out of range", GSL_EINVAL);
    }

  if (j >= size)
    {
      GSL_ERROR("second index is out of range", GSL_EINVAL);
    }

  if (i != j)
    {
      size_t tmp = p->data[i];
      p->data[i] = p->data[j];
      p->data[j] = tmp;
    }
  
  return GSL_SUCCESS;
}

/* From blas/blas.c */

#define INT(X) ((int)(X))

int
gsl_blas_dtrsv (CBLAS_UPLO_t Uplo, CBLAS_TRANSPOSE_t TransA,
                CBLAS_DIAG_t Diag, const gsl_matrix * A, gsl_vector * X)
{
  const size_t M = A->size1;
  const size_t N = A->size2;

  if (M != N)
    {
      GSL_ERROR ("matrix must be square", GSL_ENOTSQR);
    }
  else if (N != X->size)
    {
      GSL_ERROR ("invalid length", GSL_EBADLEN);
    }

  cblas_dtrsv (CblasRowMajor, Uplo, TransA, Diag, INT (N), A->data,
               INT (A->tda), X->data, INT (X->stride));
  return GSL_SUCCESS;
}

/* From cblas/dtrsv.c and cblas/source_trsv_r.h */

void
cblas_dtrsv (const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
             const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
             const int N, const double *A, const int lda, double *X,
             const int incX)
{
#define BASE double
  const int nonunit = (Diag == CblasNonUnit);
  INDEX ix, jx;
  INDEX i, j;
  const int Trans = (TransA != CblasConjTrans) ? TransA : CblasTrans;

  if (N == 0)
    return;

  /* form  x := inv( A )*x */

  if ((order == CblasRowMajor && Trans == CblasNoTrans && Uplo == CblasUpper)
      || (order == CblasColMajor && Trans == CblasTrans && Uplo == CblasLower)) {
    /* backsubstitution */
    ix = OFFSET(N, incX) + incX * (N - 1);
    if (nonunit) {
      X[ix] = X[ix] / A[lda * (N - 1) + (N - 1)];
    }
    ix -= incX;
    for (i = N - 1; i > 0 && i--;) {
      BASE tmp = X[ix];
      jx = ix + incX;
      for (j = i + 1; j < N; j++) {
        const BASE Aij = A[lda * i + j];
        tmp -= Aij * X[jx];
        jx += incX;
      }
      if (nonunit) {
        X[ix] = tmp / A[lda * i + i];
      } else {
        X[ix] = tmp;
      }
      ix -= incX;
    }
  } else if ((order == CblasRowMajor && Trans == CblasNoTrans && Uplo == CblasLower)
             || (order == CblasColMajor && Trans == CblasTrans && Uplo == CblasUpper)) {

    /* forward substitution */
    ix = OFFSET(N, incX);
    if (nonunit) {
      X[ix] = X[ix] / A[lda * 0 + 0];
    }
    ix += incX;
    for (i = 1; i < N; i++) {
      BASE tmp = X[ix];
      jx = OFFSET(N, incX);
      for (j = 0; j < i; j++) {
        const BASE Aij = A[lda * i + j];
        tmp -= Aij * X[jx];
        jx += incX;
      }
      if (nonunit) {
        X[ix] = tmp / A[lda * i + i];
      } else {
        X[ix] = tmp;
      }
      ix += incX;
    }
  } else if ((order == CblasRowMajor && Trans == CblasTrans && Uplo == CblasUpper)
             || (order == CblasColMajor && Trans == CblasNoTrans && Uplo == CblasLower)) {

    /* form  x := inv( A' )*x */

    /* forward substitution */
    ix = OFFSET(N, incX);
    if (nonunit) {
      X[ix] = X[ix] / A[lda * 0 + 0];
    }
    ix += incX;
    for (i = 1; i < N; i++) {
      BASE tmp = X[ix];
      jx = OFFSET(N, incX);
      for (j = 0; j < i; j++) {
        const BASE Aji = A[lda * j + i];
        tmp -= Aji * X[jx];
        jx += incX;
      }
      if (nonunit) {
        X[ix] = tmp / A[lda * i + i];
      } else {
        X[ix] = tmp;
      }
      ix += incX;
    }
  } else if ((order == CblasRowMajor && Trans == CblasTrans && Uplo == CblasLower)
             || (order == CblasColMajor && Trans == CblasNoTrans && Uplo == CblasUpper)) {

    /* backsubstitution */
    ix = OFFSET(N, incX) + (N - 1) * incX;
    if (nonunit) {
      X[ix] = X[ix] / A[lda * (N - 1) + (N - 1)];
    }
    ix -= incX;
    for (i = N - 1; i > 0 && i--;) {
      BASE tmp = X[ix];
      jx = ix + incX;
      for (j = i + 1; j < N; j++) {
        const BASE Aji = A[lda * j + i];
        tmp -= Aji * X[jx];
        jx += incX;
      }
      if (nonunit) {
        X[ix] = tmp / A[lda * i + i];
      } else {
        X[ix] = tmp;
      }
      ix -= incX;
    }
  } else {
    BLAS_ERROR("unrecognized operation");
  }
#undef BASE
}

/* From cblas/xerbla.c */

void
cblas_xerbla (int p, const char *rout, const char *form, ...)
{
  va_list ap;

  va_start (ap, form);

  if (p)
    {
      fprintf (stderr, "Parameter %d to routine %s was incorrect\n", p, rout);
    }

  vfprintf (stderr, form, ap);
  va_end (ap);

  abort ();
}

/* From templates_on.h */

#define BASE double
#define SHORT
#define ATOMIC double
#define MULTIPLICITY 1
#define FP 1
#define IN_FORMAT "%lg"
#define OUT_FORMAT "%g"
#define ATOMIC_IO ATOMIC
#define ZERO 0.0
#define ONE 1.0
#define BASE_EPSILON GSL_DBL_EPSILON

#define CONCAT2x(a,b) a ## _ ## b 
#define CONCAT2(a,b) CONCAT2x(a,b)

#define FUNCTION(dir,name) CONCAT2(dir,name)
#define TYPE(dir) dir
#define VIEW(dir,name) CONCAT2(dir,name)
#define QUALIFIED_TYPE(dir) TYPE(dir)
#define QUALIFIED_VIEW(dir,name) CONCAT2(dir,name)

/* From vector/copy_source.c */

int
FUNCTION (gsl_vector, memcpy) (TYPE (gsl_vector) * dest,
                               const TYPE (gsl_vector) * src)
{
  const size_t src_size = src->size;
  const size_t dest_size = dest->size;

  if (src_size != dest_size)
    {
      GSL_ERROR ("vector lengths are not equal", GSL_EBADLEN);
    }

  {
    const size_t src_stride = src->stride ;
    const size_t dest_stride = dest->stride ;
    size_t j;

    for (j = 0; j < src_size; j++)
      {
        size_t k;

        for (k = 0; k < MULTIPLICITY; k++) 
          {
            dest->data[MULTIPLICITY * dest_stride * j + k] 
              = src->data[MULTIPLICITY * src_stride * j + k];
          }
      }
  }

  return GSL_SUCCESS;
}


/* From matrix/swap_source.c */

int
FUNCTION (gsl_matrix, swap_rows) (TYPE (gsl_matrix) * m,
                                 const size_t i, const size_t j)
{
  const size_t size1 = m->size1;
  const size_t size2 = m->size2;

  if (i >= size1)
    {
      GSL_ERROR ("first row index is out of range", GSL_EINVAL);
    }

  if (j >= size1)
    {
      GSL_ERROR ("second row index is out of range", GSL_EINVAL);
    }

  if (i != j)
    {
      ATOMIC *row1 = m->data + MULTIPLICITY * i * m->tda;
      ATOMIC *row2 = m->data + MULTIPLICITY * j * m->tda;
      
      size_t k;
      
      for (k = 0; k < MULTIPLICITY * size2; k++)
        {
          ATOMIC tmp = row1[k] ;
          row1[k] = row2[k] ;
          row2[k] = tmp ;
        }
    }

  return GSL_SUCCESS;
}

/* From permutation/permute_source.c */

/* In-place Permutations 

   permute:    OUT[i]       = IN[perm[i]]     i = 0 .. N-1
   invpermute: OUT[perm[i]] = IN[i]           i = 0 .. N-1

   PERM is an index map, i.e. a vector which contains a permutation of
   the integers 0 .. N-1.

   From Knuth "Sorting and Searching", Volume 3 (3rd ed), Section 5.2
   Exercise 10 (answers), p 617

   FIXME: these have not been fully tested.
*/

int
TYPE (gsl_permute) (const size_t * p, ATOMIC * data, const size_t stride, const size_t n)
{
  size_t i, k, pk;

  for (i = 0; i < n; i++)
    {
      k = p[i];
      
      while (k > i) 
        k = p[k];
      
      if (k < i)
        continue ;
      
      /* Now have k == i, i.e the least in its cycle */
      
      pk = p[k];
      
      if (pk == i)
        continue ;
      
      /* shuffle the elements of the cycle */
      
      {
        unsigned int a;

        ATOMIC t[MULTIPLICITY];
        
        for (a = 0; a < MULTIPLICITY; a++)
          t[a] = data[i*stride*MULTIPLICITY + a];
      
        while (pk != i)
          {
            for (a = 0; a < MULTIPLICITY; a++)
              {
                ATOMIC r1 = data[pk*stride*MULTIPLICITY + a];
                data[k*stride*MULTIPLICITY + a] = r1;
              }
            k = pk;
            pk = p[k];
          };
        
        for (a = 0; a < MULTIPLICITY; a++)
          data[k*stride*MULTIPLICITY + a] = t[a];
      }
    }

  return GSL_SUCCESS;
}

int
TYPE (gsl_permute_vector) (const gsl_permutation * p, TYPE (gsl_vector) * v)
{
  if (v->size != p->size)
    {
      GSL_ERROR ("vector and permutation must be the same length", GSL_EBADLEN);
    }

  TYPE (gsl_permute) (p->data, v->data, v->stride, v->size) ;

  return GSL_SUCCESS;
}
