kusano 7d535a
kusano 7d535a
/*! @file sgstrs.c
kusano 7d535a
 * \brief Solves a system using LU factorization
kusano 7d535a
 *
kusano 7d535a
 * 
kusano 7d535a
 * -- SuperLU routine (version 3.0) --
kusano 7d535a
 * Univ. of California Berkeley, Xerox Palo Alto Research Center,
kusano 7d535a
 * and Lawrence Berkeley National Lab.
kusano 7d535a
 * October 15, 2003
kusano 7d535a
 *
kusano 7d535a
 * Copyright (c) 1994 by Xerox Corporation.  All rights reserved.
kusano 7d535a
 *
kusano 7d535a
 * THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
kusano 7d535a
 * EXPRESSED OR IMPLIED.  ANY USE IS AT YOUR OWN RISK.
kusano 7d535a
 *
kusano 7d535a
 * Permission is hereby granted to use or copy this program for any
kusano 7d535a
 * purpose, provided the above notices are retained on all copies.
kusano 7d535a
 * Permission to modify the code and to distribute modified code is
kusano 7d535a
 * granted, provided the above notices are retained, and a notice that
kusano 7d535a
 * the code was modified is included with the above copyright notice.
kusano 7d535a
 * 
kusano 7d535a
 */
kusano 7d535a
kusano 7d535a
#include "slu_sdefs.h"
kusano 7d535a
kusano 7d535a
kusano 7d535a
/* 
kusano 7d535a
 * Function prototypes 
kusano 7d535a
 */
kusano 7d535a
void susolve(int, int, float*, float*);
kusano 7d535a
void slsolve(int, int, float*, float*);
kusano 7d535a
void smatvec(int, int, int, float*, float*, float*);
kusano 7d535a
kusano 7d535a
/*! \brief
kusano 7d535a
 *
kusano 7d535a
 * 
kusano 7d535a
 * Purpose
kusano 7d535a
 * =======
kusano 7d535a
 *
kusano 7d535a
 * SGSTRS solves a system of linear equations A*X=B or A'*X=B
kusano 7d535a
 * with A sparse and B dense, using the LU factorization computed by
kusano 7d535a
 * SGSTRF.
kusano 7d535a
 *
kusano 7d535a
 * See supermatrix.h for the definition of 'SuperMatrix' structure.
kusano 7d535a
 *
kusano 7d535a
 * Arguments
kusano 7d535a
 * =========
kusano 7d535a
 *
kusano 7d535a
 * trans   (input) trans_t
kusano 7d535a
 *          Specifies the form of the system of equations:
kusano 7d535a
 *          = NOTRANS: A * X = B  (No transpose)
kusano 7d535a
 *          = TRANS:   A'* X = B  (Transpose)
kusano 7d535a
 *          = CONJ:    A**H * X = B  (Conjugate transpose)
kusano 7d535a
 *
kusano 7d535a
 * L       (input) SuperMatrix*
kusano 7d535a
 *         The factor L from the factorization Pr*A*Pc=L*U as computed by
kusano 7d535a
 *         sgstrf(). Use compressed row subscripts storage for supernodes,
kusano 7d535a
 *         i.e., L has types: Stype = SLU_SC, Dtype = SLU_S, Mtype = SLU_TRLU.
kusano 7d535a
 *
kusano 7d535a
 * U       (input) SuperMatrix*
kusano 7d535a
 *         The factor U from the factorization Pr*A*Pc=L*U as computed by
kusano 7d535a
 *         sgstrf(). Use column-wise storage scheme, i.e., U has types:
kusano 7d535a
 *         Stype = SLU_NC, Dtype = SLU_S, Mtype = SLU_TRU.
kusano 7d535a
 *
kusano 7d535a
 * perm_c  (input) int*, dimension (L->ncol)
kusano 7d535a
 *	   Column permutation vector, which defines the 
kusano 7d535a
 *         permutation matrix Pc; perm_c[i] = j means column i of A is 
kusano 7d535a
 *         in position j in A*Pc.
kusano 7d535a
 *
kusano 7d535a
 * perm_r  (input) int*, dimension (L->nrow)
kusano 7d535a
 *         Row permutation vector, which defines the permutation matrix Pr; 
kusano 7d535a
 *         perm_r[i] = j means row i of A is in position j in Pr*A.
kusano 7d535a
 *
kusano 7d535a
 * B       (input/output) SuperMatrix*
kusano 7d535a
 *         B has types: Stype = SLU_DN, Dtype = SLU_S, Mtype = SLU_GE.
kusano 7d535a
 *         On entry, the right hand side matrix.
kusano 7d535a
 *         On exit, the solution matrix if info = 0;
kusano 7d535a
 *
kusano 7d535a
 * stat     (output) SuperLUStat_t*
kusano 7d535a
 *          Record the statistics on runtime and floating-point operation count.
kusano 7d535a
 *          See util.h for the definition of 'SuperLUStat_t'.
kusano 7d535a
 *
kusano 7d535a
 * info    (output) int*
kusano 7d535a
 * 	   = 0: successful exit
kusano 7d535a
 *	   < 0: if info = -i, the i-th argument had an illegal value
kusano 7d535a
 * 
kusano 7d535a
 */
kusano 7d535a
kusano 7d535a
void
kusano 7d535a
sgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
kusano 7d535a
        int *perm_c, int *perm_r, SuperMatrix *B,
kusano 7d535a
        SuperLUStat_t *stat, int *info)
kusano 7d535a
{
kusano 7d535a
kusano 7d535a
#ifdef _CRAY
kusano 7d535a
    _fcd ftcs1, ftcs2, ftcs3, ftcs4;
kusano 7d535a
#endif
kusano 7d535a
    int      incx = 1, incy = 1;
kusano 7d535a
#ifdef USE_VENDOR_BLAS
kusano 7d535a
    float   alpha = 1.0, beta = 1.0;
kusano 7d535a
    float   *work_col;
kusano 7d535a
#endif
kusano 7d535a
    DNformat *Bstore;
kusano 7d535a
    float   *Bmat;
kusano 7d535a
    SCformat *Lstore;
kusano 7d535a
    NCformat *Ustore;
kusano 7d535a
    float   *Lval, *Uval;
kusano 7d535a
    int      fsupc, nrow, nsupr, nsupc, luptr, istart, irow;
kusano 7d535a
    int      i, j, k, iptr, jcol, n, ldb, nrhs;
kusano 7d535a
    float   *work, *rhs_work, *soln;
kusano 7d535a
    flops_t  solve_ops;
kusano 7d535a
    void sprint_soln();
kusano 7d535a
kusano 7d535a
    /* Test input parameters ... */
kusano 7d535a
    *info = 0;
kusano 7d535a
    Bstore = B->Store;
kusano 7d535a
    ldb = Bstore->lda;
kusano 7d535a
    nrhs = B->ncol;
kusano 7d535a
    if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
kusano 7d535a
    else if ( L->nrow != L->ncol || L->nrow < 0 ||
kusano 7d535a
	      L->Stype != SLU_SC || L->Dtype != SLU_S || L->Mtype != SLU_TRLU )
kusano 7d535a
	*info = -2;
kusano 7d535a
    else if ( U->nrow != U->ncol || U->nrow < 0 ||
kusano 7d535a
	      U->Stype != SLU_NC || U->Dtype != SLU_S || U->Mtype != SLU_TRU )
kusano 7d535a
	*info = -3;
kusano 7d535a
    else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
kusano 7d535a
	      B->Stype != SLU_DN || B->Dtype != SLU_S || B->Mtype != SLU_GE )
kusano 7d535a
	*info = -6;
kusano 7d535a
    if ( *info ) {
kusano 7d535a
	i = -(*info);
kusano 7d535a
	xerbla_("sgstrs", &i);
kusano 7d535a
	return;
kusano 7d535a
    }
kusano 7d535a
kusano 7d535a
    n = L->nrow;
kusano 7d535a
    work = floatCalloc(n * nrhs);
kusano 7d535a
    if ( !work ) ABORT("Malloc fails for local work[].");
kusano 7d535a
    soln = floatMalloc(n);
kusano 7d535a
    if ( !soln ) ABORT("Malloc fails for local soln[].");
kusano 7d535a
kusano 7d535a
    Bmat = Bstore->nzval;
kusano 7d535a
    Lstore = L->Store;
kusano 7d535a
    Lval = Lstore->nzval;
kusano 7d535a
    Ustore = U->Store;
kusano 7d535a
    Uval = Ustore->nzval;
kusano 7d535a
    solve_ops = 0;
kusano 7d535a
    
kusano 7d535a
    if ( trans == NOTRANS ) {
kusano 7d535a
	/* Permute right hand sides to form Pr*B */
kusano 7d535a
	for (i = 0; i < nrhs; i++) {
kusano 7d535a
	    rhs_work = &Bmat[i*ldb];
kusano 7d535a
	    for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
kusano 7d535a
	    for (k = 0; k < n; k++) rhs_work[k] = soln[k];
kusano 7d535a
	}
kusano 7d535a
	
kusano 7d535a
	/* Forward solve PLy=Pb. */
kusano 7d535a
	for (k = 0; k <= Lstore->nsuper; k++) {
kusano 7d535a
	    fsupc = L_FST_SUPC(k);
kusano 7d535a
	    istart = L_SUB_START(fsupc);
kusano 7d535a
	    nsupr = L_SUB_START(fsupc+1) - istart;
kusano 7d535a
	    nsupc = L_FST_SUPC(k+1) - fsupc;
kusano 7d535a
	    nrow = nsupr - nsupc;
kusano 7d535a
kusano 7d535a
	    solve_ops += nsupc * (nsupc - 1) * nrhs;
kusano 7d535a
	    solve_ops += 2 * nrow * nsupc * nrhs;
kusano 7d535a
	    
kusano 7d535a
	    if ( nsupc == 1 ) {
kusano 7d535a
		for (j = 0; j < nrhs; j++) {
kusano 7d535a
		    rhs_work = &Bmat[j*ldb];
kusano 7d535a
	    	    luptr = L_NZ_START(fsupc);
kusano 7d535a
		    for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
kusano 7d535a
			irow = L_SUB(iptr);
kusano 7d535a
			++luptr;
kusano 7d535a
			rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr];
kusano 7d535a
		    }
kusano 7d535a
		}
kusano 7d535a
	    } else {
kusano 7d535a
	    	luptr = L_NZ_START(fsupc);
kusano 7d535a
#ifdef USE_VENDOR_BLAS
kusano 7d535a
#ifdef _CRAY
kusano 7d535a
		ftcs1 = _cptofcd("L", strlen("L"));
kusano 7d535a
		ftcs2 = _cptofcd("N", strlen("N"));
kusano 7d535a
		ftcs3 = _cptofcd("U", strlen("U"));
kusano 7d535a
		STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
kusano 7d535a
		       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
kusano 7d535a
		
kusano 7d535a
		SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, 
kusano 7d535a
			&Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
kusano 7d535a
			&beta, &work[0], &n );
kusano 7d535a
#else
kusano 7d535a
		strsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
kusano 7d535a
		       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
kusano 7d535a
		
kusano 7d535a
		sgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, 
kusano 7d535a
			&Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
kusano 7d535a
			&beta, &work[0], &n );
kusano 7d535a
#endif
kusano 7d535a
		for (j = 0; j < nrhs; j++) {
kusano 7d535a
		    rhs_work = &Bmat[j*ldb];
kusano 7d535a
		    work_col = &work[j*n];
kusano 7d535a
		    iptr = istart + nsupc;
kusano 7d535a
		    for (i = 0; i < nrow; i++) {
kusano 7d535a
			irow = L_SUB(iptr);
kusano 7d535a
			rhs_work[irow] -= work_col[i]; /* Scatter */
kusano 7d535a
			work_col[i] = 0.0;
kusano 7d535a
			iptr++;
kusano 7d535a
		    }
kusano 7d535a
		}
kusano 7d535a
#else		
kusano 7d535a
		for (j = 0; j < nrhs; j++) {
kusano 7d535a
		    rhs_work = &Bmat[j*ldb];
kusano 7d535a
		    slsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
kusano 7d535a
		    smatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
kusano 7d535a
			    &rhs_work[fsupc], &work[0] );
kusano 7d535a
kusano 7d535a
		    iptr = istart + nsupc;
kusano 7d535a
		    for (i = 0; i < nrow; i++) {
kusano 7d535a
			irow = L_SUB(iptr);
kusano 7d535a
			rhs_work[irow] -= work[i];
kusano 7d535a
			work[i] = 0.0;
kusano 7d535a
			iptr++;
kusano 7d535a
		    }
kusano 7d535a
		}
kusano 7d535a
#endif		    
kusano 7d535a
	    } /* else ... */
kusano 7d535a
	} /* for L-solve */
kusano 7d535a
kusano 7d535a
#ifdef DEBUG
kusano 7d535a
  	printf("After L-solve: y=\n");
kusano 7d535a
	sprint_soln(n, nrhs, Bmat);
kusano 7d535a
#endif
kusano 7d535a
kusano 7d535a
	/*
kusano 7d535a
	 * Back solve Ux=y.
kusano 7d535a
	 */
kusano 7d535a
	for (k = Lstore->nsuper; k >= 0; k--) {
kusano 7d535a
	    fsupc = L_FST_SUPC(k);
kusano 7d535a
	    istart = L_SUB_START(fsupc);
kusano 7d535a
	    nsupr = L_SUB_START(fsupc+1) - istart;
kusano 7d535a
	    nsupc = L_FST_SUPC(k+1) - fsupc;
kusano 7d535a
	    luptr = L_NZ_START(fsupc);
kusano 7d535a
kusano 7d535a
	    solve_ops += nsupc * (nsupc + 1) * nrhs;
kusano 7d535a
kusano 7d535a
	    if ( nsupc == 1 ) {
kusano 7d535a
		rhs_work = &Bmat[0];
kusano 7d535a
		for (j = 0; j < nrhs; j++) {
kusano 7d535a
		    rhs_work[fsupc] /= Lval[luptr];
kusano 7d535a
		    rhs_work += ldb;
kusano 7d535a
		}
kusano 7d535a
	    } else {
kusano 7d535a
#ifdef USE_VENDOR_BLAS
kusano 7d535a
#ifdef _CRAY
kusano 7d535a
		ftcs1 = _cptofcd("L", strlen("L"));
kusano 7d535a
		ftcs2 = _cptofcd("U", strlen("U"));
kusano 7d535a
		ftcs3 = _cptofcd("N", strlen("N"));
kusano 7d535a
		STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
kusano 7d535a
		       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
kusano 7d535a
#else
kusano 7d535a
		strsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
kusano 7d535a
		       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
kusano 7d535a
#endif
kusano 7d535a
#else		
kusano 7d535a
		for (j = 0; j < nrhs; j++)
kusano 7d535a
		    susolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
kusano 7d535a
#endif		
kusano 7d535a
	    }
kusano 7d535a
kusano 7d535a
	    for (j = 0; j < nrhs; ++j) {
kusano 7d535a
		rhs_work = &Bmat[j*ldb];
kusano 7d535a
		for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
kusano 7d535a
		    solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
kusano 7d535a
		    for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
kusano 7d535a
			irow = U_SUB(i);
kusano 7d535a
			rhs_work[irow] -= rhs_work[jcol] * Uval[i];
kusano 7d535a
		    }
kusano 7d535a
		}
kusano 7d535a
	    }
kusano 7d535a
	    
kusano 7d535a
	} /* for U-solve */
kusano 7d535a
kusano 7d535a
#ifdef DEBUG
kusano 7d535a
  	printf("After U-solve: x=\n");
kusano 7d535a
	sprint_soln(n, nrhs, Bmat);
kusano 7d535a
#endif
kusano 7d535a
kusano 7d535a
	/* Compute the final solution X := Pc*X. */
kusano 7d535a
	for (i = 0; i < nrhs; i++) {
kusano 7d535a
	    rhs_work = &Bmat[i*ldb];
kusano 7d535a
	    for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
kusano 7d535a
	    for (k = 0; k < n; k++) rhs_work[k] = soln[k];
kusano 7d535a
	}
kusano 7d535a
	
kusano 7d535a
        stat->ops[SOLVE] = solve_ops;
kusano 7d535a
kusano 7d535a
    } else { /* Solve A'*X=B or CONJ(A)*X=B */
kusano 7d535a
	/* Permute right hand sides to form Pc'*B. */
kusano 7d535a
	for (i = 0; i < nrhs; i++) {
kusano 7d535a
	    rhs_work = &Bmat[i*ldb];
kusano 7d535a
	    for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
kusano 7d535a
	    for (k = 0; k < n; k++) rhs_work[k] = soln[k];
kusano 7d535a
	}
kusano 7d535a
kusano 7d535a
	stat->ops[SOLVE] = 0;
kusano 7d535a
	for (k = 0; k < nrhs; ++k) {
kusano 7d535a
	    
kusano 7d535a
	    /* Multiply by inv(U'). */
kusano 7d535a
	    sp_strsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info);
kusano 7d535a
	    
kusano 7d535a
	    /* Multiply by inv(L'). */
kusano 7d535a
	    sp_strsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info);
kusano 7d535a
	    
kusano 7d535a
	}
kusano 7d535a
	/* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
kusano 7d535a
	for (i = 0; i < nrhs; i++) {
kusano 7d535a
	    rhs_work = &Bmat[i*ldb];
kusano 7d535a
	    for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
kusano 7d535a
	    for (k = 0; k < n; k++) rhs_work[k] = soln[k];
kusano 7d535a
	}
kusano 7d535a
kusano 7d535a
    }
kusano 7d535a
kusano 7d535a
    SUPERLU_FREE(work);
kusano 7d535a
    SUPERLU_FREE(soln);
kusano 7d535a
}
kusano 7d535a
kusano 7d535a
/*
kusano 7d535a
 * Diagnostic print of the solution vector 
kusano 7d535a
 */
kusano 7d535a
void
kusano 7d535a
sprint_soln(int n, int nrhs, float *soln)
kusano 7d535a
{
kusano 7d535a
    int i;
kusano 7d535a
kusano 7d535a
    for (i = 0; i < n; i++) 
kusano 7d535a
  	printf("\t%d: %.4f\n", i, soln[i]);
kusano 7d535a
}