00001
00002
00003
00004
00005
00006
00007
00008
00009 #pragma once
00010
00011 #include <cuComplex.h>
00012 #include "constants.h"
00013
00014 #ifdef __cplusplus
00015 #define ZERO =0 // default: no batch/tiles
00016 #else
00017 #define ZERO
00018 #endif
00019
00020 namespace af {
00021
00022 class array;
00023
00026
00046 template<typename ty> ty norm(const array& in, float p = af::nan);
00048
00049
00052
00053
00061 template<typename ty> ty inner(const array& x, const array& y);
00062
00070 AFAPI array inner(const array& X, const array& Y, int dim = -1);
00071
00072
00074
00075 }
00076
00077 #ifdef __cplusplus
00078 extern "C" {
00079 #endif
00080
00100
00102 AFAPI afError af_gemm_SS(char TRANSA, char TRANSB, int M, int N, int K, float ALPHA, const float *d_A, int LDA, const float *d_B, int LDB, float BETA, float *d_C, int LDC, unsigned batch_A ZERO, unsigned batch_B ZERO);
00104 AFAPI afError af_gemm_DD(char TRANSA, char TRANSB, int M, int N, int K, double ALPHA, const double *d_A, int LDA, const double *d_B, int LDB, double BETA, double *d_C, int LDC, unsigned batch_A ZERO, unsigned batch_B ZERO);
00106 AFAPI afError af_gemm_CC(char TRANSA, char TRANSB, int M, int N, int K, cuComplex ALPHA, const cuComplex *d_A, int LDA, const cuComplex *d_B, int LDB, cuComplex BETA, cuComplex *d_C, int LDC, unsigned batch_A ZERO, unsigned batch_B ZERO);
00108 AFAPI afError af_gemm_ZZ(char TRANSA, char TRANSB, int M, int N, int K, cuDoubleComplex ALPHA, const cuDoubleComplex *d_A, int LDA, const cuDoubleComplex *d_B, int LDB, cuDoubleComplex BETA, cuDoubleComplex *d_C, int LDC, unsigned batch_A ZERO, unsigned batch_B ZERO);
00110 AFAPI afError af_matmul_SS(const float *d_A, const float *d_B,
00111 float *d_C, int M, int N, int K,
00112 unsigned batch_A ZERO, unsigned batch_B ZERO);
00113 AFAPI afError af_matmul_DS(const double *d_A, const float *d_B,
00114 float *d_C, int M, int N, int K,
00115 unsigned batch_A ZERO, unsigned batch_B ZERO);
00116 AFAPI afError af_matmul_SD(const float *d_A, const double *d_B,
00117 float *d_C, int M, int N, int K,
00118 unsigned batch_A ZERO, unsigned batch_B ZERO);
00119 AFAPI afError af_matmul_DD(const double *d_A, const double *d_B,
00120 double *d_C, int M, int N, int K,
00121 unsigned batch_A ZERO, unsigned batch_B ZERO);
00122 AFAPI afError af_matmul_CS(const cuComplex *d_A, const float *d_B,
00123 cuComplex *d_C, int M, int N, int K,
00124 unsigned batch_A ZERO, unsigned batch_B ZERO);
00125 AFAPI afError af_matmul_ZS(const cuDoubleComplex *d_A, const float *d_B,
00126 cuComplex *d_C, int M, int N, int K,
00127 unsigned batch_A ZERO, unsigned batch_B ZERO);
00128 AFAPI afError af_matmul_CD(const cuComplex *d_A, const double *d_B,
00129 cuComplex *d_C, int M, int N, int K,
00130 unsigned batch_A ZERO, unsigned batch_B ZERO);
00131 AFAPI afError af_matmul_ZD(const cuDoubleComplex *d_A, const double *d_B,
00132 cuDoubleComplex *d_C, int M, int N,
00133 int K, unsigned batch_A ZERO,
00134 unsigned batch_B ZERO);
00135 AFAPI afError af_matmul_SC(const float *d_A, const cuComplex *d_B,
00136 cuComplex *d_C, int M, int N, int K,
00137 unsigned batch_A ZERO, unsigned batch_B ZERO);
00138 AFAPI afError af_matmul_DC(const double *d_A, const cuComplex *d_B,
00139 cuComplex *d_C, int M, int N, int K,
00140 unsigned batch_A ZERO, unsigned batch_B ZERO);
00141 AFAPI afError af_matmul_SZ(const float *d_A, const cuDoubleComplex *d_B,
00142 cuComplex *d_C, int M, int N, int K,
00143 unsigned batch_A ZERO, unsigned batch_B ZERO);
00144 AFAPI afError af_matmul_DZ(const double *d_A, const cuDoubleComplex *d_B,
00145 cuDoubleComplex *d_C, int M, int N,
00146 int K, unsigned batch_A ZERO,
00147 unsigned batch_B ZERO);
00148 AFAPI afError af_matmul_CC(const cuComplex *d_A, const cuComplex *d_B,
00149 cuComplex *d_C, int M, int N, int K,
00150 unsigned batch_A ZERO, unsigned batch_B ZERO);
00151 AFAPI afError af_matmul_ZC(const cuDoubleComplex *d_A,
00152 const cuComplex *d_B, cuComplex *d_C,
00153 int M, int N, int K,
00154 unsigned batch_A ZERO, unsigned batch_B ZERO);
00155 AFAPI afError af_matmul_CZ(const cuComplex *d_A,
00156 const cuDoubleComplex *d_B,
00157 cuComplex *d_C, int M, int N, int K,
00158 unsigned batch_A ZERO, unsigned batch_B ZERO);
00159 AFAPI afError af_matmul_ZZ(const cuDoubleComplex *d_A,
00160 const cuDoubleComplex *d_B,
00161 cuDoubleComplex *d_C, int M, int N,
00162 int K, unsigned batch_A ZERO,
00163 unsigned batch_B ZERO);
00166
00175
00176 AFAPI afError af_norm_vector_S(float* h_dst,
00177 unsigned numel, const float* d_src);
00179 AFAPI afError af_norm_vector_D(double* h_dst,
00180 unsigned numel, const double* d_src);
00183
00185 AFAPI afError af_norm_S(float* d_dst,
00186 unsigned ndims, const unsigned* dims,
00187 const float* d_src, int dim);
00189 AFAPI afError af_norm_D(double* d_dst,
00190 unsigned ndims, const unsigned* dims,
00191 const double* d_src, int dim);
00192
00203 AFAPI afError af_inner_S(float *h_res, unsigned len, const float *d_A, const float *d_B);
00204 AFAPI afError af_inner_D(double *h_res, unsigned len, const double *d_A, const double *d_B);
00205 AFAPI afError af_inner_C(cuComplex *h_res, unsigned len,
00206 const cuComplex *d_A, const cuComplex *d_B);
00207 AFAPI afError af_inner_Z(cuDoubleComplex *h_res, unsigned len,
00208 const cuDoubleComplex *d_A, const cuDoubleComplex *d_B);
00209
00212
00213 #ifdef __cplusplus
00214 }
00215 #endif
00216
00217 #ifdef __cplusplus
00218 extern "C" {
00219 #endif
00220
00221 #include <cublas_v2.h>
00222
00223
00224 AFAPI afError cublasSgemmN(cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
00225 float alpha, const float *d_A, int lda, const float *d_B, int ldb, float beta, float *d_C, int ldc,
00226 unsigned batch_A ZERO, unsigned batch_B ZERO);
00227 AFAPI afError cublasDgemmN(cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
00228 double alpha, const double *d_A, int lda, const double *d_B, int ldb, double beta, double *d_C, int ldc,
00229 unsigned batch_A ZERO, unsigned batch_B ZERO);
00230 AFAPI afError cublasCgemmN(cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
00231 cuComplex alpha, const cuComplex *d_A, int lda, const cuComplex *d_B, int ldb, cuComplex beta, cuComplex *d_C, int ldc,
00232 unsigned batch_A ZERO, unsigned batch_B ZERO);
00233 AFAPI afError cublasZgemmN(cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
00234 cuDoubleComplex alpha, const cuDoubleComplex *d_A, int lda, const cuDoubleComplex *d_B, int ldb, cuDoubleComplex beta, cuDoubleComplex *d_C, int ldc,
00235 unsigned batch_A ZERO, unsigned batch_B ZERO);
00236
00237
00238 AFAPI afError cublasSgemvN(cublasOperation_t TRANSA, int M, int N,
00239 float ALPHA, const float *d_A, int LDA, const float *d_x, int incx, float BETA, float *d_y, int incy,
00240 unsigned batch_A ZERO, unsigned batch_x ZERO);
00241 AFAPI afError cublasDgemvN(cublasOperation_t TRANSA, int M, int N,
00242 double ALPHA, const double *d_A, int LDA, const double *d_x, int incx, double BETA, double *d_y, int incy,
00243 unsigned batch_A ZERO, unsigned batch_x ZERO);
00244 AFAPI afError cublasCgemvN(cublasOperation_t TRANSA, int M, int N,
00245 cuComplex ALPHA, const cuComplex *d_A, int LDA, const cuComplex *d_x, int incx, cuComplex BETA, cuComplex *d_y, int incy,
00246 unsigned batch_A ZERO, unsigned batch_x ZERO);
00247 AFAPI afError cublasZgemvN(cublasOperation_t TRANSA, int M, int N,
00248 cuDoubleComplex ALPHA, const cuDoubleComplex *d_A, int LDA, const cuDoubleComplex *d_x, int incx, cuDoubleComplex BETA, cuDoubleComplex *d_y, int incy,
00249 unsigned batch_A ZERO, unsigned batch_x ZERO);
00250
00251 AFAPI afError cublasSgerN(int M, int N,
00252 float ALPHA, const float *d_x, int incx, const float *d_y, int incy, float *d_A, int lda,
00253 unsigned batch_x ZERO, unsigned batch_y ZERO);
00254 AFAPI afError cublasDgerN(int M, int N,
00255 double ALPHA, const double *d_x, int incx, const double *d_y, int incy, double *d_A, int lda,
00256 unsigned batch_x ZERO, unsigned batch_y ZERO);
00257 AFAPI afError cublasCgeruN(int M, int N,
00258 cuComplex ALPHA, const cuComplex *d_x, int incx, const cuComplex *d_y, int incy, cuComplex *d_A, int lda,
00259 unsigned batch_x ZERO, unsigned batch_y ZERO);
00260 AFAPI afError cublasZgeruN(int M, int N,
00261 cuDoubleComplex ALPHA, const cuDoubleComplex *d_x, int incx, const cuDoubleComplex *d_y, int incy, cuDoubleComplex *d_A, int lda,
00262 unsigned batch_x ZERO, unsigned batch_y ZERO);
00263
00264
00265 AFAPI afError cublasSdotN( int N, const float *d_x, int incx, const float *d_y, int incy, float *h_dot, unsigned batch_x ZERO, unsigned batch_y ZERO);
00266 AFAPI afError cublasDdotN( int N, const double *d_x, int incx, const double *d_y, int incy, double *h_dot, unsigned batch_x ZERO, unsigned batch_y ZERO);
00267 AFAPI afError cublasCdotuN(int N, const cuComplex *d_x, int incx, const cuComplex *d_y, int incy, cuComplex *h_dot, unsigned batch_x ZERO, unsigned batch_y ZERO);
00268 AFAPI afError cublasZdotuN(int N, const cuDoubleComplex *d_x, int incx, const cuDoubleComplex *d_y, int incy, cuDoubleComplex *h_dot, unsigned batch_x ZERO, unsigned batch_y ZERO);
00269 AFAPI afError cublasCdotcN(int N, const cuComplex *d_x, int incx, const cuComplex *d_y, int incy, cuComplex *h_dot, unsigned batch_x ZERO, unsigned batch_y ZERO);
00270 AFAPI afError cublasZdotcN(int N, const cuDoubleComplex *d_x, int incx, const cuDoubleComplex *d_y, int incy, cuDoubleComplex *h_dot, unsigned batch_x ZERO, unsigned batch_y ZERO);
00271
00272 #undef ZERO
00273
00274 #ifdef __cplusplus
00275 }
00276 #endif
00277
00278 #ifdef __cplusplus
00279
00280 AFAPI afError af_norm_vector(float* h_dst, unsigned numel, const float* d_src);
00282 AFAPI afError af_norm_vector(double* h_dst, unsigned numel, const double* d_src);
00284 AFAPI afError af_norm(float* d_dst, unsigned ndims, const unsigned* dims, const float* d_src, int dim);
00286 AFAPI afError af_norm(double* d_dst, unsigned ndims, const unsigned* dims, const double* d_src, int dim);
00287 #endif