Actual source code: mpimatmatmatmult.c

  1: /*
  2:   Defines matrix-matrix-matrix product routines for MPIAIJ matrices
  3:           D = A * B * C
  4: */
  5: #include <../src/mat/impls/aij/mpi/mpiaij.h>

  7: #if defined(PETSC_HAVE_HYPRE)
  8: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,PetscReal,Mat);
  9: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,Mat);

 11: PETSC_INTERN PetscErrorCode MatProductNumeric_ABC_Transpose_AIJ_AIJ(Mat RAP)
 12: {
 13:   Mat_Product    *product = RAP->product;
 14:   Mat            Rt,R=product->A,A=product->B,P=product->C;

 16:   MatTransposeGetMat(R,&Rt);
 17:   MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,RAP);
 18:   return 0;
 19: }

 21: PETSC_INTERN PetscErrorCode MatProductSymbolic_ABC_Transpose_AIJ_AIJ(Mat RAP)
 22: {
 23:   Mat_Product    *product = RAP->product;
 24:   Mat            Rt,R=product->A,A=product->B,P=product->C;
 25:   PetscBool      flg;

 27:   /* local sizes of matrices will be checked by the calling subroutines */
 28:   MatTransposeGetMat(R,&Rt);
 29:   PetscObjectTypeCompareAny((PetscObject)Rt,&flg,MATSEQAIJ,MATSEQAIJMKL,MATMPIAIJ,NULL);
 31:   MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,product->fill,RAP);
 32:   RAP->ops->productnumeric = MatProductNumeric_ABC_Transpose_AIJ_AIJ;
 33:   return 0;
 34: }

 36: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose_AIJ_AIJ(Mat C)
 37: {
 38:   Mat_Product *product = C->product;

 40:   if (product->type == MATPRODUCT_ABC) {
 41:     C->ops->productsymbolic = MatProductSymbolic_ABC_Transpose_AIJ_AIJ;
 42:   } else SETERRQ(PetscObjectComm((PetscObject)C),PETSC_ERR_SUP,"MatProduct type %s is not supported for Transpose, AIJ and AIJ matrices",MatProductTypes[product->type]);
 43:   return 0;
 44: }
 45: #endif

 47: PetscErrorCode MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,PetscReal fill,Mat D)
 48: {
 49:   Mat            BC;
 50:   PetscBool      scalable;
 51:   Mat_Product    *product;

 53:   MatCheckProduct(D,4);
 55:   product = D->product;
 56:   MatProductCreate(B,C,NULL,&BC);
 57:   MatProductSetType(BC,MATPRODUCT_AB);
 58:   PetscStrcmp(product->alg,"scalable",&scalable);
 59:   if (scalable) {
 60:     MatMatMultSymbolic_MPIAIJ_MPIAIJ(B,C,fill,BC);
 61:     MatZeroEntries(BC); /* initialize value entries of BC */
 62:     MatMatMultSymbolic_MPIAIJ_MPIAIJ(A,BC,fill,D);
 63:   } else {
 64:     MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(B,C,fill,BC);
 65:     MatZeroEntries(BC); /* initialize value entries of BC */
 66:     MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(A,BC,fill,D);
 67:   }
 68:   MatDestroy(&product->Dwork);
 69:   product->Dwork = BC;

 71:   D->ops->matmatmultnumeric = MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ;
 72:   return 0;
 73: }

 75: PetscErrorCode MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,Mat D)
 76: {
 77:   Mat_Product    *product;
 78:   Mat            BC;

 80:   MatCheckProduct(D,4);
 82:   product = D->product;
 83:   BC = product->Dwork;
 85:   (*BC->ops->matmultnumeric)(B,C,BC);
 87:   (*D->ops->matmultnumeric)(A,BC,D);
 88:   return 0;
 89: }

 91: /* ----------------------------------------------------- */
 92: PetscErrorCode MatDestroy_MPIAIJ_RARt(void *data)
 93: {
 94:   Mat_RARt       *rart = (Mat_RARt*)data;

 96:   MatDestroy(&rart->Rt);
 97:   if (rart->destroy) {
 98:     (*rart->destroy)(rart->data);
 99:   }
100:   PetscFree(rart);
101:   return 0;
102: }

104: PetscErrorCode MatProductNumeric_RARt_MPIAIJ_MPIAIJ(Mat C)
105: {
106:   Mat_RARt       *rart;
107:   Mat            A,R,Rt;

109:   MatCheckProduct(C,1);
111:   rart = (Mat_RARt*)C->product->data;
112:   A    = C->product->A;
113:   R    = C->product->B;
114:   Rt   = rart->Rt;
115:   MatTranspose(R,MAT_REUSE_MATRIX,&Rt);
116:   if (rart->data) C->product->data = rart->data;
117:   (*C->ops->matmatmultnumeric)(R,A,Rt,C);
118:   C->product->data = rart;
119:   return 0;
120: }

122: PetscErrorCode MatProductSymbolic_RARt_MPIAIJ_MPIAIJ(Mat C)
123: {
124:   Mat            A,R,Rt;
125:   Mat_RARt       *rart;

127:   MatCheckProduct(C,1);
129:   A    = C->product->A;
130:   R    = C->product->B;
131:   MatTranspose(R,MAT_INITIAL_MATRIX,&Rt);
132:   /* product->Dwork is used to store A*Rt in MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ() */
133:   MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(R,A,Rt,C->product->fill,C);
134:   C->ops->productnumeric = MatProductNumeric_RARt_MPIAIJ_MPIAIJ;

136:   /* create a supporting struct */
137:   PetscNew(&rart);
138:   rart->Rt      = Rt;
139:   rart->data    = C->product->data;
140:   rart->destroy = C->product->destroy;
141:   C->product->data    = rart;
142:   C->product->destroy = MatDestroy_MPIAIJ_RARt;
143:   return 0;
144: }