Skip to content

Commit f62f438

Browse files
committed
Add trsm CPU support for Tensor LAPACK
1 parent b978cae commit f62f438

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

source/module_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,24 @@ struct lapack_geqrf<T, DEVICE_CPU> {
199199
}
200200
};
201201

202+
template <typename T>
203+
struct lapack_trsm<T, DEVICE_CPU> {
204+
void operator()(
205+
char side,
206+
char uplo,
207+
char transA,
208+
char diag,
209+
int m,
210+
int n,
211+
T alpha,
212+
T* A,
213+
int lda,
214+
T* B,
215+
int ldb)
216+
{
217+
lapackConnector::trsm(side, uplo, transA, diag, m, n, alpha, A, lda, B, ldb);
218+
}
219+
};
202220

203221
template struct set_matrix<float, DEVICE_CPU>;
204222
template struct set_matrix<double, DEVICE_CPU>;

source/module_base/module_container/ATen/kernels/lapack.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,56 @@ struct lapack_geqrf {
285285
int lwork);
286286
};
287287

288+
/**
289+
* @brief Functor for solving a system of linear equations with a triangular matrix using LAPACK's TRSM routine.
290+
*
291+
* TRSM: triangular solve matrix
292+
* Solves one of the following matrix equations:
293+
* - op(A) * X = alpha * B
294+
* - X * op(A) = alpha * B
295+
* where op(A) is either A, A^T, A^H, or A^T.
296+
*
297+
*/
298+
template <typename T, typename Device>
299+
struct lapack_trsm {
300+
/**
301+
* @brief Solve a system of linear equations with a triangular matrix.
302+
*
303+
* Solves one of the following matrix equations:
304+
* - op(A) * X = alpha * B
305+
* - X * op(A) = alpha * B
306+
* where op(A) is either A, A^T, A^H, or A^T.
307+
*
308+
* @param side Specifies whether op(A) multiplies B from the left or right.
309+
* 'L' or 'l' for left, 'R' or 'r' for right.
310+
* @param uplo Specifies whether the matrix A is an upper or lower triangular matrix.
311+
* 'U' or 'u' for upper, 'L' or 'l' for lower.
312+
* @param transA Specifies the form of op(A) to be used in the matrix multiplication.
313+
* 'N' or 'n' for no transpose, 'T' or 't' for transpose, 'C' or 'c' for conjugate transpose.
314+
* @param diag Specifies whether or not A is unit triangular.
315+
* 'U' or 'u' for unit triangular, 'N' or 'n' for non-unit triangular.
316+
* @param m The number of rows of the matrix B. m >= 0.
317+
* @param n The number of columns of the matrix B. n >= 0.
318+
* @param alpha Scalar multiplier applied to op(A) * B.
319+
* @param A Pointer to the matrix A.
320+
* @param lda Leading dimension of A. lda >= max(1, m) if side == 'L' or lda >= max(1, n) if side == 'R'.
321+
* @param B Pointer to the matrix B.
322+
* @param ldb Leading dimension of B. ldb >= max(1, m).
323+
*/
324+
void operator()(
325+
char side,
326+
char uplo,
327+
char transA,
328+
char diag,
329+
int m,
330+
int n,
331+
T alpha,
332+
T* A,
333+
int lda,
334+
T* B,
335+
int ldb);
336+
};
337+
288338

289339
#if defined(__CUDA) || defined(__ROCM)
290340
// TODO: Use C++ singleton to manage the GPU handles

source/module_base/module_container/base/third_party/lapack.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ void sgeqrf_(const int* m, const int* n, float* a, const int* lda, float* tau, f
124124
void dgeqrf_(const int* m, const int* n, double* a, const int* lda, double* tau, double* work, const int* lwork, int* info);
125125
void cgeqrf_(const int* m, const int* n, std::complex<float>* a, const int* lda, std::complex<float>* tau, std::complex<float>* work, const int* lwork, int* info);
126126
void zgeqrf_(const int* m, const int* n, std::complex<double>* a, const int* lda, std::complex<double>* tau, std::complex<double>* work, const int* lwork, int* info);
127+
128+
void strsm_(const char* side, const char* uplo, const char* transa, const char* diag, const int* m, const int* n, const float* alpha, const float* a, const int* lda, float* b, const int* ldb);
129+
void dtrsm_(const char* side, const char* uplo, const char* transa, const char* diag, const int* m, const int* n, const double* alpha, const double* a, const int* lda, double* b, const int* ldb);
130+
void ctrsm_(const char* side, const char* uplo, const char* transa, const char* diag, const int* m, const int* n, const std::complex<float>* alpha, const std::complex<float>* a, const int* lda, std::complex<float>* b, const int* ldb);
131+
void ztrsm_(const char* side, const char* uplo, const char* transa, const char* diag, const int* m, const int* n, const std::complex<double>* alpha, const std::complex<double>* a, const int* lda, std::complex<double>* b, const int* ldb);
127132
}
128133

129134
// Class LapackConnector provide the connector to fortran lapack routine.

0 commit comments

Comments
 (0)