kusano 7d535a
/*
kusano 7d535a
 * -- SuperLU routine (version 4.0) --
kusano 7d535a
 * Univ. of California Berkeley, Xerox Palo Alto Research Center,
kusano 7d535a
 * and Lawrence Berkeley National Lab.
kusano 7d535a
 * June 30, 2009
kusano 7d535a
 *
kusano 7d535a
 */
kusano 7d535a
#include <stdio.h></stdio.h>
kusano 7d535a
#include "mex.h"
kusano 7d535a
#include "slu_ddefs.h"
kusano 7d535a
kusano 7d535a
kusano 7d535a
#ifdef V5
kusano 7d535a
#define  MatlabMatrix mxArray
kusano 7d535a
#else    /* V4 */
kusano 7d535a
#define  MatlabMatrix Matrix
kusano 7d535a
#endif
kusano 7d535a
kusano 7d535a
kusano 7d535a
kusano 7d535a
/* Aliases for input and output arguments */
kusano 7d535a
#define A_in		prhs[0]
kusano 7d535a
#define Pc_in		prhs[1]
kusano 7d535a
#define L_out    	plhs[0]
kusano 7d535a
#define U_out          	plhs[1]
kusano 7d535a
#define Pr_out     	plhs[2]
kusano 7d535a
#define Pc_out   	plhs[3]
kusano 7d535a
kusano 7d535a
void LUextract(SuperMatrix *, SuperMatrix *, double *, int *, int *, 
kusano 7d535a
	       double *, int *, int *, int *, int*);
kusano 7d535a
kusano 7d535a
#define verbose (SPUMONI>0)
kusano 7d535a
#define babble  (SPUMONI>1)
kusano 7d535a
#define burble  (SPUMONI>2)
kusano 7d535a
kusano 7d535a
void mexFunction(
kusano 7d535a
    int          nlhs,           /* number of expected outputs */
kusano 7d535a
    MatlabMatrix *plhs[],        /* matrix pointer array returning outputs */
kusano 7d535a
    int          nrhs,           /* number of inputs */
kusano 7d535a
#ifdef V5
kusano 7d535a
    const MatlabMatrix *prhs[]   /* matrix pointer array for inputs */
kusano 7d535a
#else /* V4 */
kusano 7d535a
    MatlabMatrix *prhs[]         /* matrix pointer array for inputs */
kusano 7d535a
#endif
kusano 7d535a
    )
kusano 7d535a
{
kusano 7d535a
    int SPUMONI;             /* ... as should the sparse monitor flag */
kusano 7d535a
#ifdef V5
kusano 7d535a
    double FlopsInSuperLU;   /* ... as should the flop counter */
kusano 7d535a
#else
kusano 7d535a
    Real FlopsInSuperLU;     /* ... as should the flop counter */
kusano 7d535a
#endif
kusano 7d535a
    extern flops_t LUFactFlops(SuperLUStat_t *);
kusano 7d535a
    
kusano 7d535a
    /* Arguments to C dgstrf(). */
kusano 7d535a
    SuperMatrix A;
kusano 7d535a
    SuperMatrix Ac;        /* Matrix postmultiplied by Pc */
kusano 7d535a
    SuperMatrix L, U;
kusano 7d535a
    int	   	m, n, nnz;
kusano 7d535a
    double      *val;
kusano 7d535a
    int       	*rowind;
kusano 7d535a
    int		*colptr;
kusano 7d535a
    int    	*etree, *perm_r, *perm_c;
kusano 7d535a
    int         panel_size, relax;
kusano 7d535a
    double      thresh = 1.0;       /* diagonal pivoting threshold */
kusano 7d535a
    int		info;
kusano 7d535a
    MatlabMatrix *X, *Y;            /* args to calls back to Matlab */
kusano 7d535a
    int         i, mexerr;
kusano 7d535a
    double      *dp;
kusano 7d535a
    double      *Lval, *Uval;
kusano 7d535a
    int         *Lrow, *Urow;
kusano 7d535a
    int         *Lcol, *Ucol;
kusano 7d535a
    int         nnzL, nnzU, snnzL, snnzU;
kusano 7d535a
    superlu_options_t options;
kusano 7d535a
    SuperLUStat_t stat;
kusano 7d535a
kusano 7d535a
    /* Check number of arguments passed from Matlab. */
kusano 7d535a
    if (nrhs != 2) {
kusano 7d535a
	mexErrMsgTxt("SUPERLU requires 2 input arguments.");
kusano 7d535a
    } else if (nlhs != 4) {
kusano 7d535a
      	mexErrMsgTxt("SUPERLU requires 4 output arguments.");
kusano 7d535a
    }   
kusano 7d535a
kusano 7d535a
    /* Read the Sparse Monitor Flag */
kusano 7d535a
    X = mxCreateString("spumoni");
kusano 7d535a
    mexerr = mexCallMATLAB(1, &Y, 1, &X, "sparsfun");
kusano 7d535a
    SPUMONI = mxGetScalar(Y);
kusano 7d535a
#ifdef V5
kusano 7d535a
    mxDestroyArray(Y);
kusano 7d535a
    mxDestroyArray(X);
kusano 7d535a
#else
kusano 7d535a
    mxFreeMatrix(Y);
kusano 7d535a
    mxFreeMatrix(X);
kusano 7d535a
#endif
kusano 7d535a
kusano 7d535a
    m = mxGetM(A_in);
kusano 7d535a
    n = mxGetN(A_in);
kusano 7d535a
    etree = (int *) mxCalloc(n, sizeof(int));
kusano 7d535a
    perm_r = (int *) mxCalloc(m, sizeof(int));
kusano 7d535a
    perm_c = mxGetIr(Pc_in); 
kusano 7d535a
    val = mxGetPr(A_in);
kusano 7d535a
    rowind = mxGetIr(A_in);
kusano 7d535a
    colptr = mxGetJc(A_in);
kusano 7d535a
    nnz = colptr[n];
kusano 7d535a
    dCreate_CompCol_Matrix(&A, m, n, nnz, val, rowind, colptr,
kusano 7d535a
			   SLU_NC, SLU_D, SLU_GE);
kusano 7d535a
    panel_size = sp_ienv(1);
kusano 7d535a
    relax      = sp_ienv(2);
kusano 7d535a
    thresh     = 1.0;
kusano 7d535a
    FlopsInSuperLU      = 0;
kusano 7d535a
kusano 7d535a
    set_default_options(&options);
kusano 7d535a
    StatInit(&stat);
kusano 7d535a
kusano 7d535a
    if ( verbose ) mexPrintf("Apply column perm to A and compute etree...\n");
kusano 7d535a
    sp_preorder(&options, &A, perm_c, etree, &Ac);
kusano 7d535a
kusano 7d535a
    if ( verbose ) {
kusano 7d535a
	mexPrintf("LU factorization...\n");
kusano 7d535a
	mexPrintf("\tpanel_size %d, relax %d, diag_pivot_thresh %.2g\n",
kusano 7d535a
		  panel_size, relax, thresh);
kusano 7d535a
    }
kusano 7d535a
kusano 7d535a
    dgstrf(&options, &Ac, relax, panel_size, etree,
kusano 7d535a
	   NULL, 0, perm_c, perm_r, &L, &U, &stat, &info);
kusano 7d535a
kusano 7d535a
    if ( verbose ) mexPrintf("INFO from dgstrf %d\n", info);
kusano 7d535a
kusano 7d535a
#if 0 /* FLOPS is not available in the new Matlab. */
kusano 7d535a
    /* Tell Matlab how many flops we did. */
kusano 7d535a
    FlopsInSuperLU += LUFactFlops(&stat);
kusano 7d535a
    if (verbose) mexPrintf("SUPERLU flops: %.f\n", FlopsInSuperLU);
kusano 7d535a
    mexerr = mexCallMATLAB(1, &X, 0, NULL, "flops");
kusano 7d535a
    *(mxGetPr(X)) += FlopsInSuperLU;
kusano 7d535a
    mexerr = mexCallMATLAB(1, &Y, 1, &X, "flops");
kusano 7d535a
#ifdef V5
kusano 7d535a
    mxDestroyArray(Y);
kusano 7d535a
    mxDestroyArray(X);
kusano 7d535a
#else
kusano 7d535a
    mxFreeMatrix(Y);
kusano 7d535a
    mxFreeMatrix(X);
kusano 7d535a
#endif
kusano 7d535a
#endif
kusano 7d535a
	
kusano 7d535a
    /* Construct output arguments for Matlab. */
kusano 7d535a
    if ( info >= 0 && info <= n ) {
kusano 7d535a
#ifdef V5
kusano 7d535a
	Pr_out = mxCreateDoubleMatrix(m, 1, mxREAL);
kusano 7d535a
#else
kusano 7d535a
	Pr_out = mxCreateFull(m, 1, REAL);
kusano 7d535a
#endif
kusano 7d535a
	dp = mxGetPr(Pr_out);
kusano 7d535a
	for (i = 0; i < m; *dp++ = (double) perm_r[i++]+1);
kusano 7d535a
#ifdef V5
kusano 7d535a
	Pc_out = mxCreateDoubleMatrix(n, 1, mxREAL);
kusano 7d535a
#else
kusano 7d535a
	Pc_out = mxCreateFull(n, 1, REAL);
kusano 7d535a
#endif
kusano 7d535a
	dp = mxGetPr(Pc_out);
kusano 7d535a
	for (i = 0; i < n; *dp++ = (double) perm_c[i++]+1);
kusano 7d535a
	
kusano 7d535a
	/* Now for L and U */
kusano 7d535a
	nnzL = ((SCformat*)L.Store)->nnz; /* count diagonals */
kusano 7d535a
   	nnzU = ((NCformat*)U.Store)->nnz;
kusano 7d535a
kusano 7d535a
#ifdef V5
kusano 7d535a
	L_out = mxCreateSparse(m, n, nnzL, mxREAL);
kusano 7d535a
#else
kusano 7d535a
	L_out = mxCreateSparse(m, n, nnzL, REAL);
kusano 7d535a
#endif
kusano 7d535a
	Lval = mxGetPr(L_out);
kusano 7d535a
	Lrow = mxGetIr(L_out);
kusano 7d535a
	Lcol = mxGetJc(L_out);
kusano 7d535a
kusano 7d535a
#ifdef V5
kusano 7d535a
	U_out = mxCreateSparse(m, n, nnzU, mxREAL);
kusano 7d535a
#else
kusano 7d535a
	U_out = mxCreateSparse(m, n, nnzU, REAL);
kusano 7d535a
#endif
kusano 7d535a
	Uval = mxGetPr(U_out);
kusano 7d535a
	Urow = mxGetIr(U_out);
kusano 7d535a
	Ucol = mxGetJc(U_out);
kusano 7d535a
kusano 7d535a
	LUextract(&L, &U, Lval, Lrow, Lcol, Uval, Urow, Ucol, &snnzL, &snnzU);
kusano 7d535a
	
kusano 7d535a
        Destroy_CompCol_Permuted(&Ac);
kusano 7d535a
	Destroy_SuperNode_Matrix(&L);
kusano 7d535a
	Destroy_CompCol_Matrix(&U);
kusano 7d535a
kusano 7d535a
	if (babble) mexPrintf("factor nonzeros: %d unsqueezed, %d squeezed.\n",
kusano 7d535a
			      nnzL + nnzU, snnzL + snnzU);
kusano 7d535a
    } else {
kusano 7d535a
	mexErrMsgTxt("Error returned from C dgstrf().");
kusano 7d535a
    }
kusano 7d535a
kusano 7d535a
    mxFree(etree);
kusano 7d535a
    mxFree(perm_r);
kusano 7d535a
    StatFree(&stat);
kusano 7d535a
    return;
kusano 7d535a
}
kusano 7d535a
kusano 7d535a
void
kusano 7d535a
LUextract(SuperMatrix *L, SuperMatrix *U, double *Lval, int *Lrow,
kusano 7d535a
	  int *Lcol, double *Uval, int *Urow, int *Ucol, int *snnzL,
kusano 7d535a
	  int *snnzU)
kusano 7d535a
{
kusano 7d535a
    int         i, j, k;
kusano 7d535a
    int         upper;
kusano 7d535a
    int         fsupc, istart, nsupr;
kusano 7d535a
    int         lastl = 0, lastu = 0;
kusano 7d535a
    SCformat    *Lstore;
kusano 7d535a
    NCformat    *Ustore;
kusano 7d535a
    double      *SNptr;
kusano 7d535a
kusano 7d535a
    Lstore = L->Store;
kusano 7d535a
    Ustore = U->Store;
kusano 7d535a
    Lcol[0] = 0;
kusano 7d535a
    Ucol[0] = 0;
kusano 7d535a
    
kusano 7d535a
    /* for each supernode */
kusano 7d535a
    for (k = 0; k <= Lstore->nsuper; ++k) {
kusano 7d535a
	
kusano 7d535a
	fsupc = L_FST_SUPC(k);
kusano 7d535a
	istart = L_SUB_START(fsupc);
kusano 7d535a
	nsupr = L_SUB_START(fsupc+1) - istart;
kusano 7d535a
	upper = 1;
kusano 7d535a
	
kusano 7d535a
	/* for each column in the supernode */
kusano 7d535a
	for (j = fsupc; j < L_FST_SUPC(k+1); ++j) {
kusano 7d535a
	    SNptr = &((double*)Lstore->nzval)[L_NZ_START(j)];
kusano 7d535a
kusano 7d535a
	    /* Extract U */
kusano 7d535a
	    for (i = U_NZ_START(j); i < U_NZ_START(j+1); ++i) {
kusano 7d535a
		Uval[lastu] = ((double*)Ustore->nzval)[i];
kusano 7d535a
 		/* Matlab doesn't like explicit zero. */
kusano 7d535a
		if (Uval[lastu] != 0.0) Urow[lastu++] = U_SUB(i);
kusano 7d535a
	    }
kusano 7d535a
	    for (i = 0; i < upper; ++i) { /* upper triangle in the supernode */
kusano 7d535a
		Uval[lastu] = SNptr[i];
kusano 7d535a
 		/* Matlab doesn't like explicit zero. */
kusano 7d535a
		if (Uval[lastu] != 0.0) Urow[lastu++] = L_SUB(istart+i);
kusano 7d535a
	    }
kusano 7d535a
	    Ucol[j+1] = lastu;
kusano 7d535a
kusano 7d535a
	    /* Extract L */
kusano 7d535a
	    Lval[lastl] = 1.0; /* unit diagonal */
kusano 7d535a
	    Lrow[lastl++] = L_SUB(istart + upper - 1);
kusano 7d535a
	    for (i = upper; i < nsupr; ++i) {
kusano 7d535a
		Lval[lastl] = SNptr[i];
kusano 7d535a
 		/* Matlab doesn't like explicit zero. */
kusano 7d535a
		if (Lval[lastl] != 0.0) Lrow[lastl++] = L_SUB(istart+i);
kusano 7d535a
	    }
kusano 7d535a
	    Lcol[j+1] = lastl;
kusano 7d535a
kusano 7d535a
	    ++upper;
kusano 7d535a
	    
kusano 7d535a
	} /* for j ... */
kusano 7d535a
	
kusano 7d535a
    } /* for k ... */
kusano 7d535a
kusano 7d535a
    *snnzL = lastl;
kusano 7d535a
    *snnzU = lastu;
kusano 7d535a
}