/*
   SVD.c
   singular value decomposition of an m-by-n matrix
   this file is part of the Diag library
   last modified 21 Aug 15 th
*/

#include "diag-c.h"


/*
   SVD performs a singular value decomposition.
   Input:	m, n, A = m-by-n matrix.
   Output:	d = nm-vector of singular values, nm = min(m, n),
   for UCOLS=0:	V = nm-by-m left transformation matrix,
		W = nm-by-n right transformation matrix,
   which fulfill d = V^* A W^+,  A = V^T d W,  V^* A = d W,
   for UCOLS=1:	V = m-by-nm left transformation matrix,
		W = n-by-nm right transformation matrix,
   which fulfill d = V^+ A W^*,  A = V d W^T,  A W^* = V d.
*/

void SVD(cint m, cint n, cComplexType *A, cint ldA,
  RealType *d, ComplexType *V, cint ldV, ComplexType *W, cint ldW,
  cint sort)
{
  int p, q;
  cint rev = ((m - n) >> 15) & 1;
  cint nm = imin(m, n);
  int pi[nm];
  cint nx = imax(m, n);
  cRealType red = .01/(nx*nx*nx*nx);
  ComplexType VW[3][nx][nx];
  ComplexType *V_ = (ComplexType *)VW[0];
  ComplexType *W_ = (ComplexType *)VW[1];
  ComplexType *A_ = (ComplexType *)VW[2];
#define ldV_ nx
#define ldW_ nx
#define ldA_ nx

  memset(VW, 0, sizeof VW);
  for( p = 0; p < nx; ++p ) V_(p,p) = 1;
  for( p = 0; p < nx; ++p ) W_(p,p) = 1;

  if( rev )
    for( q = 0; q < m; ++q )
      for( p = 0; p < n; ++p )
        A_(p,q) = A(q,p);
  else
    for( q = 0; q < m; ++q )
      for( p = 0; p < n; ++p )
        A_(q,p) = A(q,p);

  for( nsweeps = 1; nsweeps <= 50; ++nsweeps ) {
    RealType thresh = 0;
    for( q = 1; q < nx; ++q )
      for( p = 0; p < q; ++p )
        thresh += Sq(A_(p,q)) + Sq(A_(q,p));
    if( !(thresh > EPS) ) goto done;

    thresh = (nsweeps < 4) ? thresh*red : 0;

    for( q = 1; q < nx; ++q )
      for( p = 0; p < q; ++p ) {
        ComplexType App, Apq, Aqp, Aqq;
        RealType off;
        int px = p, qx = q;
        if( Sq(A_(p,p)) + Sq(A_(q,q)) <
            Sq(A_(p,q)) + Sq(A_(q,p)) )
          px = q, qx = p;

        App = A_(px,p);
        Aqq = A_(qx,q);
        Apq = A_(px,q);
        Aqp = A_(qx,p);
        off = Sq(Apq) + Sq(Aqp);

        if( nsweeps > 4 && off < EPS*(Sq(App) + Sq(Aqq)) ) {
          A_(px,q) = 0;
          A_(qx,p) = 0;
        }
        else if( off > thresh ) {
          RealType xv, xw, dv, dw, t, invc;
          ComplexType sv, sw, tv, tw;
          int j;

          xv = Re((App - Aqq)*Conjugate(App + Aqq));
          xw = Re((Apq - Aqp)*Conjugate(Apq + Aqp));
          dv = .5*(xv + xw);
          dw = .5*(xv - xw);

          tv = Conjugate(App)*Aqp + Aqq*Conjugate(Apq);
          tw = Conjugate(App)*Apq + Aqq*Conjugate(Aqp);
          t = sqrt(dw*dw + Sq(tw)) /* = sqrt(dv*dv + Sq(tv)) */;

          xv = min(absr(dv + t), absr(dw + t));
          xw = min(absr(dv - t), absr(dw - t));
          if( xv + xw > DBL_EPS ) {
            t = sign(t, xv - xw);
            tv /= dv + t;
            tw /= dw + t;
          }
          else {
            tv = 0;
            tw = Apq/App;
          }

          invc = sqrt(1 + Sq(tv));
          sv = tv/invc;
          tv /= invc + 1;

          invc = sqrt(1 + Sq(tw));
          sw = tw/invc;
          tw /= invc + 1;

          for( j = 0; j < nx; ++j ) {
            ComplexType x = A_(j,p);
            ComplexType y = A_(j,q);
            A_(j,p) = x + Conjugate(sw)*(y - tw*x);
            A_(j,q) = y - sw*(x + Conjugate(tw)*y);
            x = A_(px,j);
            y = A_(qx,j);
            A_(p,j) = x + Conjugate(sv)*(y - tv*x);
            A_(q,j) = y - sv*(x + Conjugate(tv)*y);
          }

          A_(p,p) = invc*(App + Conjugate(sv)*(Aqp - tv*App));
          A_(q,q) = invc*(Aqq - sv*(Apq + Conjugate(tv)*Aqq));
          A_(p,q) = 0;
          A_(q,p) = 0;

          for( j = 0; j < nx; ++j ) {
            cComplexType x = V_(px,j);
            cComplexType y = V_(qx,j);
            V_(p,j) = x + sv*(y - Conjugate(tv)*x);
            V_(q,j) = y - Conjugate(sv)*(x + tv*y);
          }

          for( j = 0; j < nx; ++j ) {
            cComplexType x = W_(p,j);
            cComplexType y = W_(q,j);
            W_(p,j) = x + sw*(y - Conjugate(tw)*x);
            W_(q,j) = y - Conjugate(sw)*(x + tw*y);
          }

          continue;
        }

        if( p != px ) {
          int j;

          for( j = 0; j < nx; ++j ) {
            cComplexType x = A_(p,j);
            A_(p,j) = A_(q,j);
            A_(q,j) = x;
          }

          for( j = 0; j < nx; ++j ) {
            cComplexType x = V_(p,j);
            V_(p,j) = V_(q,j);
            V_(q,j) = x;
          }
        }
      }
  }

  fputs("Bad convergence in SVD\n", stderr);

done:

/* make the diagonal elements nonnegative */

  for( p = 0; p < nm; ++p ) {
    cComplexType App = A_(p,p);
    d[p] = Abs(App);
    if( d[p] > DBL_EPS && d[p] != Re(App) ) {
      cComplexType f = App/d[p];
      for( q = 0; q < nm; ++q ) W_(p,q) *= f;
    }
  }

/* sort the eigenvalues */

  for( p = 0; p < nm; ++p ) pi[p] = p;

  for( p = 0; p < nm; ++p ) {
    int j = p;
    RealType t = d[p];
    if( sort )
      for( q = p + 1; q < nm; ++q )
        if( sort*(t - d[q]) > 0 ) t = d[j = q];

    d[j] = d[p];
    d[p] = t;

    q = pi[j];
    pi[j] = pi[p];

    for( j = 0; j < m; ++j ) VL(p,j) = VW[rev][q][j];
    for( j = 0; j < n; ++j ) WL(p,j) = VW[1-rev][q][j];
  }
}

