00001 #pragma once
00002
00003
00004
00005 #undef min
00006 #undef max
00007 #define NOMINMAX
00008
00009 namespace af {
00010
00011 class dim4;
00012 class ginternal;
00013 class array;
00014 }
00015
00016 #include <cuComplex.h>
00017 typedef cuDoubleComplex cuDComplex;
00018 #include "dims.h"
00019 #include "index.h"
00020 #include "gfor.h"
00021
00023 #define REST af_source_t src=afHostPointer, unsigned ngfor=0
00024
00025 #define group(...) __VA_ARGS__
00026
00027
00028 namespace af {
00029
00032 AFAPI array mul(const array &, const array &);
00033 AFAPI array mul(const array &, const double scalar);
00034 AFAPI array mul(const double scalar, const array &);
00035
00036
00040 AFAPI array dot(const array &, const double);
00041 AFAPI array dot(const double, const array &);
00042 AFAPI array dot(const array &, const array &);
00044
00046 typedef enum {
00047 f32,
00048 c32,
00049 f64,
00050 c64,
00051 b8,
00052 s32,
00053 u32,
00054 } dtype;
00055
00057 typedef enum {
00058 afDevicePointer,
00059 afHostPointer
00060 } af_source_t;
00061
00063 typedef enum {
00064 AF_SP_CSR,
00065 AF_SP_CSC,
00066 AF_SP_COO,
00067 } af_sparse_t;
00068
00069 class AFAPI array {
00070
00071 public:
00073 array copy() const;
00074
00075 array(ginternal*);
00076 ~array();
00077
00078 #define CONSTR(what, dims, ...) \
00079 \
00080 template<typename ty> array(dims, const ty __VA_ARGS__);
00081
00084
00085 array();
00086
00088 array(const dim4& dims, dtype ty=f32);
00090 array(int d0, int d1, dtype ty=f32);
00092 array(int d0, int d1, int d2, dtype ty=f32);
00094 array(int d0, int d1, int d2, int d3, dtype ty=f32);
00095
00096 array(const array&);
00097
00103 array(const seq& s);
00104
00105 CONSTR(column vector, unsigned dim0, *pointer, REST)
00106 CONSTR(matrix, group(unsigned dim0, unsigned dim1), *pointer, REST)
00107 CONSTR(volume, group(unsigned dim0, unsigned dim1, unsigned dim2), *pointer, REST)
00108 CONSTR(4D array, group(unsigned dim0, unsigned dim1, unsigned dim2, unsigned dim3), *pointer, REST)
00109 CONSTR(array, const dim4 &dims, *pointer, REST)
00110
00112 template<typename ty> static array scalar(const ty);
00113
00114
00115 #undef CONSTR
00116 #undef REST
00117
00118
00120
00121
00124
00126 dim4 dims() const;
00128 int dims(unsigned) const;
00130 int ndims() const;
00132 int elements() const;
00133
00135 unsigned ngfor() const;
00137 size_t bytes() const;
00139 bool isempty() const;
00141 bool isscalar() const;
00143 bool isvector() const;
00145 bool isrow() const;
00147 bool iscolumn() const;
00148
00150 dtype type() const;
00152 bool iscomplex() const;
00154 inline bool isreal() const { return iscomplex() == false; }
00155
00156
00157 bool issparse() const;
00158
00159
00161
00163 void eval();
00164
00166 void unlock() const;
00167
00170
00171 array row(int i) const { return operator()(i,span,span); }
00172 array col(int i) const { return operator()(span,i,span); }
00173 array slice(int i) const { return operator()(span,span,i); }
00174
00175 array rows(int first, int last) const { return operator()(seq(first,last), span); }
00176 array cols(int first, int last) const { return operator()(span, seq(first,last)); }
00177 array slices(int first, int last) const { return operator()(span,span,seq(first,last)); }
00178
00181 array operator()(int x) const { return operator()(seq(x,x)); }
00182 array operator()(const seq&) const;
00183 array operator()(array indices) const;
00184
00185
00188 array operator()(const seq&, const seq&) const;
00189 array operator()(int row, int col) const { return operator()(seq(row,row), seq(col,col)); }
00190 array operator()(int row, const seq &y) const { return operator()(seq(row,row),y); }
00191 array operator()(int row, array columns) const { return operator()(seq(row,row),columns); }
00192 array operator()(const seq &x, int column) const { return operator()(x,seq(column,column)); }
00193 array operator()(const seq &rows, array columns) const;
00194 array operator()(array rows, int column) const { return operator()(rows,seq(column,column)); }
00195 array operator()(array rows, const seq& columns) const;
00196 array operator()(array rows, array columns) const;
00197
00198
00199
00202 array operator()(const seq&, const seq&, const seq&) const;
00203 array operator()(int x, int y, int z) const { return operator()(seq(x,x),seq(y,y),seq(z,z)); }
00204
00205 array operator()(const seq &x, int y, int z) const { return operator()(x,seq(y,y),seq(z,z)); }
00206 array operator()(int x, const seq &y, int z) const { return operator()(seq(x,x),y,seq(z,z)); }
00207 array operator()(int x, int y, const seq &z) const { return operator()(seq(x,x),seq(y,y),z); }
00208
00209 array operator()(int x, const seq &y, const seq &z) const { return operator()(seq(x,x),y,z); }
00210 array operator()(const seq &x, const seq &y, int z) const { return operator()(x,y,seq(z,z)); }
00211 array operator()(const seq &x, int y, const seq &z) const { return operator()(x,seq(y,y),z); }
00212
00213 array operator()(array rows, array cols, const seq &slices) const;
00214 array operator()(array rows, array cols, int slice) const { return operator()(rows,cols,seq(slice,slice)); }
00215 array operator()(array rows, const seq &cols, const seq &slices) const;
00216 array operator()(array rows, const seq &cols, int slice) const { return operator()(rows,cols,seq(slice,slice)); }
00217
00218 array operator()(const seq &rows, array cols, const seq &slices) const;
00219 array operator()(const seq&, const seq&, array) const;
00220
00221
00224 array operator()(const seq &w, const seq &x, const seq &y, const seq &z) const;
00225 array operator()(const seq &w, const seq &x, const seq &y, int z) const { return operator()(w,x,y,seq(z,z)); }
00226 array operator()(const seq &w, const seq &x, int y, int z) const { return operator()(w,x,seq(y,y),seq(z,z)); }
00227
00228
00230
00231
00234
00239 template<typename T> T scalar() const;
00240
00274 template<typename T> T *device() const;
00275
00288 template<typename T> T *host() const;
00289
00291 template<typename T> static void hostFree(const T *);
00292
00294
00297
00309 template<typename T> static T *alloc(size_t elements);
00310
00312 static void free(void*);
00313
00314
00316
00317
00319 #define SELF(op) \
00320 array& operator op(const array&); \
00321 array& operator op(const double&)
00322
00324 #define BIN(op) \
00325 array operator op(const array&) const; \
00326 array operator op(const double&) const; \
00327 AFAPI friend array operator op(const double&, const array&)
00328
00330 #define LOGIC(op) \
00331 array operator op(const array&) const; \
00332 array operator op##op(const array&) const; \
00333 array operator op##op(const bool&) const; \
00334 array operator op##op(const int&) const; \
00335 array operator op##op(const unsigned&) const; \
00336 array operator op##op(const double&) const; \
00337 AFAPI friend array operator op##op(const bool&, const array&); \
00338 AFAPI friend array operator op##op(const int&, const array&); \
00339 AFAPI friend array operator op##op(const unsigned&, const array&)
00340
00342 #define COMPARISON(op) \
00343 array operator op(const array&) const; \
00344 array operator op(const bool&) const; \
00345 array operator op(const int&) const; \
00346 array operator op(const double&) const; \
00347 AFAPI friend array operator op(const bool&, const array&); \
00348 AFAPI friend array operator op(const int&, const array&); \
00349 AFAPI friend array operator op(const double&, const array&)
00350
00353 array operator-() const;
00354 array operator!() const;
00355
00356
00359
00393 inline array operator*(const array &rhs) const
00394 {
00395 #if defined(AF_TIMES_ELEMENTWISE)
00396 return mul(*this, rhs);
00397 #elif defined(AF_TIMES_MATMUL)
00398 return dot(*this, rhs);
00399 #else
00400 return dot(*this, rhs);
00401 #endif
00402 }
00403
00405
00408 array operator*(const double &scalar) const;
00409 AFAPI friend array operator*(const double &scalar, const array&);
00410
00411 BIN(+);
00412 BIN(-);
00413 BIN(/);
00414 BIN(%);
00415
00416 SELF(+=);
00417 SELF(-=);
00418 SELF(*=);
00419 SELF(/=);
00420 SELF(%=);
00421
00422 array& operator++();
00423 array& operator--();
00424
00425 LOGIC(&);
00426 LOGIC(|);
00427 array operator ^(const array&) const;
00428
00429 COMPARISON(==);
00430 COMPARISON(!=);
00431 COMPARISON(<);
00432 COMPARISON(<=);
00433 COMPARISON(>);
00434 COMPARISON(>=);
00435
00437 #undef COMPARISON
00438 #undef BIN
00439 #undef SELF
00440 #undef LOGIC
00441
00442
00444
00447 template<typename ty> array& operator=(const ty);
00448 array& operator=(const array&);
00449
00450
00453
00471 array T() const;
00472
00493 array H() const;
00495
00496
00497
00498 friend class ginternal;
00499 mutable ginternal *m_internal;
00500 private:
00501 array(unsigned ty, const dim4& dims, const void *pointer,
00502 af_source_t src=afHostPointer, unsigned ngfor=0);
00503 };
00504
00505
00507 #define GEN(what, ...) \
00508 \
00509 \
00510 AFAPI array ones(__VA_ARGS__, dtype ty=f32); \
00511 AFAPI array zeros(__VA_ARGS__, dtype ty=f32); \
00512 AFAPI array identity(__VA_ARGS__, dtype ty=f32); \
00513 AFAPI array randu(__VA_ARGS__, dtype ty=f32); \
00514 AFAPI array randn(__VA_ARGS__, dtype ty=f32); \
00515
00522 GEN(column vector, unsigned nx)
00523 GEN(matrix, unsigned nx, unsigned ny)
00524 GEN(volume, unsigned nx, unsigned ny, unsigned nz)
00525 GEN(4D array, unsigned d0, unsigned d1, unsigned d2, unsigned d3)
00526 GEN(array, const dim4 &dims)
00527
00528
00529
00530 #undef GEN
00531
00532
00534 #define SPARSE(ty) \
00535 AFAPI array sparse(int rows, int cols, int nnz, \
00536 const ty *values, \
00537 const int *rowptr, \
00538 const int *colind, \
00539 af_source_t src=afHostPointer, \
00540 unsigned ngfor=0); \
00541
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552
00553
00554
00555
00556 SPARSE(float);
00557 SPARSE(double);
00558 SPARSE(cuComplex);
00559 SPARSE(cuDComplex);
00561 #undef SPARSE
00562
00567 AFAPI array sparse(array dense);
00569
00578 AFAPI array sparse(array& I, array& J, array& val, int rows=-1,
00579 int cols=-1, af_sparse_t format=AF_SP_COO);
00583
00591 AFAPI array dense(array sparse);
00593
00601 AFAPI afError where(array& I, array& J, array& Val, array sp, af_sparse_t format=AF_SP_COO);
00604
00605 AFAPI array getrow(int nrows, int *rows_idx, array sp);
00606 AFAPI array getcol(int ncols, int *cols_idx, array sp);
00607
00608
00610 #define UNARY(fn) AFAPI array fn(const array& x)
00611
00612 #define UNARY_REAL(fn) \
00613 \
00614 AFAPI array fn(const array& x, bool isreal=false);
00615
00617 #define BINARY_EXTRA(fn, a, b, comment, ...) \
00618 \
00619 AFAPI array fn(const array &a, const array &b __VA_ARGS__); \
00620 \
00621 AFAPI array fn(const array &a, const double b __VA_ARGS__); \
00622 \
00623 AFAPI array fn(const double a, const array &b __VA_ARGS__)
00624 #define BINARY_NAMED(fn,a,b,comment) BINARY_EXTRA(fn,a,b,comment,) // HACK
00625
00627 #define BINARY(fn,comment) \
00628 \
00629 AFAPI array fn(const array &, const array &); \
00630 \
00631 AFAPI array fn(const array &, const double); \
00632 \
00633 AFAPI array fn(const double, const array &)
00634
00636 #define EXTREMUM(fn, OP) \
00637 \
00638 AFAPI array fn(const array &x, const array &y); \
00639 \
00640 AFAPI array fn(const double x, const array &y)
00641
00642
00645
00646 AFAPI array convert(const array &, dtype type);
00647
00648 UNARY(sin);
00649 UNARY(sinh);
00650 UNARY(asin);
00651 UNARY(asinh);
00652 UNARY(cos);
00653 UNARY(cosh);
00654 UNARY(acos);
00655 UNARY(acosh);
00656 UNARY(tan);
00657 UNARY(tanh);
00658 UNARY(atan);
00659 UNARY(atanh);
00660
00661 UNARY(isFinite);
00662 UNARY(isInfinite);
00663 UNARY(isNan);
00664 UNARY(sign);
00665
00666 UNARY_REAL(sqrt)
00667 BINARY_NAMED(root, radicand, n, Calculate \p n-th root of real-valued \p radicand);
00668 UNARY(pow2);
00669 BINARY_EXTRA(pow, base, power, \p base raised to \p power (exponent). If \p isreal is true then bypass checks for complexity (faster)., ,bool isreal=false);
00670 UNARY(ceil);
00671 UNARY(floor);
00672 UNARY(round);
00673 UNARY(trunc);
00674 UNARY(factorial);
00675
00676 EXTREMUM(min, Minimum);
00677 EXTREMUM(max, Maximum);
00678
00679 UNARY_REAL(log)
00680 UNARY(log2);
00681 UNARY(log10);
00682 UNARY(log1p);
00683 UNARY(exp);
00684 UNARY(expm1);
00685 UNARY(gamma);
00686 UNARY(gammaln);
00687 UNARY(epsilon);
00688
00689 UNARY(erf);
00690 UNARY(erfc);
00691 UNARY(erfinv);
00692 UNARY(erfcinv);
00693
00694 UNARY(abs);
00695 UNARY(arg);
00696 UNARY(conj);
00697
00698 UNARY(real);
00699 UNARY(imag);
00700 UNARY(complex);
00701 BINARY_NAMED(complex, real, imaginary, Form a complex result from \p real and \p imaginary parts);
00702
00703 BINARY(atan2, arc tangent function of two variables);
00704 BINARY(hypot, Euclidean distance function without undue overflow or underflow during intermediate steps);
00705 BINARY(rem, remainder);
00706 BINARY_NAMED(mod, x, y, Compute <tt>x-n*y</tt> where \c n is quotient of <tt>x/y</tt>. Round toward zero.);
00708
00709
00712
00729
00730 #define print(exp) disp(exp, #exp)
00731
00732
00736 AFAPI void disp(const array exp, const char *expstr=NULL);
00737
00738
00740
00743
00750 AFAPI array lower(const array& input, int diagonal=0);
00751
00758 AFAPI array upper(const array& input, int diagonal=0);
00759
00766 AFAPI array diagonal(const array& input, int diag=0);
00767
00792 AFAPI array join(const array& A, const array& B, int dim=0);
00793
00794
00801 AFAPI array newdims(const array &input, const dim4 &newdims);
00802
00812 AFAPI array newdims(const array &input, int dim0, int dim1=1, int dim2=1, int dim3=1);
00813
00818 AFAPI array flipv(const array& in);
00819
00824 AFAPI array fliph(const array& in);
00825
00832 AFAPI array flipdim(const array& in, unsigned dim);
00833
00834
00836
00837 #undef CONSTRUCTOR
00838 #undef REAL_CONSTRUCTOR
00839 #undef ROOT_CONSTRUCTOR
00840 #undef GENERATOR
00841 #undef UNARY_OPERATION
00842 #undef BINARY_OPERATION
00843 #undef UNARY
00844 #undef UNARY_REAL
00845 #undef BINARY
00846 #undef BINARY_NAMED
00847 #undef BINARY_EXTRA
00848 #undef EXTREMUM
00849 #undef OPERATOR
00850 #undef COMPARISON
00851 #undef INDEX
00852 #undef ALLOC
00853 #undef DATA
00854
00855
00857 AFAPI void sync();
00858
00860 inline array eval(array a) { a.eval(); return a; }
00861 inline void eval(array a,array b) { eval(a); b.eval(); }
00862 inline void eval(array a,array b,array c) { eval(a,b); c.eval(); }
00863 inline void eval(array a,array b,array c,array d) { eval(a,b,c); d.eval(); }
00864 inline void eval(array a,array b,array c,array d,array e) { eval(a,b,c,d); e.eval(); }
00865 inline void eval(array a,array b,array c,array d,array e,array f) { eval(a,b,c,d,e); f.eval(); }
00866
00867 };