#include <iostream>
#include <arrayfire.h>
#include <stdio.h>
#include <assert.h>
using namespace af;
static void multi_Sgemv(int iterations, int ngpu, array *AMatrix, int m, const float* X, int n, float *Y)
{
array *YVector = new array[ngpu];
float **YVector_host = new float* [ngpu];
for (int i = 0 ; i < iterations; i++) {
for (int idx = 0; idx < ngpu; idx++) {
deviceset(idx);
YVector[idx] = AMatrix[idx] * array(n/ngpu, X + idx*(n/ngpu), afHostPointer);
}
for (int idx = 0; idx < ngpu; idx++) {
deviceset(idx);
YVector_host[idx] = YVector[idx].host<float>();
}
for (int i = 0; i < m; i++) {
Y[i] = 0;
for (int j = 0; j < ngpu; j++)
Y[i] += YVector_host[j][i];
}
for (int idx = 0; idx < ngpu; idx++)
delete [] YVector_host[idx];
}
delete [] YVector;
delete [] YVector_host;
}
static void fill_ones(float *X, int n)
{
while (n--)
*(X++) = 1;
}
#define MB (1024 * 1024)
#define mb(x) (unsigned)((x) / MB + !!((x) % MB))
int main(int argc, char **argv)
{
try {
printf("Multi-GPU Matrix-Vector Multiply: y = A*x\n\n"
"The system matrix 'A' is distributed across the available devices.\n"
"Each iteration pushes 'x' to the devices, multiplies against the matrix 'A',\n"
"and pulls the result 'y' back to the host.\n\n");
af::info();
int iterations = 1000;
int ngpu = devicecount();
int n = ngpu*9000;
printf("size(A)=[%d,%d] (%u mb)\n", n, n, mb(n * n * sizeof(float)));
printf("benchmarking........\n\n");
float *A = new float[n*n], *X = new float[n], *Y = new float[n];
fill_ones(A, n*n);
fill_ones(X, n*1);
array *AMatrix = new array[ngpu];
for (int idx = 0; idx < ngpu; idx++) {
deviceset(idx);
AMatrix[idx] = array(n, n/ngpu, A + (n*n/ngpu * idx), afHostPointer);
}
delete[] A;
af::sync();
timer::tic();
multi_Sgemv(iterations, ngpu, AMatrix, n, X, n, Y);
af::sync();
printf("elapsed time : %g seconds\n", timer::toc() / iterations);
for (int i = 0; i < n; ++i)
assert(Y[i] == n);
delete[] X; delete[] Y;
} catch (af::exception& e) {
fprintf(stderr, "%s\n", e.what());
}
#ifdef WIN32 // pause in Windows
if (!(argc == 2 && argv[1][0] == '-')) {
printf("hit [enter]...");
getchar();
}
#endif
return 0;
}