From 4556630efec7859d80c6e57cc2e8f83584f3cfe8 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 02:31:53 +0800 Subject: [PATCH 01/33] set MKL_USE_SINGLE_DYNAMIC_LIBRARY as disable. Otherwise it may cause some issues. Signed-off-by: Zhigang Gong --- cmake/Modules/FindMKL.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Modules/FindMKL.cmake b/cmake/Modules/FindMKL.cmake index 8ac6fc0c1e3..81a23fd8a71 100644 --- a/cmake/Modules/FindMKL.cmake +++ b/cmake/Modules/FindMKL.cmake @@ -14,7 +14,7 @@ # ---[ Options -caffe_option(MKL_USE_SINGLE_DYNAMIC_LIBRARY "Use single dynamic library interface" ON) +caffe_option(MKL_USE_SINGLE_DYNAMIC_LIBRARY "Use single dynamic library interface" OFF) caffe_option(MKL_USE_STATIC_LIBS "Use static libraries" OFF IF NOT MKL_USE_SINGLE_DYNAMIC_LIBRARY) caffe_option(MKL_MULTI_THREADED "Use multi-threading" ON IF NOT MKL_USE_SINGLE_DYNAMIC_LIBRARY) From c5ebc9ed6ab3cd2d18bfe375ee2e6bb9320b5284 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 02:35:34 +0800 Subject: [PATCH 02/33] Add optimized GEMM/GEMV into caffe greentea math library. The image interface is much faster than buffer interface GEMM in ISAAC. Before we add new image based interface in ISAAC, we may have to keep this implementation. Signed-off-by: Zhigang Gong --- include/caffe/greentea/greentea_math_functions.hpp | 11 +- src/caffe/greentea/cl_kernels.cpp | 1644 +++++++++++++++- src/caffe/greentea/cl_kernels.sh | 13 +- src/caffe/greentea/cl_kernels/gemm.cl | 1958 ++++++++++++++++++++ src/caffe/greentea/cl_kernels/matvec_mul.cl | 143 ++ src/caffe/greentea/greentea_math_functions.cpp | 1135 ++++++++++-- 6 files changed, 4706 insertions(+), 198 deletions(-) create mode 100644 src/caffe/greentea/cl_kernels/gemm.cl create mode 100644 src/caffe/greentea/cl_kernels/matvec_mul.cl diff --git a/include/caffe/greentea/greentea_math_functions.hpp b/include/caffe/greentea/greentea_math_functions.hpp index 58364dca3ed..1462b9b7c48 100644 --- a/include/caffe/greentea/greentea_math_functions.hpp +++ b/include/caffe/greentea/greentea_math_functions.hpp @@ -53,7 +53,16 @@ void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const int_tp N, const int_tp K, const Dtype alpha, const cl_mem A, const int_tp offA, const cl_mem B, const int_tp offB, const Dtype beta, cl_mem C, - const int_tp offC); + const int_tp offC , const bool is_image_a = false, + const bool is_image_b = false); + +void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, + cl_mem *image, cl_mem buffer, int offset, + bool is_matrix_a, bool transpose, + bool padding, int padded_height, + int padded_width, int height, + int width, int wait_list_size, + cl_event *wait_list, cl_event *event); template void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index 9135a32da8f..ddae5f6f898 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -3012,6 +3012,1491 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT +"#define TILE_M 32", // NOLINT +"#define TILE_K 8", // NOLINT +"#define TILE_N 8", // NOLINT +"", // NOLINT +"// common block to calculate (alpha * AxB + beta * C) and output to destination image.", // NOLINT +"", // NOLINT +"//#define USE_IMAGE_C", // NOLINT +"#ifdef USE_IMAGE_C", // NOLINT +"#define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read8( _C, _coordC ) )", // NOLINT +"#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) )", // NOLINT +"#define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst", // NOLINT +"#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint))", // NOLINT +"#else", // NOLINT +"#define BLOCKC_READ8( _C, _coordC ) (float8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * N + _coordC.x + get_local_id(0) ] : 0)", // NOLINT +"", // NOLINT +"#define BLOCKC_WRITE8( _C, _coordC, _val) do { if (_coordC.x + get_local_id(0) < N) { if (_coordC.y < M) _C[ _coordC.y * N + _coordC.x + get_local_id(0) ] = _val.s0; if (_coordC.y + 1 < M) _C[ ( _coordC.y + 1 )* N + _coordC.x + get_local_id(0) ] = _val.s1; if (_coordC.y + 2 < M) _C[ ( _coordC.y + 2 )* N + _coordC.x + get_local_id(0) ] = _val.s2; if (_coordC.y + 3 < M) _C[ ( _coordC.y + 3 )* N + _coordC.x + get_local_id(0) ] = _val.s3; if (_coordC.y + 4 < M) _C[ ( _coordC.y + 4 )* N + _coordC.x + get_local_id(0) ] = _val.s4; if (_coordC.y + 5 < M) _C[ ( _coordC.y + 5 )* N + _coordC.x + get_local_id(0) ] = _val.s5; if (_coordC.y + 6 < M) _C[ ( _coordC.y + 6 )* N + _coordC.x + get_local_id(0) ] = _val.s6; if (_coordC.y + 7 < M) _C[ ( _coordC.y + 7 )* N + _coordC.x + get_local_id(0) ] = _val.s7; }} while(0)", // NOLINT +"#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N", // NOLINT +"#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1)", // NOLINT +"#endif", // NOLINT +"", // NOLINT +"#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); int2 coordC = coordDst; float8 blockC00; float8 blockC01; float8 blockC02; float8 blockC03; if (BETA_NOT0) { blockC00 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC01 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC02 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC03 = BLOCKC_READ8( _C, coordC ); if (!ALPHA1) { blockC00 *= beta; blockC01 *= beta; blockC02 *= beta; blockC03 *= beta; blockC00 = mad(blockAxB00, (float8)alpha, blockC00); blockC01 = mad(blockAxB01, (float8)alpha, blockC01); blockC02 = mad(blockAxB02, (float8)alpha, blockC02); blockC03 = mad(blockAxB03, (float8)alpha, blockC03); } else { blockC00 = mad(blockC00, (float8)beta, blockAxB00); blockC01 = mad(blockC01, (float8)beta, blockAxB01); blockC02 = mad(blockC02, (float8)beta, blockAxB02); blockC03 = mad(blockC03, (float8)beta, blockAxB03); } } else { if (!ALPHA1) { blockC00 = blockAxB00 * alpha; blockC01 = blockAxB01 * alpha; blockC02 = blockAxB02 * alpha; blockC03 = blockAxB03 * alpha; } else { blockC00 = blockAxB00; blockC01 = blockAxB01; blockC02 = blockAxB02; blockC03 = blockAxB03; } } BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC03 );", // NOLINT +"", // NOLINT +"// Get the specified column of the block of the block", // NOLINT +"#define TRANSPOSE_BLOCK_8( _block, _col ) (float8)( intel_sub_group_shuffle( _block.s0, _col ), intel_sub_group_shuffle( _block.s1, _col ), intel_sub_group_shuffle( _block.s2, _col ), intel_sub_group_shuffle( _block.s3, _col ), intel_sub_group_shuffle( _block.s4, _col ), intel_sub_group_shuffle( _block.s5, _col ), intel_sub_group_shuffle( _block.s6, _col ), intel_sub_group_shuffle( _block.s7, _col ) );", // NOLINT +"", // NOLINT +"// A's column block multiply B 's row block.", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); _result = mad( (float8)(_blockB.s0), acol0, _result ); _result = mad( (float8)(_blockB.s1), acol1, _result ); _result = mad( (float8)(_blockB.s2), acol2, _result ); _result = mad( (float8)(_blockB.s3), acol3, _result ); _result = mad( (float8)(_blockB.s4), acol4, _result ); _result = mad( (float8)(_blockB.s5), acol5, _result ); _result = mad( (float8)(_blockB.s6), acol6, _result ); _result = mad( (float8)(_blockB.s7), acol7, _result ); }", // NOLINT +"", // NOLINT +"#define GEMM_NN(ALPHA1, BETA_NOT0) __attribute__((reqd_work_group_size(8, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha, float beta, int width0) { const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( ( group_x * TILE_N ) * sizeof(uint), 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( intel_sub_group_block_read8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.x += TILE_K * sizeof(uint); MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"", // NOLINT +"GEMM_NN(1, 0) // ALPHA == 1, BETA == 0", // NOLINT +"GEMM_NN(1, 1) // ALPHA == 1, BETA != 0", // NOLINT +"GEMM_NN(0, 0) // ALPHA != 1, BETA == 0", // NOLINT +"GEMM_NN(0, 1) // ALPHA != 1, BETA != 0", // NOLINT +"", // NOLINT +"#undef TRANSPOSE_BLOCK_8", // NOLINT +"#undef MULTIPLY_BLOCKS_8x8", // NOLINT +"", // NOLINT +"// replicate the first row to column block.", // NOLINT +"#define TRANSPOSE_BLOCK_8(_vec) (float8)( intel_sub_group_shuffle(_vec, 0), intel_sub_group_shuffle(_vec, 1), intel_sub_group_shuffle(_vec, 2), intel_sub_group_shuffle(_vec, 3), intel_sub_group_shuffle(_vec, 4), intel_sub_group_shuffle(_vec, 5), intel_sub_group_shuffle(_vec, 6), intel_sub_group_shuffle(_vec, 7) )", // NOLINT +"", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { _result = mad( (float8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0), _result ); _result = mad( (float8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1), _result ); _result = mad( (float8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2), _result ); _result = mad( (float8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3), _result ); _result = mad( (float8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4), _result ); _result = mad( (float8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5), _result ); _result = mad( (float8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6), _result ); _result = mad( (float8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7), _result ); }", // NOLINT +"", // NOLINT +"#define GEMM_TN(ALPHA1, BETA_NOT0) __attribute__((reqd_work_group_size(8, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha, float beta, int width0) { const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * sizeof(uint), 0 ); int2 coordB = (int2)( ( group_x * TILE_N ) * sizeof(uint), 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( intel_sub_group_block_read8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"", // NOLINT +"GEMM_TN(1, 0) // ALPHA == 1, BETA == 0", // NOLINT +"GEMM_TN(1, 1) // ALPHA == 1, BETA != 0", // NOLINT +"GEMM_TN(0, 0) // ALPHA != 1, BETA == 0", // NOLINT +"GEMM_TN(0, 1) // ALPHA != 1, BETA != 0", // NOLINT +"", // NOLINT +"#undef MULTIPLY_BLOCKS_8x8", // NOLINT +"#undef TRANSPOSE_BLOCK_8", // NOLINT +"", // NOLINT +"// The same as GEMM_NN", // NOLINT +"#define TRANSPOSE_BLOCK_8( _block, _col ) (float8)( intel_sub_group_shuffle( _block.s0, _col), intel_sub_group_shuffle( _block.s1, _col), intel_sub_group_shuffle( _block.s2, _col), intel_sub_group_shuffle( _block.s3, _col), intel_sub_group_shuffle( _block.s4, _col), intel_sub_group_shuffle( _block.s5, _col), intel_sub_group_shuffle( _block.s6, _col), intel_sub_group_shuffle( _block.s7, _col) )", // NOLINT +"", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); _result = mad( (float8)_blockB.s0, acol0, _result ); _result = mad( (float8)_blockB.s1, acol1, _result ); _result = mad( (float8)_blockB.s2, acol2, _result ); _result = mad( (float8)_blockB.s3, acol3, _result ); _result = mad( (float8)_blockB.s4, acol4, _result ); _result = mad( (float8)_blockB.s5, acol5, _result ); _result = mad( (float8)_blockB.s6, acol6, _result ); _result = mad( (float8)_blockB.s7, acol7, _result ); }", // NOLINT +"", // NOLINT +"", // NOLINT +"", // NOLINT +"#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((reqd_work_group_size(8, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha, float beta, int padded_k, int k) { const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.x += TILE_K * sizeof(uint); MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"", // NOLINT +"", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb.s0123 = read_imagef(_B, sampler, _coordBTemp); _coordBTemp.x += 1; _blockb.s4567 = read_imagef(_B, sampler, _coordBTemp); _coordB.x += 2;", // NOLINT +"", // NOLINT +"#define MATB_PARAMETER __read_only image2d_t B", // NOLINT +"", // NOLINT +"GEMM_NT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0", // NOLINT +"GEMM_NT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0", // NOLINT +"GEMM_NT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0", // NOLINT +"GEMM_NT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0", // NOLINT +"#undef BLOCKB_READ8", // NOLINT +"#undef MATB_PARAMETER", // NOLINT +"", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb = *(__global float8*)&_B[_coordBTemp.y * k + _coordBTemp.x + offB]; _coordB.x += TILE_K;", // NOLINT +"", // NOLINT +"#define MATB_PARAMETER __global float *B, int offB", // NOLINT +"", // NOLINT +"GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0", // NOLINT +"GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0", // NOLINT +"GEMM_NT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0", // NOLINT +"GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0", // NOLINT +"#undef BLOCKB_READ8", // NOLINT +"#undef MATB_PARAMETER", // NOLINT +"", // NOLINT +"", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT +"", // NOLINT +"#define MATB_PARAMETER __read_only image2d_t B", // NOLINT +"", // NOLINT +"GEMM_NT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0", // NOLINT +"GEMM_NT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0", // NOLINT +"GEMM_NT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0", // NOLINT +"GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0", // NOLINT +"#undef BLOCKB_READ8", // NOLINT +"#undef MATB_PARAMETER", // NOLINT +"", // NOLINT +"#undef MULTIPLY_BLOCKS_8x8", // NOLINT +"#undef TRANSPOSE_BLOCK_8", // NOLINT +"", // NOLINT +"//The same as GEMM_TN.", // NOLINT +"#define TRANSPOSE_BLOCK_8(_vec) (float8)( intel_sub_group_shuffle(_vec, 0), intel_sub_group_shuffle(_vec, 1), intel_sub_group_shuffle(_vec, 2), intel_sub_group_shuffle(_vec, 3), intel_sub_group_shuffle(_vec, 4), intel_sub_group_shuffle(_vec, 5), intel_sub_group_shuffle(_vec, 6), intel_sub_group_shuffle(_vec, 7) );", // NOLINT +"", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7 ); _result = mad( (float8)_blockB.s0, acol0, _result ); _result = mad( (float8)_blockB.s1, acol1, _result ); _result = mad( (float8)_blockB.s2, acol2, _result ); _result = mad( (float8)_blockB.s3, acol3, _result ); _result = mad( (float8)_blockB.s4, acol4, _result ); _result = mad( (float8)_blockB.s5, acol5, _result ); _result = mad( (float8)_blockB.s6, acol6, _result ); _result = mad( (float8)_blockB.s7, acol7, _result ); }", // NOLINT +"", // NOLINT +"#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((reqd_work_group_size(8, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha, float beta, int padded_k, int k) { const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * sizeof(uint), 0 ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0);}", // NOLINT +"", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); blockB00.s0123 = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; blockB00.s4567 = read_imagef(B, _coordBTemp); _coordB.x += 2;", // NOLINT +"", // NOLINT +"#define MATB_PARAMETER __read_only image2d_t B", // NOLINT +"", // NOLINT +"GEMM_TT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0", // NOLINT +"GEMM_TT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0", // NOLINT +"GEMM_TT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0", // NOLINT +"GEMM_TT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0", // NOLINT +"#undef BLOCKB_READ8", // NOLINT +"#undef MATB_PARAMETER", // NOLINT +"", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb = *(__global float8*)&_B[_coordBTemp.y * k + _coordBTemp.x + offB]; _coordB.x += TILE_K;", // NOLINT +"", // NOLINT +"#define MATB_PARAMETER __global float *B, int offB", // NOLINT +"", // NOLINT +"GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0", // NOLINT +"GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0", // NOLINT +"GEMM_TT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0", // NOLINT +"GEMM_TT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0", // NOLINT +"#undef BLOCKB_READ8", // NOLINT +"#undef MATB_PARAMETER", // NOLINT +"", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT +"", // NOLINT +"#define MATB_PARAMETER __read_only image2d_t B", // NOLINT +"", // NOLINT +"GEMM_TT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0", // NOLINT +"GEMM_TT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0", // NOLINT +"GEMM_TT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0", // NOLINT +"GEMM_TT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0", // NOLINT +"#undef BLOCKB_READ8", // NOLINT +"#undef MATB_PARAMETER", // NOLINT +"", // NOLINT +"#undef MULTIPLY_BLOCKS_8x8", // NOLINT +"#undef TRANSPOSE_BLOCK_8", // NOLINT +"", // NOLINT +"#undef TILE_M", // NOLINT +"#undef TILE_K", // NOLINT +"#undef TILE_N", // NOLINT +"", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_copy_image,Dtype)(", // NOLINT +"__global float* A,", // NOLINT +"__write_only image2d_t ImA,", // NOLINT +"int offA,", // NOLINT +"int width,", // NOLINT +"int height)", // NOLINT +"{", // NOLINT +"const int gidx = get_global_id(0);", // NOLINT +"const int gidy = get_global_id(1);", // NOLINT +"int2 coord_dst = (int2)(gidx, gidy);", // NOLINT +"if (gidx >= width || gidy >= height) {", // NOLINT +"write_imageui(ImA, coord_dst, (uint4)0);", // NOLINT +"return;", // NOLINT +"}", // NOLINT +"__global float* A_off = A + offA;", // NOLINT +"uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * width + gidx]));", // NOLINT +"write_imageui(ImA, coord_dst, srcA);", // NOLINT +"}", // NOLINT +"", // NOLINT +"#define VEC_SIZE 4", // NOLINT +"#define LWG_HEIGHT 4", // NOLINT +"#define TILE_M 8", // NOLINT +"#define TILE_K 16", // NOLINT +"#define TILE_N 32", // NOLINT +"", // NOLINT +"__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_NN, Dtype)(", // NOLINT +"const __global float *src0, int off0,", // NOLINT +"const __global float *src1, int off1,", // NOLINT +"__global float *dst, int offd,", // NOLINT +"int M,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"float alpha,", // NOLINT +"float beta,", // NOLINT +"int start_index)", // NOLINT +"{", // NOLINT +"const int group_x = get_group_id(0);", // NOLINT +"const int group_y = get_group_id(1);", // NOLINT +"const int local_x = get_local_id(0);", // NOLINT +"const int local_y = get_local_id(1);", // NOLINT +"const int global_x = get_global_id(0);", // NOLINT +"const int global_y = get_global_id(1);", // NOLINT +"", // NOLINT +"float4 brow;", // NOLINT +"float2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7;", // NOLINT +"", // NOLINT +"__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"", // NOLINT +"const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * K + start_index + off0;", // NOLINT +"", // NOLINT +"const __global float *src1_read0 = src1 + local_x * VEC_SIZE + ( group_x * TILE_N ) + start_index * N + off1;", // NOLINT +"", // NOLINT +"int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M);", // NOLINT +"", // NOLINT +"int row0 = mad24(global_y, TILE_M, 0) < M ? 0 : border;", // NOLINT +"int row1 = mad24(global_y, TILE_M, 1) < M ? 1 : border;", // NOLINT +"int row2 = mad24(global_y, TILE_M, 2) < M ? 2 : border;", // NOLINT +"int row3 = mad24(global_y, TILE_M, 3) < M ? 3 : border;", // NOLINT +"int row4 = mad24(global_y, TILE_M, 4) < M ? 4 : border;", // NOLINT +"int row5 = mad24(global_y, TILE_M, 5) < M ? 5 : border;", // NOLINT +"int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border;", // NOLINT +"int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border;", // NOLINT +"", // NOLINT +"float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : beta * ((__global float4 *)dst_write0)[0];", // NOLINT +"float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 1 * N))[0] : beta * ((__global float4 *)(dst_write0 + 1 * N))[0];", // NOLINT +"float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : beta * ((__global float4 *)(dst_write0 + 2 * N))[0];", // NOLINT +"float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : beta * ((__global float4 *)(dst_write0 + 3 * N))[0];", // NOLINT +"float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : beta * ((__global float4 *)(dst_write0 + 4 * N))[0];", // NOLINT +"float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : beta * ((__global float4 *)(dst_write0 + 5 * N))[0];", // NOLINT +"float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : beta * ((__global float4 *)(dst_write0 + 6 * N))[0];", // NOLINT +"float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : beta * ((__global float4 *)(dst_write0 + 7 * N))[0];", // NOLINT +"", // NOLINT +"int end_index = min(start_index + 256, K);", // NOLINT +"int w = start_index;", // NOLINT +"while( w + TILE_K <= end_index ) {", // NOLINT +"arow0 = alpha * ((__global float2 *)(src0_read + row0 * K))[0];", // NOLINT +"arow1 = alpha * ((__global float2 *)(src0_read + row1 * K))[0];", // NOLINT +"arow2 = alpha * ((__global float2 *)(src0_read + row2 * K))[0];", // NOLINT +"arow3 = alpha * ((__global float2 *)(src0_read + row3 * K))[0];", // NOLINT +"arow4 = alpha * ((__global float2 *)(src0_read + row4 * K))[0];", // NOLINT +"arow5 = alpha * ((__global float2 *)(src0_read + row5 * K))[0];", // NOLINT +"arow6 = alpha * ((__global float2 *)(src0_read + row6 * K))[0];", // NOLINT +"arow7 = alpha * ((__global float2 *)(src0_read + row7 * K))[0];", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( index, suffix ) brow = ((__global float4 *)src1_read0)[0]; src1_read0 += N; dot00 = mad( (float4)(intel_sub_group_shuffle( arow0, index ).s##suffix), brow, dot00 ); dot01 = mad( (float4)(intel_sub_group_shuffle( arow1, index ).s##suffix), brow, dot01 ); dot02 = mad( (float4)(intel_sub_group_shuffle( arow2, index ).s##suffix), brow, dot02 ); dot03 = mad( (float4)(intel_sub_group_shuffle( arow3, index ).s##suffix), brow, dot03 ); dot04 = mad( (float4)(intel_sub_group_shuffle( arow4, index ).s##suffix), brow, dot04 ); dot05 = mad( (float4)(intel_sub_group_shuffle( arow5, index ).s##suffix), brow, dot05 ); dot06 = mad( (float4)(intel_sub_group_shuffle( arow6, index ).s##suffix), brow, dot06 ); dot07 = mad( (float4)(intel_sub_group_shuffle( arow7, index ).s##suffix), brow, dot07 );", // NOLINT +"MM_DOT_PRODUCT(0, 0);", // NOLINT +"MM_DOT_PRODUCT(0, 1);", // NOLINT +"MM_DOT_PRODUCT(1, 0);", // NOLINT +"MM_DOT_PRODUCT(1, 1);", // NOLINT +"MM_DOT_PRODUCT(2, 0);", // NOLINT +"MM_DOT_PRODUCT(2, 1);", // NOLINT +"MM_DOT_PRODUCT(3, 0);", // NOLINT +"MM_DOT_PRODUCT(3, 1);", // NOLINT +"MM_DOT_PRODUCT(4, 0);", // NOLINT +"MM_DOT_PRODUCT(4, 1);", // NOLINT +"MM_DOT_PRODUCT(5, 0);", // NOLINT +"MM_DOT_PRODUCT(5, 1);", // NOLINT +"MM_DOT_PRODUCT(6, 0);", // NOLINT +"MM_DOT_PRODUCT(6, 1);", // NOLINT +"MM_DOT_PRODUCT(7, 0);", // NOLINT +"MM_DOT_PRODUCT(7, 1);", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"", // NOLINT +"src0_read += TILE_K;", // NOLINT +"w += TILE_K;", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(w < end_index) {", // NOLINT +"arow0.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row0 * K)[0] : 0.0f;", // NOLINT +"arow0.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row0 * K)[1] : 0.0f;", // NOLINT +"arow1.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row1 * K)[0] : 0.0f;", // NOLINT +"arow1.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row1 * K)[1] : 0.0f;", // NOLINT +"arow2.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row2 * K)[0] : 0.0f;", // NOLINT +"arow2.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row2 * K)[1] : 0.0f;", // NOLINT +"arow3.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row3 * K)[0] : 0.0f;", // NOLINT +"arow3.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row3 * K)[1] : 0.0f;", // NOLINT +"arow4.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row4 * K)[0] : 0.0f;", // NOLINT +"arow4.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row4 * K)[1] : 0.0f;", // NOLINT +"arow5.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row5 * K)[0] : 0.0f;", // NOLINT +"arow5.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row5 * K)[1] : 0.0f;", // NOLINT +"arow6.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row6 * K)[0] : 0.0f;", // NOLINT +"arow6.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row6 * K)[1] : 0.0f;", // NOLINT +"arow7.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row7 * K)[0] : 0.0f;", // NOLINT +"arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f;", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( index, suffix ) brow = (w < K) ? ((__global float4 *)src1_read0)[0] : 0.0f; src1_read0 += N; w++; dot00 = mad( (float4)(intel_sub_group_shuffle( arow0, index ).s##suffix), brow, dot00 ); dot01 = mad( (float4)(intel_sub_group_shuffle( arow1, index ).s##suffix), brow, dot01 ); dot02 = mad( (float4)(intel_sub_group_shuffle( arow2, index ).s##suffix), brow, dot02 ); dot03 = mad( (float4)(intel_sub_group_shuffle( arow3, index ).s##suffix), brow, dot03 ); dot04 = mad( (float4)(intel_sub_group_shuffle( arow4, index ).s##suffix), brow, dot04 ); dot05 = mad( (float4)(intel_sub_group_shuffle( arow5, index ).s##suffix), brow, dot05 ); dot06 = mad( (float4)(intel_sub_group_shuffle( arow6, index ).s##suffix), brow, dot06 ); dot07 = mad( (float4)(intel_sub_group_shuffle( arow7, index ).s##suffix), brow, dot07 );", // NOLINT +"MM_DOT_PRODUCT(0, 0);", // NOLINT +"MM_DOT_PRODUCT(0, 1);", // NOLINT +"MM_DOT_PRODUCT(1, 0);", // NOLINT +"MM_DOT_PRODUCT(1, 1);", // NOLINT +"MM_DOT_PRODUCT(2, 0);", // NOLINT +"MM_DOT_PRODUCT(2, 1);", // NOLINT +"MM_DOT_PRODUCT(3, 0);", // NOLINT +"MM_DOT_PRODUCT(3, 1);", // NOLINT +"MM_DOT_PRODUCT(4, 0);", // NOLINT +"MM_DOT_PRODUCT(4, 1);", // NOLINT +"MM_DOT_PRODUCT(5, 0);", // NOLINT +"MM_DOT_PRODUCT(5, 1);", // NOLINT +"MM_DOT_PRODUCT(6, 0);", // NOLINT +"MM_DOT_PRODUCT(6, 1);", // NOLINT +"MM_DOT_PRODUCT(7, 0);", // NOLINT +"MM_DOT_PRODUCT(7, 1);", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(global_x * 4 < N && global_y * 8 < M) {", // NOLINT +"if(mad24(global_x, 4, 3) < N) {", // NOLINT +"__global float4 *dst_write = (__global float4 *)dst_write0;", // NOLINT +"dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; }", // NOLINT +"} else if(mad24(global_x, 4, 2) < N) {", // NOLINT +"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write[0] = dot00.xy;", // NOLINT +"dst_write0[2] = dot00.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"if(mad24(global_y, 8, 1) < M) {", // NOLINT +"dst_write[0] = dot01.xy;", // NOLINT +"dst_write0[2] = dot01.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) {", // NOLINT +"dst_write[0] = dot02.xy;", // NOLINT +"dst_write0[2] = dot02.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) {", // NOLINT +"dst_write[0] = dot03.xy;", // NOLINT +"dst_write0[2] = dot03.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) {", // NOLINT +"dst_write[0] = dot04.xy;", // NOLINT +"dst_write0[2] = dot04.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) {", // NOLINT +"dst_write[0] = dot05.xy;", // NOLINT +"dst_write0[2] = dot05.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) {", // NOLINT +"dst_write[0] = dot06.xy;", // NOLINT +"dst_write0[2] = dot06.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) {", // NOLINT +"dst_write[0] = dot07.xy;", // NOLINT +"dst_write0[2] = dot07.z;", // NOLINT +"}", // NOLINT +"} else if(mad24(global_x, 4, 1) < N) {", // NOLINT +"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; }", // NOLINT +"} else {", // NOLINT +"dst_write0[0] = dot00.x; dst_write0 += N;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; }", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"#undef VEC_SIZE", // NOLINT +"#undef LWG_HEIGHT", // NOLINT +"#undef TILE_M", // NOLINT +"#undef TILE_K", // NOLINT +"#undef TILE_N", // NOLINT +"", // NOLINT +"#define VEC_SIZE 1", // NOLINT +"#define LWG_HEIGHT 16", // NOLINT +"#define TILE_M 8", // NOLINT +"#define TILE_K 32", // NOLINT +"#define TILE_N 8", // NOLINT +"#define SLM_BLOCK 512", // NOLINT +"", // NOLINT +"__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(", // NOLINT +"const __global float *src0, int off0,", // NOLINT +"const __global float *src1, int off1,", // NOLINT +"__global float *dst, int offd,", // NOLINT +"int M,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"float alpha,", // NOLINT +"float beta)", // NOLINT +"{", // NOLINT +"const int group_x = get_group_id(0);", // NOLINT +"const int group_y = get_group_id(1);", // NOLINT +"const int local_x = get_local_id(0);", // NOLINT +"const int local_y = get_local_id(1);", // NOLINT +"const int global_x = get_global_id(0);", // NOLINT +"const int global_y = get_global_id(1);", // NOLINT +"", // NOLINT +"float8 dot00 = 0.f;", // NOLINT +"float8 dot01 = 0.f;", // NOLINT +"float8 dot02 = 0.f;", // NOLINT +"float8 dot03 = 0.f;", // NOLINT +"float8 dot04 = 0.f;", // NOLINT +"float8 dot05 = 0.f;", // NOLINT +"float8 dot06 = 0.f;", // NOLINT +"float8 dot07 = 0.f;", // NOLINT +"", // NOLINT +"float4 brow0;", // NOLINT +"float4 brow1;", // NOLINT +"float4 brow2;", // NOLINT +"float4 brow3;", // NOLINT +"float4 brow4;", // NOLINT +"float4 brow5;", // NOLINT +"float4 brow6;", // NOLINT +"float4 brow7;", // NOLINT +"", // NOLINT +"__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"", // NOLINT +"const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * K + off0;", // NOLINT +"", // NOLINT +"const __global float *src1_read0 = src1 + ( group_x * TILE_N ) * K + off1;", // NOLINT +"", // NOLINT +"__local float slm_brow[8 * SLM_BLOCK];", // NOLINT +"__local float* slm_brow0;", // NOLINT +"", // NOLINT +"int local_index = mad24(local_y, 8, local_x) * 4;", // NOLINT +"int w;", // NOLINT +"for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"((__local float4 *)(slm_brow + mad24(0, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(0, K, local_index)))[0];", // NOLINT +"((__local float4 *)(slm_brow + mad24(1, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(1, K, local_index)))[0];", // NOLINT +"((__local float4 *)(slm_brow + mad24(2, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(2, K, local_index)))[0];", // NOLINT +"((__local float4 *)(slm_brow + mad24(3, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(3, K, local_index)))[0];", // NOLINT +"((__local float4 *)(slm_brow + mad24(4, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(4, K, local_index)))[0];", // NOLINT +"((__local float4 *)(slm_brow + mad24(5, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(5, K, local_index)))[0];", // NOLINT +"((__local float4 *)(slm_brow + mad24(6, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(6, K, local_index)))[0];", // NOLINT +"((__local float4 *)(slm_brow + mad24(7, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(7, K, local_index)))[0];", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"", // NOLINT +"slm_brow0 = slm_brow + local_x * (TILE_K / 8);", // NOLINT +"w = b_tile;", // NOLINT +"int end_w = min(b_tile + SLM_BLOCK, K);", // NOLINT +"while( w + TILE_K <= end_w ) {", // NOLINT +"float4 arow;", // NOLINT +"", // NOLINT +"brow0 = ((__local float4 *)(slm_brow0 + 0 * SLM_BLOCK))[0];", // NOLINT +"brow1 = ((__local float4 *)(slm_brow0 + 1 * SLM_BLOCK))[0];", // NOLINT +"brow2 = ((__local float4 *)(slm_brow0 + 2 * SLM_BLOCK))[0];", // NOLINT +"brow3 = ((__local float4 *)(slm_brow0 + 3 * SLM_BLOCK))[0];", // NOLINT +"brow4 = ((__local float4 *)(slm_brow0 + 4 * SLM_BLOCK))[0];", // NOLINT +"brow5 = ((__local float4 *)(slm_brow0 + 5 * SLM_BLOCK))[0];", // NOLINT +"brow6 = ((__local float4 *)(slm_brow0 + 6 * SLM_BLOCK))[0];", // NOLINT +"brow7 = ((__local float4 *)(slm_brow0 + 7 * SLM_BLOCK))[0];", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _row, _dot ) arow = ((__global float4 *)(src0_read + _row * K))[0]; _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT +"MM_DOT_PRODUCT( 0, dot00 );", // NOLINT +"MM_DOT_PRODUCT( 1, dot01 );", // NOLINT +"MM_DOT_PRODUCT( 2, dot02 );", // NOLINT +"MM_DOT_PRODUCT( 3, dot03 );", // NOLINT +"MM_DOT_PRODUCT( 4, dot04 );", // NOLINT +"MM_DOT_PRODUCT( 5, dot05 );", // NOLINT +"MM_DOT_PRODUCT( 6, dot06 );", // NOLINT +"MM_DOT_PRODUCT( 7, dot07 );", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"", // NOLINT +"src0_read += TILE_K;", // NOLINT +"slm_brow0 += TILE_K;", // NOLINT +"w += TILE_K;", // NOLINT +"}", // NOLINT +"src1_read0 += SLM_BLOCK;", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(w < K) {", // NOLINT +"float4 arow;", // NOLINT +"", // NOLINT +"#define READ_BROW(_brow, _row) _brow = ((__local float4 *)(slm_brow0 + _row * SLM_BLOCK))[0]; _brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; _brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; _brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; _brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;", // NOLINT +"READ_BROW(brow0, 0);", // NOLINT +"READ_BROW(brow1, 1);", // NOLINT +"READ_BROW(brow2, 2);", // NOLINT +"READ_BROW(brow3, 3);", // NOLINT +"READ_BROW(brow4, 4);", // NOLINT +"READ_BROW(brow5, 5);", // NOLINT +"READ_BROW(brow6, 6);", // NOLINT +"READ_BROW(brow7, 7);", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _row, _dot ) arow = ((__global float4 *)(src0_read + _row * K))[0]; arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT +"MM_DOT_PRODUCT( 0, dot00 );", // NOLINT +"MM_DOT_PRODUCT( 1, dot01 );", // NOLINT +"MM_DOT_PRODUCT( 2, dot02 );", // NOLINT +"MM_DOT_PRODUCT( 3, dot03 );", // NOLINT +"MM_DOT_PRODUCT( 4, dot04 );", // NOLINT +"MM_DOT_PRODUCT( 5, dot05 );", // NOLINT +"MM_DOT_PRODUCT( 6, dot06 );", // NOLINT +"MM_DOT_PRODUCT( 7, dot07 );", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"}", // NOLINT +"", // NOLINT +"#define REDUCE(_dot) _dot = intel_sub_group_shuffle(_dot, 0) + intel_sub_group_shuffle(_dot, 1) + intel_sub_group_shuffle(_dot, 2) + intel_sub_group_shuffle(_dot, 3) + intel_sub_group_shuffle(_dot, 4) + intel_sub_group_shuffle(_dot, 5) + intel_sub_group_shuffle(_dot, 6) + intel_sub_group_shuffle(_dot, 7);", // NOLINT +"REDUCE(dot00);", // NOLINT +"REDUCE(dot01);", // NOLINT +"REDUCE(dot02);", // NOLINT +"REDUCE(dot03);", // NOLINT +"REDUCE(dot04);", // NOLINT +"REDUCE(dot05);", // NOLINT +"REDUCE(dot06);", // NOLINT +"REDUCE(dot07);", // NOLINT +"#undef REDUCE", // NOLINT +"", // NOLINT +"float output = 0.0f;", // NOLINT +"#define OUTPUT( _dot) output = (local_x == 0) ? _dot.s0 : output; output = (local_x == 1) ? _dot.s1 : output; output = (local_x == 2) ? _dot.s2 : output; output = (local_x == 3) ? _dot.s3 : output; output = (local_x == 4) ? _dot.s4 : output; output = (local_x == 5) ? _dot.s5 : output; output = (local_x == 6) ? _dot.s6 : output; output = (local_x == 7) ? _dot.s7 : output; dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); dst_write0 += N;", // NOLINT +"", // NOLINT +"if(global_x < N && global_y * 8 < M) {", // NOLINT +"OUTPUT(dot00);", // NOLINT +"if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }", // NOLINT +"if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }", // NOLINT +"if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }", // NOLINT +"if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }", // NOLINT +"if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }", // NOLINT +"if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }", // NOLINT +"if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }", // NOLINT +"}", // NOLINT +"#undef OUTPUT", // NOLINT +"}", // NOLINT +"", // NOLINT +"#undef VEC_SIZE", // NOLINT +"#undef LWG_HEIGHT", // NOLINT +"#undef TILE_M", // NOLINT +"#undef TILE_K", // NOLINT +"#undef TILE_N", // NOLINT +"#undef SLM_BLOCK", // NOLINT +"", // NOLINT +"#define SLM_SIZE 64", // NOLINT +"void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(", // NOLINT +"const __global Dtype* srca_read0,", // NOLINT +"const __global Dtype* srca_read1,", // NOLINT +"const __global Dtype* srcb_read,", // NOLINT +"__local Dtype4* work0,", // NOLINT +"__local Dtype4* work1,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"int x_gid,", // NOLINT +"int lid,", // NOLINT +"Dtype alpha,", // NOLINT +"Dtype beta,", // NOLINT +"__global Dtype* dstc0,", // NOLINT +"__global Dtype* dstc1)", // NOLINT +"{", // NOLINT +"__local Dtype* work_each0 = (__local Dtype*)work0;", // NOLINT +"__local Dtype* work_each1 = (__local Dtype*)work1;", // NOLINT +"", // NOLINT +"int rows = N - x_gid * 4;", // NOLINT +"", // NOLINT +"Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"", // NOLINT +"int i = lid;", // NOLINT +"while( i < K / 4) {", // NOLINT +"const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT +"const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"dot0[j] += b0 * vload4(i, srcb_read + j * K);", // NOLINT +"dot1[j] += b1 * vload4(i, srcb_read + j * K);", // NOLINT +"}", // NOLINT +"", // NOLINT +"i += get_local_size(0);", // NOLINT +"}", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;", // NOLINT +"work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(i == K / 4) {", // NOLINT +"short tail_items = K % 4;", // NOLINT +"", // NOLINT +"if(tail_items != 0) {", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + i * 4;", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + i * 4;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + i * 4;", // NOLINT +"#pragma unroll", // NOLINT +"for(short i = 0; i < tail_items; ++i) {", // NOLINT +"const Dtype at0 = srca_tail0[i];", // NOLINT +"const Dtype at1 = srca_tail1[i];", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];", // NOLINT +"work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"if(lid < stride) {", // NOLINT +"work0[lid] += work0[lid+stride];", // NOLINT +"work1[lid] += work1[lid+stride];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(lid == 0) {", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];", // NOLINT +"dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(", // NOLINT +"__global const Dtype * A,", // NOLINT +"int offA,", // NOLINT +"__global const Dtype * B,", // NOLINT +"int offB,", // NOLINT +"__global Dtype * C,", // NOLINT +"int offC,", // NOLINT +"int M,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"float alpha_f,", // NOLINT +"float beta_f)", // NOLINT +"{", // NOLINT +"Dtype alpha = (Dtype)alpha_f;", // NOLINT +"Dtype beta = (Dtype)beta_f;", // NOLINT +"int x_gid = get_group_id(0);", // NOLINT +"int lid = get_local_id(0);", // NOLINT +"", // NOLINT +"const __global Dtype *srca_read0 = A + offA;", // NOLINT +"const __global Dtype *srca_read1 = srca_read0 + K;", // NOLINT +"", // NOLINT +"const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT +"", // NOLINT +"__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);", // NOLINT +"__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);", // NOLINT +"", // NOLINT +"__local Dtype4 work0[SLM_SIZE];", // NOLINT +"__local Dtype4 work1[SLM_SIZE];", // NOLINT +"__local Dtype* work_each0 = (__local Dtype*)work0;", // NOLINT +"__local Dtype* work_each1 = (__local Dtype*)work1;", // NOLINT +"", // NOLINT +"if(x_gid == N / 4) {", // NOLINT +"TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1);", // NOLINT +"} else {", // NOLINT +"Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"int i = lid;", // NOLINT +"while( i < K / 4) {", // NOLINT +"const Dtype4 b0 = vload4(i, srca_read0);", // NOLINT +"const Dtype4 b1 = vload4(i, srca_read1);", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < 4; ++j) {", // NOLINT +"Dtype4 a = vload4(i, srcb_read + j * K);", // NOLINT +"dot0[j] += b0 * a;", // NOLINT +"dot1[j] += b1 * a;", // NOLINT +"}", // NOLINT +"i += get_local_size(0);", // NOLINT +"}", // NOLINT +"", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < 4; ++j) {", // NOLINT +"work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;", // NOLINT +"work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(i == K / 4) {", // NOLINT +"short tail_items = K % 4;", // NOLINT +"if(tail_items != 0) {", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + i * 4;", // NOLINT +"", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + i * 4;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + i * 4;", // NOLINT +"#pragma unroll", // NOLINT +"for(short i = 0; i < tail_items; ++i) {", // NOLINT +"const Dtype at0 = srca_tail0[i];", // NOLINT +"const Dtype at1 = srca_tail1[i];", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < 4; ++j) {", // NOLINT +"work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];", // NOLINT +"work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"if(lid < stride) {", // NOLINT +"work0[lid] += work0[lid+stride];", // NOLINT +"work1[lid] += work1[lid+stride];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(lid == 0) {", // NOLINT +"dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];", // NOLINT +"dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"#undef SLM_SIZE", // NOLINT +"", // NOLINT +"#define SLM_SIZE 32", // NOLINT +"void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(", // NOLINT +"const __global Dtype* srca_read0,", // NOLINT +"const __global Dtype* srca_read1,", // NOLINT +"const __global Dtype* srca_read2,", // NOLINT +"const __global Dtype* srca_read3,", // NOLINT +"const __global Dtype* srcb_read,", // NOLINT +"__local Dtype4* work0,", // NOLINT +"__local Dtype4* work1,", // NOLINT +"__local Dtype4* work2,", // NOLINT +"__local Dtype4* work3,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"int x_gid,", // NOLINT +"int lid,", // NOLINT +"Dtype alpha,", // NOLINT +"Dtype beta,", // NOLINT +"__global Dtype* dstc0,", // NOLINT +"__global Dtype* dstc1,", // NOLINT +"__global Dtype* dstc2,", // NOLINT +"__global Dtype* dstc3)", // NOLINT +"{", // NOLINT +"__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);", // NOLINT +"__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);", // NOLINT +"__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);", // NOLINT +"__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);", // NOLINT +"", // NOLINT +"int rows = N - x_gid * 4;", // NOLINT +"", // NOLINT +"Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"", // NOLINT +"int i = lid;", // NOLINT +"while( i < K / 4) {", // NOLINT +"const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT +"const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT +"const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]};", // NOLINT +"const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]};", // NOLINT +"#pragma unrol", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"dot0[j] += a0 * vload4(i, srcb_read + j * K);", // NOLINT +"dot1[j] += a1 * vload4(i, srcb_read + j * K);", // NOLINT +"dot2[j] += a2 * vload4(i, srcb_read + j * K);", // NOLINT +"dot3[j] += a3 * vload4(i, srcb_read + j * K);", // NOLINT +"}", // NOLINT +"", // NOLINT +"i += get_local_size(0);", // NOLINT +"}", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;", // NOLINT +"work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;", // NOLINT +"work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;", // NOLINT +"work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(i == K / 4) {", // NOLINT +"short tail_items = K % 4;", // NOLINT +"", // NOLINT +"if(tail_items != 0) {", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + i * 4;", // NOLINT +"", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + i * 4;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + i * 4;", // NOLINT +"const __global Dtype *srca_tail2 = srca_read2 + i * 4;", // NOLINT +"const __global Dtype *srca_tail3 = srca_read3 + i * 4;", // NOLINT +"#pragma unroll", // NOLINT +"for(short i = 0; i < tail_items; ++i) {", // NOLINT +"const Dtype at0 = srca_tail0[i];", // NOLINT +"const Dtype at1 = srca_tail1[i];", // NOLINT +"const Dtype at2 = srca_tail2[i];", // NOLINT +"const Dtype at3 = srca_tail3[i];", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"work_each0[j] += at0 * srcb_tail[i + j * K];", // NOLINT +"work_each1[j] += at1 * srcb_tail[i + j * K];", // NOLINT +"work_each2[j] += at2 * srcb_tail[i + j * K];", // NOLINT +"work_each3[j] += at3 * srcb_tail[i + j * K];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"if(lid < stride) {", // NOLINT +"work0[lid] += work0[lid+stride];", // NOLINT +"work1[lid] += work1[lid+stride];", // NOLINT +"work2[lid] += work2[lid+stride];", // NOLINT +"work3[lid] += work3[lid+stride];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(lid == 0) {", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];", // NOLINT +"dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];", // NOLINT +"dstc2[(x_gid * 4 + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)];", // NOLINT +"dstc3[(x_gid * 4 + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(", // NOLINT +"__global const Dtype * A,", // NOLINT +"int offA,", // NOLINT +"__global const Dtype * B,", // NOLINT +"int offB,", // NOLINT +"__global Dtype * C,", // NOLINT +"int offC,", // NOLINT +"int M,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"float alpha_f,", // NOLINT +"float beta_f)", // NOLINT +"{", // NOLINT +"Dtype alpha = (Dtype)alpha_f;", // NOLINT +"Dtype beta = (Dtype)beta_f;", // NOLINT +"int x_gid = get_group_id(0);", // NOLINT +"int lid = get_local_id(0);", // NOLINT +"int lsize = get_local_size(0);", // NOLINT +"", // NOLINT +"const __global Dtype *srca_read0 = A + offA;", // NOLINT +"const __global Dtype *srca_read1 = srca_read0 + K;", // NOLINT +"const __global Dtype *srca_read2 = srca_read1 + K;", // NOLINT +"const __global Dtype *srca_read3 = srca_read2 + K;", // NOLINT +"", // NOLINT +"const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT +"", // NOLINT +"__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);", // NOLINT +"__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);", // NOLINT +"__global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N);", // NOLINT +"__global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N);", // NOLINT +"", // NOLINT +"__local Dtype4 work0[SLM_SIZE];", // NOLINT +"__local Dtype4 work1[SLM_SIZE];", // NOLINT +"__local Dtype4 work2[SLM_SIZE];", // NOLINT +"__local Dtype4 work3[SLM_SIZE];", // NOLINT +"__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);", // NOLINT +"__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);", // NOLINT +"__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);", // NOLINT +"__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);", // NOLINT +"", // NOLINT +"if(x_gid == N / 4) {", // NOLINT +"TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) (srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3);", // NOLINT +"} else {", // NOLINT +"Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"", // NOLINT +"int kid = lid;", // NOLINT +"while( kid < K / 4) {", // NOLINT +"const Dtype4 b0 = vload4(kid, srca_read0);", // NOLINT +"const Dtype4 b1 = vload4(kid, srca_read1);", // NOLINT +"const Dtype4 b2 = vload4(kid, srca_read2);", // NOLINT +"const Dtype4 b3 = vload4(kid, srca_read3);", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < 4; ++j) {", // NOLINT +"Dtype4 a = vload4(kid, srcb_read + j * K);", // NOLINT +"dot0[j] += b0 * a;", // NOLINT +"dot1[j] += b1 * a;", // NOLINT +"dot2[j] += b2 * a;", // NOLINT +"dot3[j] += b3 * a;", // NOLINT +"}", // NOLINT +"kid += lsize;", // NOLINT +"}", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < 4; ++j) {", // NOLINT +"work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;", // NOLINT +"work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;", // NOLINT +"work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;", // NOLINT +"work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(kid == (K >> 2)) {", // NOLINT +"short tail_items = K % 4;", // NOLINT +"if(tail_items != 0) {", // NOLINT +"int offset = kid << 2;", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + offset;", // NOLINT +"", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + offset;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + offset;", // NOLINT +"const __global Dtype *srca_tail2 = srca_read2 + offset;", // NOLINT +"const __global Dtype *srca_tail3 = srca_read3 + offset;", // NOLINT +"#pragma unroll", // NOLINT +"for(short i = 0; i < tail_items; ++i) {", // NOLINT +"const Dtype at0 = srca_tail0[i];", // NOLINT +"const Dtype at1 = srca_tail1[i];", // NOLINT +"const Dtype at2 = srca_tail2[i];", // NOLINT +"const Dtype at3 = srca_tail3[i];", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < 4; ++j) {", // NOLINT +"work_each0[j] += at0 * srcb_tail[i + j * K];", // NOLINT +"work_each1[j] += at1 * srcb_tail[i + j * K];", // NOLINT +"work_each2[j] += at2 * srcb_tail[i + j * K];", // NOLINT +"work_each3[j] += at3 * srcb_tail[i + j * K];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"if(lid < stride) {", // NOLINT +"work0[lid] += work0[lid+stride];", // NOLINT +"work1[lid] += work1[lid+stride];", // NOLINT +"work2[lid] += work2[lid+stride];", // NOLINT +"work3[lid] += work3[lid+stride];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(lid == 0) {", // NOLINT +"dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];", // NOLINT +"dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];", // NOLINT +"dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];", // NOLINT +"dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"#undef SLM_SIZE", // NOLINT +"", // NOLINT +"#define SLM_SIZE 16", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(", // NOLINT +"__global const Dtype * A,", // NOLINT +"int offA,", // NOLINT +"__global const Dtype * B,", // NOLINT +"int offB,", // NOLINT +"__global Dtype * C,", // NOLINT +"int offC,", // NOLINT +"int M,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"float alpha_f,", // NOLINT +"float beta_f)", // NOLINT +"{", // NOLINT +"Dtype alpha = (Dtype)alpha_f;", // NOLINT +"Dtype beta = (Dtype)beta_f;", // NOLINT +"int x_gid = get_group_id(0);", // NOLINT +"int lid = get_local_id(0);", // NOLINT +"int lsize = get_local_size(0);", // NOLINT +"", // NOLINT +"const __global Dtype *srca_read0 = A + offA;", // NOLINT +"const __global Dtype *srca_read1 = srca_read0 + K;", // NOLINT +"const __global Dtype *srca_read2 = srca_read1 + K;", // NOLINT +"const __global Dtype *srca_read3 = srca_read2 + K;", // NOLINT +"const __global Dtype *srca_read4 = srca_read3 + K;", // NOLINT +"const __global Dtype *srca_read5 = srca_read4 + K;", // NOLINT +"const __global Dtype *srca_read6 = srca_read5 + K;", // NOLINT +"const __global Dtype *srca_read7 = srca_read6 + K;", // NOLINT +"", // NOLINT +"const __global Dtype *srcb_read = B + x_gid * K + offB;", // NOLINT +"", // NOLINT +"__global Dtype *dstc0 = C + offC;", // NOLINT +"__global Dtype *dstc1 = dstc0 + N;", // NOLINT +"__global Dtype *dstc2 = dstc1 + N;", // NOLINT +"__global Dtype *dstc3 = dstc2 + N;", // NOLINT +"__global Dtype *dstc4 = dstc3 + N;", // NOLINT +"__global Dtype *dstc5 = dstc4 + N;", // NOLINT +"__global Dtype *dstc6 = dstc5 + N;", // NOLINT +"__global Dtype *dstc7 = dstc6 + N;", // NOLINT +"", // NOLINT +"__local Dtype work0[SLM_SIZE];", // NOLINT +"__local Dtype work1[SLM_SIZE];", // NOLINT +"__local Dtype work2[SLM_SIZE];", // NOLINT +"__local Dtype work3[SLM_SIZE];", // NOLINT +"__local Dtype work4[SLM_SIZE];", // NOLINT +"__local Dtype work5[SLM_SIZE];", // NOLINT +"__local Dtype work6[SLM_SIZE];", // NOLINT +"__local Dtype work7[SLM_SIZE];", // NOLINT +"", // NOLINT +"Dtype4 dot0 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot1 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot2 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot3 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot4 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot5 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot6 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot7 = (Dtype4)(0.);", // NOLINT +"", // NOLINT +"int kid = lid;", // NOLINT +"while( kid < K / 4) {", // NOLINT +"const Dtype4 a0 = vload4(kid, srca_read0);", // NOLINT +"const Dtype4 a1 = vload4(kid, srca_read1);", // NOLINT +"const Dtype4 a2 = vload4(kid, srca_read2);", // NOLINT +"const Dtype4 a3 = vload4(kid, srca_read3);", // NOLINT +"const Dtype4 a4 = vload4(kid, srca_read4);", // NOLINT +"const Dtype4 a5 = vload4(kid, srca_read5);", // NOLINT +"const Dtype4 a6 = vload4(kid, srca_read6);", // NOLINT +"const Dtype4 a7 = vload4(kid, srca_read7);", // NOLINT +"Dtype4 b = vload4(kid, srcb_read);", // NOLINT +"dot0 += a0 * b;", // NOLINT +"dot1 += a1 * b;", // NOLINT +"dot2 += a2 * b;", // NOLINT +"dot3 += a3 * b;", // NOLINT +"dot4 += a4 * b;", // NOLINT +"dot5 += a5 * b;", // NOLINT +"dot6 += a6 * b;", // NOLINT +"dot7 += a7 * b;", // NOLINT +"", // NOLINT +"kid += lsize;", // NOLINT +"}", // NOLINT +"work0[lid] = dot0.x + dot0.y + dot0.z + dot0.w;", // NOLINT +"work1[lid] = dot1.x + dot1.y + dot1.z + dot1.w;", // NOLINT +"work2[lid] = dot2.x + dot2.y + dot2.z + dot2.w;", // NOLINT +"work3[lid] = dot3.x + dot3.y + dot3.z + dot3.w;", // NOLINT +"work4[lid] = dot4.x + dot4.y + dot4.z + dot4.w;", // NOLINT +"work5[lid] = dot5.x + dot5.y + dot5.z + dot5.w;", // NOLINT +"work6[lid] = dot6.x + dot6.y + dot6.z + dot6.w;", // NOLINT +"work7[lid] = dot7.x + dot7.y + dot7.z + dot7.w;", // NOLINT +"", // NOLINT +"if(kid == (K >> 2)) {", // NOLINT +"short tail_items = K % 4;", // NOLINT +"if(tail_items != 0) {", // NOLINT +"int offset = kid << 2;", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + offset;", // NOLINT +"", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + offset;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + offset;", // NOLINT +"const __global Dtype *srca_tail2 = srca_read2 + offset;", // NOLINT +"const __global Dtype *srca_tail3 = srca_read3 + offset;", // NOLINT +"const __global Dtype *srca_tail4 = srca_read4 + offset;", // NOLINT +"const __global Dtype *srca_tail5 = srca_read5 + offset;", // NOLINT +"const __global Dtype *srca_tail6 = srca_read6 + offset;", // NOLINT +"const __global Dtype *srca_tail7 = srca_read7 + offset;", // NOLINT +"#pragma unroll", // NOLINT +"for(short item = 0; item < tail_items; ++item) {", // NOLINT +"work0[lid] += srca_tail0[item] * srcb_tail[item];", // NOLINT +"work1[lid] += srca_tail1[item] * srcb_tail[item];", // NOLINT +"work2[lid] += srca_tail2[item] * srcb_tail[item];", // NOLINT +"work3[lid] += srca_tail3[item] * srcb_tail[item];", // NOLINT +"work4[lid] += srca_tail4[item] * srcb_tail[item];", // NOLINT +"work5[lid] += srca_tail5[item] * srcb_tail[item];", // NOLINT +"work6[lid] += srca_tail6[item] * srcb_tail[item];", // NOLINT +"work7[lid] += srca_tail7[item] * srcb_tail[item];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"if(lid < stride) {", // NOLINT +"work0[lid] += work0[lid+stride];", // NOLINT +"work1[lid] += work1[lid+stride];", // NOLINT +"work2[lid] += work2[lid+stride];", // NOLINT +"work3[lid] += work3[lid+stride];", // NOLINT +"work4[lid] += work4[lid+stride];", // NOLINT +"work5[lid] += work5[lid+stride];", // NOLINT +"work6[lid] += work6[lid+stride];", // NOLINT +"work7[lid] += work7[lid+stride];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(lid == 0) {", // NOLINT +"dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];", // NOLINT +"dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];", // NOLINT +"dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];", // NOLINT +"dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];", // NOLINT +"dstc4[x_gid] = alpha * work4[0] + beta * dstc4[x_gid];", // NOLINT +"dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid];", // NOLINT +"dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid];", // NOLINT +"dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"#undef SLM_SIZE", // NOLINT +"", // NOLINT +"#define VEC_SIZE 4", // NOLINT +"#define LWG_HEIGHT 4", // NOLINT +"#define TILE_M 8", // NOLINT +"#define TILE_K 16", // NOLINT +"#define TILE_N 32", // NOLINT +"", // NOLINT +"__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_TN, Dtype)(", // NOLINT +"const __global float *src0, int off0,", // NOLINT +"const __global float *src1, int off1,", // NOLINT +"__global float *dst, int offd,", // NOLINT +"int M,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"float alpha,", // NOLINT +"float beta,", // NOLINT +"int start_index)", // NOLINT +"", // NOLINT +"{", // NOLINT +"const int group_x = get_group_id(0);", // NOLINT +"const int group_y = get_group_id(1);", // NOLINT +"const int local_x = get_local_id(0);", // NOLINT +"const int local_y = get_local_id(1);", // NOLINT +"const int global_x = get_global_id(0);", // NOLINT +"const int global_y = get_global_id(1);", // NOLINT +"", // NOLINT +"float4 brow;", // NOLINT +"", // NOLINT +"__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"", // NOLINT +"const __global float *src0_read = src0 + (local_x * ( TILE_K / 8 ) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT +"", // NOLINT +"const __global float *src1_read0 = src1 + local_x * VEC_SIZE + ( group_x * TILE_N ) + start_index * N + off1;", // NOLINT +"", // NOLINT +"float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : (beta * ((__global float4 *)dst_write0)[0]);", // NOLINT +"float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + N))[0] : (beta * ((__global float4 *)(dst_write0 + N))[0]);", // NOLINT +"float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 2 * N))[0]);", // NOLINT +"float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 3 * N))[0]);", // NOLINT +"float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 4 * N))[0]);", // NOLINT +"float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 5 * N))[0]);", // NOLINT +"float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 6 * N))[0]);", // NOLINT +"float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 7 * N))[0]);", // NOLINT +"", // NOLINT +"int end_index = min(start_index + 256, K);", // NOLINT +"while( start_index + TILE_K <= end_index ) {", // NOLINT +"float8 arow0 = alpha * ((__global float8 *)src0_read)[0];", // NOLINT +"float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0];", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _arow ) brow = ((__global float4 *)src1_read0)[0]; src1_read0 += N; dot00 = mad( (float4)(_arow.s0), brow, dot00 ); dot01 = mad( (float4)(_arow.s1), brow, dot01 ); dot02 = mad( (float4)(_arow.s2), brow, dot02 ); dot03 = mad( (float4)(_arow.s3), brow, dot03 ); dot04 = mad( (float4)(_arow.s4), brow, dot04 ); dot05 = mad( (float4)(_arow.s5), brow, dot05 ); dot06 = mad( (float4)(_arow.s6), brow, dot06 ); dot07 = mad( (float4)(_arow.s7), brow, dot07 );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 0 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 0 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 1 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 1 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 2 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 2 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 3 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 3 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 4 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 4 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 5 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 5 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 6 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 6 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 7 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 7 ) );", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"", // NOLINT +"src0_read += TILE_K * M;", // NOLINT +"start_index += TILE_K;", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(start_index < end_index) {", // NOLINT +"float8 arow0 = ((start_index + local_x * 2) < K) ? (alpha * ((__global float8 *)src0_read)[0]) : 0.0f;", // NOLINT +"float8 arow1 = ((start_index + local_x * 2 + 1) < K) ? (alpha * ((__global float8 *)(src0_read + M))[0]) : 0.0f;", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _arow ) brow = (start_index < K) ? ((__global float4 *)src1_read0)[0] : 0.0f; src1_read0 += N; start_index++; dot00 = mad( (float4)(_arow.s0), brow, dot00 ); dot01 = mad( (float4)(_arow.s1), brow, dot01 ); dot02 = mad( (float4)(_arow.s2), brow, dot02 ); dot03 = mad( (float4)(_arow.s3), brow, dot03 ); dot04 = mad( (float4)(_arow.s4), brow, dot04 ); dot05 = mad( (float4)(_arow.s5), brow, dot05 ); dot06 = mad( (float4)(_arow.s6), brow, dot06 ); dot07 = mad( (float4)(_arow.s7), brow, dot07 );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 0 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 0 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 1 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 1 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 2 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 2 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 3 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 3 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 4 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 4 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 5 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 5 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 6 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 6 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 7 ) );", // NOLINT +"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 7 ) );", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(global_x * 4 < N && global_y * 8 < M) {", // NOLINT +"if(mad24(global_x, 4, 3) < N) {", // NOLINT +"__global float4 *dst_write = (__global float4 *)dst_write0;", // NOLINT +"dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; }", // NOLINT +"} else if(mad24(global_x, 4, 2) < N) {", // NOLINT +"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write[0] = dot00.xy; dst_write0[2] = dot00.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"if(mad24(global_y, 8, 1) < M) {", // NOLINT +"dst_write[0] = dot01.xy; dst_write0[2] = dot01.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) {", // NOLINT +"dst_write[0] = dot02.xy; dst_write0[2] = dot02.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) {", // NOLINT +"dst_write[0] = dot03.xy; dst_write0[2] = dot03.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) {", // NOLINT +"dst_write[0] = dot04.xy; dst_write0[2] = dot04.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) {", // NOLINT +"dst_write[0] = dot05.xy; dst_write0[2] = dot05.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) {", // NOLINT +"dst_write[0] = dot06.xy; dst_write0[2] = dot06.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) {", // NOLINT +"dst_write[0] = dot07.xy; dst_write0[2] = dot07.z;", // NOLINT +"}", // NOLINT +"} else if(mad24(global_x, 4, 1) < N) {", // NOLINT +"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; }", // NOLINT +"} else {", // NOLINT +"dst_write0[0] = dot00.x; dst_write0 += N;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; }", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"#undef VEC_SIZE", // NOLINT +"#undef LWG_HEIGHT", // NOLINT +"#undef TILE_M", // NOLINT +"#undef TILE_K", // NOLINT +"#undef TILE_N", // NOLINT +"", // NOLINT +"#define VEC_SIZE 4", // NOLINT +"#define LWG_HEIGHT 4", // NOLINT +"#define TILE_M 8", // NOLINT +"#define TILE_K 16", // NOLINT +"#define TILE_N 32", // NOLINT +"", // NOLINT +"__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_TT, Dtype)(", // NOLINT +"const __global float *src0, int off0,", // NOLINT +"const __global float *src1, int off1,", // NOLINT +"__global float *dst, int offd,", // NOLINT +"int M,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"float alpha,", // NOLINT +"float beta,", // NOLINT +"int start_index)", // NOLINT +"", // NOLINT +"{", // NOLINT +"const int group_x = get_group_id(0);", // NOLINT +"const int group_y = get_group_id(1);", // NOLINT +"const int local_x = get_local_id(0);", // NOLINT +"const int local_y = get_local_id(1);", // NOLINT +"const int global_x = get_global_id(0);", // NOLINT +"const int global_y = get_global_id(1);", // NOLINT +"", // NOLINT +"float8 dot0 = 0.f;", // NOLINT +"float8 dot1 = 0.f;", // NOLINT +"float8 dot2 = 0.f;", // NOLINT +"float8 dot3 = 0.f;", // NOLINT +"", // NOLINT +"float16 brow0;", // NOLINT +"float16 brow1;", // NOLINT +"float16 brow2;", // NOLINT +"float16 brow3;", // NOLINT +"", // NOLINT +"__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"", // NOLINT +"const __global float *src0_read = src0 + (local_x * ( TILE_K / 8 ) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT +"", // NOLINT +"const __global float *src1_read0 = src1 + (local_x * VEC_SIZE + ( group_x * TILE_N )) * K + start_index + off1;", // NOLINT +"", // NOLINT +"float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : (beta * ((__global float4 *)dst_write0)[0]);", // NOLINT +"float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + N))[0] : (beta * ((__global float4 *)(dst_write0 + N))[0]);", // NOLINT +"float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 2 * N))[0]);", // NOLINT +"float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 3 * N))[0]);", // NOLINT +"float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 4 * N))[0]);", // NOLINT +"float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 5 * N))[0]);", // NOLINT +"float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 6 * N))[0]);", // NOLINT +"float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 7 * N))[0]);", // NOLINT +"", // NOLINT +"int end_index = min(start_index + 256, K);", // NOLINT +"while( start_index + TILE_K <= end_index ) {", // NOLINT +"brow0 = ((__global float16 *)src1_read0)[0];", // NOLINT +"brow1 = ((__global float16 *)(src1_read0 + K))[0];", // NOLINT +"brow2 = ((__global float16 *)(src1_read0 + 2 * K))[0];", // NOLINT +"brow3 = ((__global float16 *)(src1_read0 + 3 * K))[0];", // NOLINT +"", // NOLINT +"float8 arow0 = alpha * ((__global float8 *)src0_read)[0];", // NOLINT +"float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0];", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _brow, _dot) _dot = mad( intel_sub_group_shuffle( arow0, 0 ), (float8)_brow.s0, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 0 ), (float8)_brow.s1, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 1 ), (float8)_brow.s2, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 1 ), (float8)_brow.s3, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 2 ), (float8)_brow.s4, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 2 ), (float8)_brow.s5, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 3 ), (float8)_brow.s6, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 3 ), (float8)_brow.s7, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 4 ), (float8)_brow.s8, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 4 ), (float8)_brow.s9, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 5 ), (float8)_brow.sa, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 5 ), (float8)_brow.sb, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 6 ), (float8)_brow.sc, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 6 ), (float8)_brow.sd, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 7 ), (float8)_brow.se, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 7 ), (float8)_brow.sf, _dot );", // NOLINT +"MM_DOT_PRODUCT( brow0, dot0 );", // NOLINT +"MM_DOT_PRODUCT( brow1, dot1 );", // NOLINT +"MM_DOT_PRODUCT( brow2, dot2 );", // NOLINT +"MM_DOT_PRODUCT( brow3, dot3 );", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"", // NOLINT +"src1_read0 += TILE_K;", // NOLINT +"src0_read += TILE_K * M;", // NOLINT +"start_index += TILE_K;", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(start_index < end_index) {", // NOLINT +"brow0 = ((__global float16 *)src1_read0)[0]; src1_read0 += K;", // NOLINT +"brow1 = ((__global float16 *)src1_read0)[0]; src1_read0 += K;", // NOLINT +"brow2 = ((__global float16 *)src1_read0)[0]; src1_read0 += K;", // NOLINT +"brow3 = ((__global float16 *)src1_read0)[0];", // NOLINT +"", // NOLINT +"float8 arow0 = alpha * ((__global float8 *)src0_read)[0];", // NOLINT +"float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0];", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _brow, _dot) _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 0 ), (float8)_brow.s0, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 0 ), (float8)_brow.s1, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 1 ), (float8)_brow.s2, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 1 ), (float8)_brow.s3, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 2 ), (float8)_brow.s4, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 2 ), (float8)_brow.s5, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 3 ), (float8)_brow.s6, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 3 ), (float8)_brow.s7, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 4 ), (float8)_brow.s8, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 4 ), (float8)_brow.s9, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 5 ), (float8)_brow.sa, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 5 ), (float8)_brow.sb, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 6 ), (float8)_brow.sc, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 6 ), (float8)_brow.sd, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 7 ), (float8)_brow.se, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 7 ), (float8)_brow.sf, _dot ) : _dot;", // NOLINT +"int w = start_index;", // NOLINT +"MM_DOT_PRODUCT( brow0, dot0 );", // NOLINT +"w = start_index;", // NOLINT +"MM_DOT_PRODUCT( brow1, dot1 );", // NOLINT +"w = start_index;", // NOLINT +"MM_DOT_PRODUCT( brow2, dot2 );", // NOLINT +"w = start_index;", // NOLINT +"MM_DOT_PRODUCT( brow3, dot3 );", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"}", // NOLINT +"", // NOLINT +"dot00 += (float4)(dot0.s0, dot1.s0, dot2.s0, dot3.s0);", // NOLINT +"dot01 += (float4)(dot0.s1, dot1.s1, dot2.s1, dot3.s1);", // NOLINT +"dot02 += (float4)(dot0.s2, dot1.s2, dot2.s2, dot3.s2);", // NOLINT +"dot03 += (float4)(dot0.s3, dot1.s3, dot2.s3, dot3.s3);", // NOLINT +"dot04 += (float4)(dot0.s4, dot1.s4, dot2.s4, dot3.s4);", // NOLINT +"dot05 += (float4)(dot0.s5, dot1.s5, dot2.s5, dot3.s5);", // NOLINT +"dot06 += (float4)(dot0.s6, dot1.s6, dot2.s6, dot3.s6);", // NOLINT +"dot07 += (float4)(dot0.s7, dot1.s7, dot2.s7, dot3.s7);", // NOLINT +"", // NOLINT +"if(global_x * 4 < N && global_y * 8 < M) {", // NOLINT +"if(mad24(global_x, 4, 3) < N) {", // NOLINT +"__global float4 *dst_write = (__global float4 *)dst_write0;", // NOLINT +"dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; }", // NOLINT +"} else if(mad24(global_x, 4, 2) < N) {", // NOLINT +"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write[0] = dot00.xy; dst_write0[2] = dot00.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"if(mad24(global_y, 8, 1) < M) {", // NOLINT +"dst_write[0] = dot01.xy; dst_write0[2] = dot01.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) {", // NOLINT +"dst_write[0] = dot02.xy; dst_write0[2] = dot02.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) {", // NOLINT +"dst_write[0] = dot03.xy; dst_write0[2] = dot03.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) {", // NOLINT +"dst_write[0] = dot04.xy; dst_write0[2] = dot04.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) {", // NOLINT +"dst_write[0] = dot05.xy; dst_write0[2] = dot05.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) {", // NOLINT +"dst_write[0] = dot06.xy; dst_write0[2] = dot06.z;", // NOLINT +"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"} else", // NOLINT +"return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) {", // NOLINT +"dst_write[0] = dot07.xy; dst_write0[2] = dot07.z;", // NOLINT +"}", // NOLINT +"} else if(mad24(global_x, 4, 1) < N) {", // NOLINT +"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; }", // NOLINT +"} else {", // NOLINT +"dst_write0[0] = dot00.x; dst_write0 += N;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; }", // NOLINT +"else return;", // NOLINT +"if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; }", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"#undef VEC_SIZE", // NOLINT +"#undef LWG_HEIGHT", // NOLINT +"#undef TILE_M", // NOLINT +"#undef TILE_K", // NOLINT +"#undef TILE_N", // NOLINT +""}, // NOLINT + {"#ifndef __OPENCL_VERSION__", // NOLINT +"#include \"header.cl\"", // NOLINT +"#endif", // NOLINT +"", // NOLINT "__kernel void TEMPLATE(im2col,Dtype)(const int_tp n,", // NOLINT "__global const Dtype* data_im,", // NOLINT "const int_tp data_im_off,", // NOLINT @@ -3757,6 +5242,150 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT +"__kernel void TEMPLATE(matvec_mul4,Dtype)(", // NOLINT +"__global const float * A,", // NOLINT +"int offA,", // NOLINT +"unsigned int A_col_size,", // NOLINT +"unsigned int trail_item,", // NOLINT +"__global const float * v,", // NOLINT +"int offv,", // NOLINT +"float alpha,", // NOLINT +"float beta,", // NOLINT +"__global float4 * result,", // NOLINT +"int offr,", // NOLINT +"__local float4 * work)", // NOLINT +"{", // NOLINT +"unsigned int row_gid = get_group_id(0);", // NOLINT +"unsigned int lid = get_local_id(0);", // NOLINT +"const __global float *src0_read = A + row_gid * 4 * A_col_size + offA;", // NOLINT +"const __global float *src1_read = v + offv;", // NOLINT +"result = (__global float4*)((__global float*)result + offr);", // NOLINT +"float4 dot0 = (float4)(0.f);", // NOLINT +"float4 dot1 = (float4)(0.f);", // NOLINT +"float4 dot2 = (float4)(0.f);", // NOLINT +"float4 dot3 = (float4)(0.f);", // NOLINT +"", // NOLINT +"unsigned int i = lid;", // NOLINT +"while( i < A_col_size / 4) {", // NOLINT +"const float4 a0 = vload4(i, src0_read);", // NOLINT +"const float4 a1 = vload4(i, src0_read + A_col_size);", // NOLINT +"const float4 a2 = vload4(i, src0_read + 2 * A_col_size);", // NOLINT +"const float4 a3 = vload4(i, src0_read + 3 * A_col_size);", // NOLINT +"", // NOLINT +"const float4 b0 = vload4(i, src1_read);", // NOLINT +"", // NOLINT +"dot0 += a0 * b0;", // NOLINT +"dot1 += a1 * b0;", // NOLINT +"dot2 += a2 * b0;", // NOLINT +"dot3 += a3 * b0;", // NOLINT +"", // NOLINT +"i += get_local_size(0);", // NOLINT +"}", // NOLINT +"", // NOLINT +"work[lid].s0 = dot0.x + dot0.y + dot0.z + dot0.w;", // NOLINT +"work[lid].s1 = dot1.x + dot1.y + dot1.z + dot1.w;", // NOLINT +"work[lid].s2 = dot2.x + dot2.y + dot2.z + dot2.w;", // NOLINT +"work[lid].s3 = dot3.x + dot3.y + dot3.z + dot3.w;", // NOLINT +"", // NOLINT +"if(i == A_col_size / 4)", // NOLINT +"{", // NOLINT +"if(trail_item != 0)", // NOLINT +"{", // NOLINT +"const __global float *src0_trail = src0_read + i * 4;", // NOLINT +"const __global float *src1_trail = src1_read + i * 4;", // NOLINT +"for(unsigned int i = 0; i < trail_item; ++i) {", // NOLINT +"const float at0 = src0_trail[i];", // NOLINT +"const float at1 = src0_trail[i + A_col_size];", // NOLINT +"const float at2 = src0_trail[i + 2 * A_col_size];", // NOLINT +"const float at3 = src0_trail[i + 3 * A_col_size];", // NOLINT +"", // NOLINT +"const float bt = src1_trail[i];", // NOLINT +"", // NOLINT +"work[lid].s0 += at0 * bt;", // NOLINT +"work[lid].s1 += at1 * bt;", // NOLINT +"work[lid].s2 += at2 * bt;", // NOLINT +"work[lid].s3 += at3 * bt;", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"}", // NOLINT +"", // NOLINT +"for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) {", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"if(lid < stride)", // NOLINT +"work[lid] += work[lid+stride];", // NOLINT +"}", // NOLINT +"if(lid == 0)", // NOLINT +"result[row_gid] = alpha * work[0] + beta * result[row_gid];", // NOLINT +"}", // NOLINT +"", // NOLINT +"/* This kernel used for the trailing rows when row_of_A %4 !=0 */", // NOLINT +"__kernel void TEMPLATE(matvec_mul1,Dtype)(", // NOLINT +"__global const float * A,", // NOLINT +"int offA,", // NOLINT +"unsigned int A_col_size,", // NOLINT +"unsigned int row_offset,", // NOLINT +"unsigned int trail_item,", // NOLINT +"__global const float * v,", // NOLINT +"int offv,", // NOLINT +"float alpha,", // NOLINT +"float beta,", // NOLINT +"__global float * result,", // NOLINT +"int offr,", // NOLINT +"__local float * work)", // NOLINT +"{", // NOLINT +"unsigned int row_gid = get_group_id(0);", // NOLINT +"unsigned int lid = get_local_id(0);", // NOLINT +"", // NOLINT +"const __global float *src0_read = A + (row_offset + row_gid) * A_col_size + offA;", // NOLINT +"const __global float *src1_read = v + + offv;", // NOLINT +"result = result + offr;", // NOLINT +"float4 dot0 = (float4)(0.f);", // NOLINT +"", // NOLINT +"unsigned int i = lid;", // NOLINT +"while( i < A_col_size / 4)", // NOLINT +"{", // NOLINT +"const float4 a0 = vload4(i, src0_read);", // NOLINT +"const float4 b0 = vload4(i, src1_read);", // NOLINT +"", // NOLINT +"dot0 += a0 * b0;", // NOLINT +"i += get_local_size(0);", // NOLINT +"}", // NOLINT +"", // NOLINT +"work[lid] = dot0.x + dot0.y + dot0.z + dot0.w;", // NOLINT +"", // NOLINT +"if(i == A_col_size / 4)", // NOLINT +"{", // NOLINT +"if(trail_item != 0)", // NOLINT +"{", // NOLINT +"const __global float *src0_trail = src0_read + i * 4;", // NOLINT +"const __global float *src1_trail = src1_read + i * 4;", // NOLINT +"for(unsigned int i = 0; i < trail_item; ++i) {", // NOLINT +"const float at0 = src0_trail[i];", // NOLINT +"const float bt = src1_trail[i];", // NOLINT +"", // NOLINT +"work[lid] += at0 * bt;", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"}", // NOLINT +"for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) {", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"if(lid < stride)", // NOLINT +"work[lid] += work[lid+stride];", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(lid == 0) {", // NOLINT +"result[row_gid+row_offset] *= beta;", // NOLINT +"result[row_gid+row_offset] += alpha * work[0];", // NOLINT +"//result[row_gid+row_offset] = alpha * work[0] + beta * result[row_gid+row_offset];", // NOLINT +"}", // NOLINT +"}", // NOLINT +""}, // NOLINT + {"#ifndef __OPENCL_VERSION__", // NOLINT +"#include \"header.cl\"", // NOLINT +"#endif", // NOLINT +"", // NOLINT "__kernel void TEMPLATE(merge_copy_forward_stack, Dtype)(const int_tp nthreads,", // NOLINT "const int_tp dims,", // NOLINT "__global const Dtype* bottom_a,", // NOLINT @@ -5197,11 +6826,13 @@ static std::string cl_kernel_names[] = { "embed", // NOLINT "fft", // NOLINT "fillbuffer", // NOLINT + "gemm", // NOLINT "im2col", // NOLINT "im2col_nd", // NOLINT "lrn", // NOLINT "lstm_unit", // NOLINT "math", // NOLINT + "matvec_mul", // NOLINT "mergecrop", // NOLINT "pooling", // NOLINT "pooling_nd", // NOLINT @@ -5254,10 +6885,15 @@ viennacl::ocl::program & RegisterKernels(viennacl::ocl::context *ctx) { ss << "#endif // DOUBLE_SUPPORT_AVAILABLE" << "\n\n"; // NOLINT std::string kernel_string = ss.str(); const char* kernel_program = kernel_string.c_str(); - // ctx->build_options("-cl-fast-relaxed-math -cl-mad-enable"); + string options; #ifdef USE_FFT - ctx->build_options("-DFFT"); + options = " -DFFT " #endif + bool is_beignet = ctx->devices()[0].opencl_c_version().find("beignet") + != std::string::npos; + if (!is_beignet) + options += (" -cl-no-subgroup-ifp "); + ctx->build_options(options); viennacl::ocl::program &program = ctx->add_program(kernel_program, "kernel_program"); return program; @@ -5288,6 +6924,10 @@ viennacl::ocl::context *ctx, string name, string options) { } } } + bool is_beignet = ctx->devices()[0].opencl_c_version().find("beignet") + != std::string::npos; + if (!is_beignet) + options += (" -cl-no-subgroup-ifp "); ctx->build_options(options); viennacl::ocl::program &program = ctx->add_program(ss.str(), name); return program; diff --git a/src/caffe/greentea/cl_kernels.sh b/src/caffe/greentea/cl_kernels.sh index 562bf0a9211..5c1fd66b129 100755 --- a/src/caffe/greentea/cl_kernels.sh +++ b/src/caffe/greentea/cl_kernels.sh @@ -188,10 +188,15 @@ echo " ss << \"#endif // DOUBLE_SUPPORT_AVAILABLE\" << \"\\n\\n\"; // NOLINT" echo " std::string kernel_string = ss.str();" >> $SOURCE echo " const char* kernel_program = kernel_string.c_str();" >> $SOURCE -echo " // ctx->build_options(\"-cl-fast-relaxed-math -cl-mad-enable\");" >> $SOURCE +echo " string options;" >> $SOURCE echo "#ifdef USE_FFT" >> $SOURCE -echo " ctx->build_options(\"-DFFT\");" >> $SOURCE +echo " options = \" -DFFT \"" >> $SOURCE echo "#endif" >> $SOURCE +echo " bool is_beignet = ctx->devices()[0].opencl_c_version().find(\"beignet\")" >> $SOURCE +echo " != std::string::npos;" >> $SOURCE +echo " if (!is_beignet)" >> $SOURCE +echo " options += (\" -cl-no-subgroup-ifp \");" >> $SOURCE +echo " ctx->build_options(options);" >> $SOURCE echo " viennacl::ocl::program &program = ctx->add_program(kernel_program," >> $SOURCE echo " \"kernel_program\");" >> $SOURCE echo " return program;" >> $SOURCE @@ -222,6 +227,10 @@ echo " ss << cl_kernels[i][j] << \"\n\n\";" >> $SOURCE echo " }" >> $SOURCE echo " }" >> $SOURCE echo " }" >> $SOURCE +echo " bool is_beignet = ctx->devices()[0].opencl_c_version().find(\"beignet\")" >> $SOURCE +echo " != std::string::npos;" >> $SOURCE +echo " if (!is_beignet)" >> $SOURCE +echo " options += (\" -cl-no-subgroup-ifp \");" >> $SOURCE echo " ctx->build_options(options);" >> $SOURCE echo " viennacl::ocl::program &program = ctx->add_program(ss.str(), name);" >> $SOURCE echo " return program;" >> $SOURCE diff --git a/src/caffe/greentea/cl_kernels/gemm.cl b/src/caffe/greentea/cl_kernels/gemm.cl new file mode 100644 index 00000000000..de5fc39c3fd --- /dev/null +++ b/src/caffe/greentea/cl_kernels/gemm.cl @@ -0,0 +1,1958 @@ +#ifndef __OPENCL_VERSION__ +#include "header.cl" +#endif + +#define TILE_M 32 +#define TILE_K 8 +#define TILE_N 8 + +// common block to calculate (alpha * AxB + beta * C) and output to destination image. + +//#define USE_IMAGE_C +#ifdef USE_IMAGE_C +#define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read8( _C, _coordC ) ) +#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) ) +#define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst +#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint)) +#else +#define BLOCKC_READ8( _C, _coordC ) \ + (float8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * N + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * N + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * N + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * N + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * N + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * N + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * N + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * N + _coordC.x + get_local_id(0) ] : 0) + +#define BLOCKC_WRITE8( _C, _coordC, _val) do {\ + if (_coordC.x + get_local_id(0) < N) { \ + if (_coordC.y < M) \ + _C[ _coordC.y * N + _coordC.x + get_local_id(0) ] = _val.s0; \ + if (_coordC.y + 1 < M) \ + _C[ ( _coordC.y + 1 )* N + _coordC.x + get_local_id(0) ] = _val.s1; \ + if (_coordC.y + 2 < M) \ + _C[ ( _coordC.y + 2 )* N + _coordC.x + get_local_id(0) ] = _val.s2; \ + if (_coordC.y + 3 < M) \ + _C[ ( _coordC.y + 3 )* N + _coordC.x + get_local_id(0) ] = _val.s3; \ + if (_coordC.y + 4 < M) \ + _C[ ( _coordC.y + 4 )* N + _coordC.x + get_local_id(0) ] = _val.s4; \ + if (_coordC.y + 5 < M) \ + _C[ ( _coordC.y + 5 )* N + _coordC.x + get_local_id(0) ] = _val.s5; \ + if (_coordC.y + 6 < M) \ + _C[ ( _coordC.y + 6 )* N + _coordC.x + get_local_id(0) ] = _val.s6; \ + if (_coordC.y + 7 < M) \ + _C[ ( _coordC.y + 7 )* N + _coordC.x + get_local_id(0) ] = _val.s7; \ + }} while(0) +#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N +#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1) +#endif + +#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) \ + int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); \ + int2 coordC = coordDst; \ + float8 blockC00; \ + float8 blockC01; \ + float8 blockC02; \ + float8 blockC03; \ + if (BETA_NOT0) { \ + blockC00 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ + blockC01 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ + blockC02 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ + blockC03 = BLOCKC_READ8( _C, coordC ); \ + if (!ALPHA1) { \ + blockC00 *= beta; \ + blockC01 *= beta; \ + blockC02 *= beta; \ + blockC03 *= beta; \ + blockC00 = mad(blockAxB00, (float8)alpha, blockC00); \ + blockC01 = mad(blockAxB01, (float8)alpha, blockC01); \ + blockC02 = mad(blockAxB02, (float8)alpha, blockC02); \ + blockC03 = mad(blockAxB03, (float8)alpha, blockC03); \ + } else { \ + blockC00 = mad(blockC00, (float8)beta, blockAxB00); \ + blockC01 = mad(blockC01, (float8)beta, blockAxB01); \ + blockC02 = mad(blockC02, (float8)beta, blockAxB02); \ + blockC03 = mad(blockC03, (float8)beta, blockAxB03); \ + } \ + } else { \ + if (!ALPHA1) { \ + blockC00 = blockAxB00 * alpha; \ + blockC01 = blockAxB01 * alpha; \ + blockC02 = blockAxB02 * alpha; \ + blockC03 = blockAxB03 * alpha; \ + } else { \ + blockC00 = blockAxB00; \ + blockC01 = blockAxB01; \ + blockC02 = blockAxB02; \ + blockC03 = blockAxB03; \ + } \ + } \ + BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; \ + BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; \ + BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; \ + BLOCKC_WRITE8( _dst, coordDst, blockC03 ); + +// Get the specified column of the block of the block +#define TRANSPOSE_BLOCK_8( _block, _col ) \ + (float8)( intel_sub_group_shuffle( _block.s0, _col ), \ + intel_sub_group_shuffle( _block.s1, _col ), \ + intel_sub_group_shuffle( _block.s2, _col ), \ + intel_sub_group_shuffle( _block.s3, _col ), \ + intel_sub_group_shuffle( _block.s4, _col ), \ + intel_sub_group_shuffle( _block.s5, _col ), \ + intel_sub_group_shuffle( _block.s6, _col ), \ + intel_sub_group_shuffle( _block.s7, _col ) ); + +// A's column block multiply B 's row block. +#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ + { \ + const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ + const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ + const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ + const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ + const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ + const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ + const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ + const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ + _result = mad( (float8)(_blockB.s0), acol0, _result ); \ + _result = mad( (float8)(_blockB.s1), acol1, _result ); \ + _result = mad( (float8)(_blockB.s2), acol2, _result ); \ + _result = mad( (float8)(_blockB.s3), acol3, _result ); \ + _result = mad( (float8)(_blockB.s4), acol4, _result ); \ + _result = mad( (float8)(_blockB.s5), acol5, _result ); \ + _result = mad( (float8)(_blockB.s6), acol6, _result ); \ + _result = mad( (float8)(_blockB.s7), acol7, _result ); \ + } + +#define GEMM_NN(ALPHA1, BETA_NOT0) \ +__attribute__((reqd_work_group_size(8, 1, 1))) \ +__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ + __read_only image2d_t A, \ + __read_only image2d_t B, \ + MATC_PARAMETER, \ + float alpha, \ + float beta, \ + int width0) \ +{ \ + const int group_x = get_group_id(0); \ + const int group_y = get_group_id(1); \ + float8 blockAxB00 = 0.0f; \ + float8 blockAxB01 = 0.0f; \ + float8 blockAxB02 = 0.0f; \ + float8 blockAxB03 = 0.0f; \ + int2 coordA = (int2)( 0, group_y * TILE_M ); \ + int2 coordB = (int2)( ( group_x * TILE_N ) * sizeof(uint), 0 ); \ + do \ + { \ + int2 coordBTemp = coordB; \ + float8 blockB00 = as_float8( intel_sub_group_block_read8( B, coordBTemp ) ); coordB.y += TILE_K; \ + int2 coordATemp = coordA; \ + float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.x += TILE_K * sizeof(uint); \ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \ + } \ + while( coordB.y < width0 ); \ + GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ +} + +GEMM_NN(1, 0) // ALPHA == 1, BETA == 0 +GEMM_NN(1, 1) // ALPHA == 1, BETA != 0 +GEMM_NN(0, 0) // ALPHA != 1, BETA == 0 +GEMM_NN(0, 1) // ALPHA != 1, BETA != 0 + +#undef TRANSPOSE_BLOCK_8 +#undef MULTIPLY_BLOCKS_8x8 + +// replicate the first row to column block. +#define TRANSPOSE_BLOCK_8(_vec) \ + (float8)( intel_sub_group_shuffle(_vec, 0), \ + intel_sub_group_shuffle(_vec, 1), \ + intel_sub_group_shuffle(_vec, 2), \ + intel_sub_group_shuffle(_vec, 3), \ + intel_sub_group_shuffle(_vec, 4), \ + intel_sub_group_shuffle(_vec, 5), \ + intel_sub_group_shuffle(_vec, 6), \ + intel_sub_group_shuffle(_vec, 7) ) + +#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ + { \ + _result = mad( (float8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0), _result ); \ + _result = mad( (float8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1), _result ); \ + _result = mad( (float8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2), _result ); \ + _result = mad( (float8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3), _result ); \ + _result = mad( (float8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4), _result ); \ + _result = mad( (float8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5), _result ); \ + _result = mad( (float8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6), _result ); \ + _result = mad( (float8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7), _result ); \ + } + +#define GEMM_TN(ALPHA1, BETA_NOT0) \ +__attribute__((reqd_work_group_size(8, 1, 1))) \ +__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ + __read_only image2d_t A, \ + __read_only image2d_t B, \ + MATC_PARAMETER, \ + float alpha, \ + float beta, \ + int width0) \ +{ \ + const int group_x = get_group_id(0);\ + const int group_y = get_group_id(1);\ + float8 blockAxB00 = 0.0f;\ + float8 blockAxB01 = 0.0f;\ + float8 blockAxB02 = 0.0f;\ + float8 blockAxB03 = 0.0f;\ + int2 coordA = (int2)( group_y * TILE_M * sizeof(uint), 0 );\ + int2 coordB = (int2)( ( group_x * TILE_N ) * sizeof(uint), 0 );\ + do\ + {\ + int2 coordBTemp = coordB;\ + float8 blockB00 = as_float8( intel_sub_group_block_read8( B, coordBTemp ) ); coordB.y += TILE_K;\ + int2 coordATemp = coordA;\ + float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint);\ + float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint);\ + float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint);\ + float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.y += TILE_K;\ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \ + } \ + while( coordB.y < width0 ); \ + GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ +} + +GEMM_TN(1, 0) // ALPHA == 1, BETA == 0 +GEMM_TN(1, 1) // ALPHA == 1, BETA != 0 +GEMM_TN(0, 0) // ALPHA != 1, BETA == 0 +GEMM_TN(0, 1) // ALPHA != 1, BETA != 0 + +#undef MULTIPLY_BLOCKS_8x8 +#undef TRANSPOSE_BLOCK_8 + +// The same as GEMM_NN +#define TRANSPOSE_BLOCK_8( _block, _col ) \ + (float8)( intel_sub_group_shuffle( _block.s0, _col), \ + intel_sub_group_shuffle( _block.s1, _col), \ + intel_sub_group_shuffle( _block.s2, _col), \ + intel_sub_group_shuffle( _block.s3, _col), \ + intel_sub_group_shuffle( _block.s4, _col), \ + intel_sub_group_shuffle( _block.s5, _col), \ + intel_sub_group_shuffle( _block.s6, _col), \ + intel_sub_group_shuffle( _block.s7, _col) ) + +#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ + { \ + const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ + const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ + const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ + const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ + const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ + const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ + const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ + const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ + _result = mad( (float8)_blockB.s0, acol0, _result ); \ + _result = mad( (float8)_blockB.s1, acol1, _result ); \ + _result = mad( (float8)_blockB.s2, acol2, _result ); \ + _result = mad( (float8)_blockB.s3, acol3, _result ); \ + _result = mad( (float8)_blockB.s4, acol4, _result ); \ + _result = mad( (float8)_blockB.s5, acol5, _result ); \ + _result = mad( (float8)_blockB.s6, acol6, _result ); \ + _result = mad( (float8)_blockB.s7, acol7, _result ); \ + } + + + +#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \ +__attribute__((reqd_work_group_size(8, 1, 1))) \ +__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ + __read_only image2d_t A, \ + MATB_PARAMETER, \ + MATC_PARAMETER, \ + float alpha, \ + float beta, \ + int padded_k, \ + int k) \ +{ \ + const int group_x = get_group_id(0); \ + const int group_y = get_group_id(1); \ + float8 blockAxB00 = 0.0f; \ + float8 blockAxB01 = 0.0f; \ + float8 blockAxB02 = 0.0f; \ + float8 blockAxB03 = 0.0f; \ + int2 coordA = (int2)( 0, group_y * TILE_M ); \ + int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ + const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ + do \ + { \ + float8 blockB00; \ + BLOCKB_READ8(blockB00, B, coordB); \ + int2 coordATemp = coordA; \ + float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.x += TILE_K * sizeof(uint); \ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \ + } \ + while( coordB.x < padded_k / VECSIZE ); \ + GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ +} + + +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + _blockb.s0123 = read_imagef(_B, sampler, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s4567 = read_imagef(_B, sampler, _coordBTemp); _coordB.x += 2; + +#define MATB_PARAMETER __read_only image2d_t B + +GEMM_NT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0 +GEMM_NT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0 +GEMM_NT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0 +GEMM_NT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0 +#undef BLOCKB_READ8 +#undef MATB_PARAMETER + +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + _blockb = *(__global float8*)&_B[_coordBTemp.y * k + _coordBTemp.x + offB];\ + _coordB.x += TILE_K; + +#define MATB_PARAMETER __global float *B, int offB + +GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0 +GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0 +GEMM_NT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0 +GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0 +#undef BLOCKB_READ8 +#undef MATB_PARAMETER + + +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + float4 temp; \ + temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s0 = temp.s0; \ + temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s1 = temp.s0; \ + temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s2 = temp.s0; \ + temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s3 = temp.s0; \ + temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s4 = temp.s0; \ + temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s5 = temp.s0; \ + temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s6 = temp.s0; \ + temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s7 = temp.s0; \ + _coordB.x += 8; + +#define MATB_PARAMETER __read_only image2d_t B + +GEMM_NT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0 +GEMM_NT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0 +GEMM_NT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0 +GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 +#undef BLOCKB_READ8 +#undef MATB_PARAMETER + +#undef MULTIPLY_BLOCKS_8x8 +#undef TRANSPOSE_BLOCK_8 + +//The same as GEMM_TN. +#define TRANSPOSE_BLOCK_8(_vec) \ + (float8)( intel_sub_group_shuffle(_vec, 0), \ + intel_sub_group_shuffle(_vec, 1), \ + intel_sub_group_shuffle(_vec, 2), \ + intel_sub_group_shuffle(_vec, 3), \ + intel_sub_group_shuffle(_vec, 4), \ + intel_sub_group_shuffle(_vec, 5), \ + intel_sub_group_shuffle(_vec, 6), \ + intel_sub_group_shuffle(_vec, 7) ); + +#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ + { \ + const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0 ); \ + const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1 ); \ + const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2 ); \ + const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3 ); \ + const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4 ); \ + const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5 ); \ + const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6 ); \ + const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7 ); \ + _result = mad( (float8)_blockB.s0, acol0, _result ); \ + _result = mad( (float8)_blockB.s1, acol1, _result ); \ + _result = mad( (float8)_blockB.s2, acol2, _result ); \ + _result = mad( (float8)_blockB.s3, acol3, _result ); \ + _result = mad( (float8)_blockB.s4, acol4, _result ); \ + _result = mad( (float8)_blockB.s5, acol5, _result ); \ + _result = mad( (float8)_blockB.s6, acol6, _result ); \ + _result = mad( (float8)_blockB.s7, acol7, _result ); \ + } + +#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \ +__attribute__((reqd_work_group_size(8, 1, 1))) \ +__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \ + __read_only image2d_t A, \ + MATB_PARAMETER, \ + MATC_PARAMETER, \ + float alpha, \ + float beta, \ + int padded_k, \ + int k) \ +{ \ + const int group_x = get_group_id(0); \ + const int group_y = get_group_id(1); \ + float8 blockAxB00 = 0.0f; \ + float8 blockAxB01 = 0.0f; \ + float8 blockAxB02 = 0.0f; \ + float8 blockAxB03 = 0.0f; \ + int2 coordA = (int2)( group_y * TILE_M * sizeof(uint), 0 ); \ + int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ + const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ + do \ + { \ + float8 blockB00; \ + BLOCKB_READ8(blockB00, B, coordB); \ + int2 coordATemp = coordA; \ + float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); \ + float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); \ + float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); \ + float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.y += TILE_K; \ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00 ); \ + } \ + while( coordB.x < padded_k / VECSIZE ); \ + GEMM_OUTPUT(ALPHA1, BETA_NOT0);\ +} + +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + blockB00.s0123 = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + blockB00.s4567 = read_imagef(B, _coordBTemp); _coordB.x += 2; + +#define MATB_PARAMETER __read_only image2d_t B + +GEMM_TT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0 +GEMM_TT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0 +GEMM_TT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0 +GEMM_TT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0 +#undef BLOCKB_READ8 +#undef MATB_PARAMETER + +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + _blockb = *(__global float8*)&_B[_coordBTemp.y * k + _coordBTemp.x + offB];\ + _coordB.x += TILE_K; + +#define MATB_PARAMETER __global float *B, int offB + +GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0 +GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0 +GEMM_TT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0 +GEMM_TT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0 +#undef BLOCKB_READ8 +#undef MATB_PARAMETER + +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + float4 temp; \ + temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s0 = temp.s0; \ + temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s1 = temp.s0; \ + temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s2 = temp.s0; \ + temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s3 = temp.s0; \ + temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s4 = temp.s0; \ + temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s5 = temp.s0; \ + temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s6 = temp.s0; \ + temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s7 = temp.s0; \ + _coordB.x += 8; + +#define MATB_PARAMETER __read_only image2d_t B + +GEMM_TT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0 +GEMM_TT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0 +GEMM_TT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0 +GEMM_TT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 +#undef BLOCKB_READ8 +#undef MATB_PARAMETER + +#undef MULTIPLY_BLOCKS_8x8 +#undef TRANSPOSE_BLOCK_8 + +#undef TILE_M +#undef TILE_K +#undef TILE_N + +__kernel void TEMPLATE(gemm_buffer_copy_image,Dtype)( + __global float* A, + __write_only image2d_t ImA, + int offA, + int width, + int height) +{ + const int gidx = get_global_id(0); + const int gidy = get_global_id(1); + int2 coord_dst = (int2)(gidx, gidy); + if (gidx >= width || gidy >= height) { + write_imageui(ImA, coord_dst, (uint4)0); + return; + } + __global float* A_off = A + offA; + uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * width + gidx])); + write_imageui(ImA, coord_dst, srcA); +} + +#define VEC_SIZE 4 +#define LWG_HEIGHT 4 +#define TILE_M 8 +#define TILE_K 16 +#define TILE_N 32 + +__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) +__kernel void TEMPLATE(gemm_buffer_NN, Dtype)( + const __global float *src0, int off0, + const __global float *src1, int off1, + __global float *dst, int offd, + int M, + int N, + int K, + float alpha, + float beta, + int start_index) +{ + const int group_x = get_group_id(0); + const int group_y = get_group_id(1); + const int local_x = get_local_id(0); + const int local_y = get_local_id(1); + const int global_x = get_global_id(0); + const int global_y = get_global_id(1); + + float4 brow; + float2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7; + + __global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + + const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * K + start_index + off0; + + const __global float *src1_read0 = src1 + local_x * VEC_SIZE + ( group_x * TILE_N ) + start_index * N + off1; + + int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M); + + int row0 = mad24(global_y, TILE_M, 0) < M ? 0 : border; + int row1 = mad24(global_y, TILE_M, 1) < M ? 1 : border; + int row2 = mad24(global_y, TILE_M, 2) < M ? 2 : border; + int row3 = mad24(global_y, TILE_M, 3) < M ? 3 : border; + int row4 = mad24(global_y, TILE_M, 4) < M ? 4 : border; + int row5 = mad24(global_y, TILE_M, 5) < M ? 5 : border; + int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border; + int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border; + + float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : beta * ((__global float4 *)dst_write0)[0]; + float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 1 * N))[0] : beta * ((__global float4 *)(dst_write0 + 1 * N))[0]; + float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : beta * ((__global float4 *)(dst_write0 + 2 * N))[0]; + float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : beta * ((__global float4 *)(dst_write0 + 3 * N))[0]; + float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : beta * ((__global float4 *)(dst_write0 + 4 * N))[0]; + float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : beta * ((__global float4 *)(dst_write0 + 5 * N))[0]; + float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : beta * ((__global float4 *)(dst_write0 + 6 * N))[0]; + float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : beta * ((__global float4 *)(dst_write0 + 7 * N))[0]; + + int end_index = min(start_index + 256, K); + int w = start_index; + while( w + TILE_K <= end_index ) { + arow0 = alpha * ((__global float2 *)(src0_read + row0 * K))[0]; + arow1 = alpha * ((__global float2 *)(src0_read + row1 * K))[0]; + arow2 = alpha * ((__global float2 *)(src0_read + row2 * K))[0]; + arow3 = alpha * ((__global float2 *)(src0_read + row3 * K))[0]; + arow4 = alpha * ((__global float2 *)(src0_read + row4 * K))[0]; + arow5 = alpha * ((__global float2 *)(src0_read + row5 * K))[0]; + arow6 = alpha * ((__global float2 *)(src0_read + row6 * K))[0]; + arow7 = alpha * ((__global float2 *)(src0_read + row7 * K))[0]; + +#define MM_DOT_PRODUCT( index, suffix ) \ + brow = ((__global float4 *)src1_read0)[0]; src1_read0 += N; \ + dot00 = mad( (float4)(intel_sub_group_shuffle( arow0, index ).s##suffix), brow, dot00 ); \ + dot01 = mad( (float4)(intel_sub_group_shuffle( arow1, index ).s##suffix), brow, dot01 ); \ + dot02 = mad( (float4)(intel_sub_group_shuffle( arow2, index ).s##suffix), brow, dot02 ); \ + dot03 = mad( (float4)(intel_sub_group_shuffle( arow3, index ).s##suffix), brow, dot03 ); \ + dot04 = mad( (float4)(intel_sub_group_shuffle( arow4, index ).s##suffix), brow, dot04 ); \ + dot05 = mad( (float4)(intel_sub_group_shuffle( arow5, index ).s##suffix), brow, dot05 ); \ + dot06 = mad( (float4)(intel_sub_group_shuffle( arow6, index ).s##suffix), brow, dot06 ); \ + dot07 = mad( (float4)(intel_sub_group_shuffle( arow7, index ).s##suffix), brow, dot07 ); \ + + MM_DOT_PRODUCT(0, 0); + MM_DOT_PRODUCT(0, 1); + MM_DOT_PRODUCT(1, 0); + MM_DOT_PRODUCT(1, 1); + MM_DOT_PRODUCT(2, 0); + MM_DOT_PRODUCT(2, 1); + MM_DOT_PRODUCT(3, 0); + MM_DOT_PRODUCT(3, 1); + MM_DOT_PRODUCT(4, 0); + MM_DOT_PRODUCT(4, 1); + MM_DOT_PRODUCT(5, 0); + MM_DOT_PRODUCT(5, 1); + MM_DOT_PRODUCT(6, 0); + MM_DOT_PRODUCT(6, 1); + MM_DOT_PRODUCT(7, 0); + MM_DOT_PRODUCT(7, 1); +#undef MM_DOT_PRODUCT + + src0_read += TILE_K; + w += TILE_K; + } + + if(w < end_index) { + arow0.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row0 * K)[0] : 0.0f; + arow0.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row0 * K)[1] : 0.0f; + arow1.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row1 * K)[0] : 0.0f; + arow1.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row1 * K)[1] : 0.0f; + arow2.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row2 * K)[0] : 0.0f; + arow2.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row2 * K)[1] : 0.0f; + arow3.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row3 * K)[0] : 0.0f; + arow3.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row3 * K)[1] : 0.0f; + arow4.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row4 * K)[0] : 0.0f; + arow4.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row4 * K)[1] : 0.0f; + arow5.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row5 * K)[0] : 0.0f; + arow5.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row5 * K)[1] : 0.0f; + arow6.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row6 * K)[0] : 0.0f; + arow6.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row6 * K)[1] : 0.0f; + arow7.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row7 * K)[0] : 0.0f; + arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f; + +#define MM_DOT_PRODUCT( index, suffix ) \ + brow = (w < K) ? ((__global float4 *)src1_read0)[0] : 0.0f; src1_read0 += N; w++; \ + dot00 = mad( (float4)(intel_sub_group_shuffle( arow0, index ).s##suffix), brow, dot00 ); \ + dot01 = mad( (float4)(intel_sub_group_shuffle( arow1, index ).s##suffix), brow, dot01 ); \ + dot02 = mad( (float4)(intel_sub_group_shuffle( arow2, index ).s##suffix), brow, dot02 ); \ + dot03 = mad( (float4)(intel_sub_group_shuffle( arow3, index ).s##suffix), brow, dot03 ); \ + dot04 = mad( (float4)(intel_sub_group_shuffle( arow4, index ).s##suffix), brow, dot04 ); \ + dot05 = mad( (float4)(intel_sub_group_shuffle( arow5, index ).s##suffix), brow, dot05 ); \ + dot06 = mad( (float4)(intel_sub_group_shuffle( arow6, index ).s##suffix), brow, dot06 ); \ + dot07 = mad( (float4)(intel_sub_group_shuffle( arow7, index ).s##suffix), brow, dot07 ); \ + + MM_DOT_PRODUCT(0, 0); + MM_DOT_PRODUCT(0, 1); + MM_DOT_PRODUCT(1, 0); + MM_DOT_PRODUCT(1, 1); + MM_DOT_PRODUCT(2, 0); + MM_DOT_PRODUCT(2, 1); + MM_DOT_PRODUCT(3, 0); + MM_DOT_PRODUCT(3, 1); + MM_DOT_PRODUCT(4, 0); + MM_DOT_PRODUCT(4, 1); + MM_DOT_PRODUCT(5, 0); + MM_DOT_PRODUCT(5, 1); + MM_DOT_PRODUCT(6, 0); + MM_DOT_PRODUCT(6, 1); + MM_DOT_PRODUCT(7, 0); + MM_DOT_PRODUCT(7, 1); +#undef MM_DOT_PRODUCT + } + + if(global_x * 4 < N && global_y * 8 < M) { + if(mad24(global_x, 4, 3) < N) { + __global float4 *dst_write = (__global float4 *)dst_write0; + dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0; + if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; } + } else if(mad24(global_x, 4, 2) < N) { + __global float2 *dst_write = (__global float2 *)dst_write0; + dst_write[0] = dot00.xy; + dst_write0[2] = dot00.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + if(mad24(global_y, 8, 1) < M) { + dst_write[0] = dot01.xy; + dst_write0[2] = dot01.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 2) < M) { + dst_write[0] = dot02.xy; + dst_write0[2] = dot02.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 3) < M) { + dst_write[0] = dot03.xy; + dst_write0[2] = dot03.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 4) < M) { + dst_write[0] = dot04.xy; + dst_write0[2] = dot04.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 5) < M) { + dst_write[0] = dot05.xy; + dst_write0[2] = dot05.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 6) < M) { + dst_write[0] = dot06.xy; + dst_write0[2] = dot06.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 7) < M) { + dst_write[0] = dot07.xy; + dst_write0[2] = dot07.z; + } + } else if(mad24(global_x, 4, 1) < N) { + __global float2 *dst_write = (__global float2 *)dst_write0; + dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; + if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; } + } else { + dst_write0[0] = dot00.x; dst_write0 += N; + if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; } + } + } +} + +#undef VEC_SIZE +#undef LWG_HEIGHT +#undef TILE_M +#undef TILE_K +#undef TILE_N + +#define VEC_SIZE 1 +#define LWG_HEIGHT 16 +#define TILE_M 8 +#define TILE_K 32 +#define TILE_N 8 +#define SLM_BLOCK 512 + +__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) +__kernel void TEMPLATE(gemm_buffer_NT, Dtype)( + const __global float *src0, int off0, + const __global float *src1, int off1, + __global float *dst, int offd, + int M, + int N, + int K, + float alpha, + float beta) +{ + const int group_x = get_group_id(0); + const int group_y = get_group_id(1); + const int local_x = get_local_id(0); + const int local_y = get_local_id(1); + const int global_x = get_global_id(0); + const int global_y = get_global_id(1); + + float8 dot00 = 0.f; + float8 dot01 = 0.f; + float8 dot02 = 0.f; + float8 dot03 = 0.f; + float8 dot04 = 0.f; + float8 dot05 = 0.f; + float8 dot06 = 0.f; + float8 dot07 = 0.f; + + float4 brow0; + float4 brow1; + float4 brow2; + float4 brow3; + float4 brow4; + float4 brow5; + float4 brow6; + float4 brow7; + + __global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + + const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * K + off0; + + const __global float *src1_read0 = src1 + ( group_x * TILE_N ) * K + off1; + + __local float slm_brow[8 * SLM_BLOCK]; + __local float* slm_brow0; + + int local_index = mad24(local_y, 8, local_x) * 4; + int w; + for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) { + barrier(CLK_LOCAL_MEM_FENCE); + ((__local float4 *)(slm_brow + mad24(0, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(0, K, local_index)))[0]; + ((__local float4 *)(slm_brow + mad24(1, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(1, K, local_index)))[0]; + ((__local float4 *)(slm_brow + mad24(2, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(2, K, local_index)))[0]; + ((__local float4 *)(slm_brow + mad24(3, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(3, K, local_index)))[0]; + ((__local float4 *)(slm_brow + mad24(4, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(4, K, local_index)))[0]; + ((__local float4 *)(slm_brow + mad24(5, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(5, K, local_index)))[0]; + ((__local float4 *)(slm_brow + mad24(6, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(6, K, local_index)))[0]; + ((__local float4 *)(slm_brow + mad24(7, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(7, K, local_index)))[0]; + barrier(CLK_LOCAL_MEM_FENCE); + + slm_brow0 = slm_brow + local_x * (TILE_K / 8); + w = b_tile; + int end_w = min(b_tile + SLM_BLOCK, K); + while( w + TILE_K <= end_w ) { + float4 arow; + + brow0 = ((__local float4 *)(slm_brow0 + 0 * SLM_BLOCK))[0]; + brow1 = ((__local float4 *)(slm_brow0 + 1 * SLM_BLOCK))[0]; + brow2 = ((__local float4 *)(slm_brow0 + 2 * SLM_BLOCK))[0]; + brow3 = ((__local float4 *)(slm_brow0 + 3 * SLM_BLOCK))[0]; + brow4 = ((__local float4 *)(slm_brow0 + 4 * SLM_BLOCK))[0]; + brow5 = ((__local float4 *)(slm_brow0 + 5 * SLM_BLOCK))[0]; + brow6 = ((__local float4 *)(slm_brow0 + 6 * SLM_BLOCK))[0]; + brow7 = ((__local float4 *)(slm_brow0 + 7 * SLM_BLOCK))[0]; + +#define MM_DOT_PRODUCT( _row, _dot ) \ + arow = ((__global float4 *)(src0_read + _row * K))[0]; \ + _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ + _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ + _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ + _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \ + + MM_DOT_PRODUCT( 0, dot00 ); + MM_DOT_PRODUCT( 1, dot01 ); + MM_DOT_PRODUCT( 2, dot02 ); + MM_DOT_PRODUCT( 3, dot03 ); + MM_DOT_PRODUCT( 4, dot04 ); + MM_DOT_PRODUCT( 5, dot05 ); + MM_DOT_PRODUCT( 6, dot06 ); + MM_DOT_PRODUCT( 7, dot07 ); +#undef MM_DOT_PRODUCT + + src0_read += TILE_K; + slm_brow0 += TILE_K; + w += TILE_K; + } + src1_read0 += SLM_BLOCK; + } + + if(w < K) { + float4 arow; + +#define READ_BROW(_brow, _row) \ + _brow = ((__local float4 *)(slm_brow0 + _row * SLM_BLOCK))[0]; \ + _brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \ + _brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \ + _brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \ + _brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f; \ + + READ_BROW(brow0, 0); + READ_BROW(brow1, 1); + READ_BROW(brow2, 2); + READ_BROW(brow3, 3); + READ_BROW(brow4, 4); + READ_BROW(brow5, 5); + READ_BROW(brow6, 6); + READ_BROW(brow7, 7); + +#define MM_DOT_PRODUCT( _row, _dot ) \ + arow = ((__global float4 *)(src0_read + _row * K))[0]; \ + arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \ + arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \ + arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \ + arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \ + _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ + _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ + _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ + _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \ + + MM_DOT_PRODUCT( 0, dot00 ); + MM_DOT_PRODUCT( 1, dot01 ); + MM_DOT_PRODUCT( 2, dot02 ); + MM_DOT_PRODUCT( 3, dot03 ); + MM_DOT_PRODUCT( 4, dot04 ); + MM_DOT_PRODUCT( 5, dot05 ); + MM_DOT_PRODUCT( 6, dot06 ); + MM_DOT_PRODUCT( 7, dot07 ); +#undef MM_DOT_PRODUCT + } + +#define REDUCE(_dot) \ + _dot = intel_sub_group_shuffle(_dot, 0) + intel_sub_group_shuffle(_dot, 1) + intel_sub_group_shuffle(_dot, 2) + intel_sub_group_shuffle(_dot, 3) + \ + intel_sub_group_shuffle(_dot, 4) + intel_sub_group_shuffle(_dot, 5) + intel_sub_group_shuffle(_dot, 6) + intel_sub_group_shuffle(_dot, 7); \ + + REDUCE(dot00); + REDUCE(dot01); + REDUCE(dot02); + REDUCE(dot03); + REDUCE(dot04); + REDUCE(dot05); + REDUCE(dot06); + REDUCE(dot07); +#undef REDUCE + + float output = 0.0f; +#define OUTPUT( _dot) \ + output = (local_x == 0) ? _dot.s0 : output; \ + output = (local_x == 1) ? _dot.s1 : output; \ + output = (local_x == 2) ? _dot.s2 : output; \ + output = (local_x == 3) ? _dot.s3 : output; \ + output = (local_x == 4) ? _dot.s4 : output; \ + output = (local_x == 5) ? _dot.s5 : output; \ + output = (local_x == 6) ? _dot.s6 : output; \ + output = (local_x == 7) ? _dot.s7 : output; \ + dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \ + dst_write0 += N; + + if(global_x < N && global_y * 8 < M) { + OUTPUT(dot00); + if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); } + if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); } + if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); } + if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); } + if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); } + if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); } + if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); } + } +#undef OUTPUT +} + +#undef VEC_SIZE +#undef LWG_HEIGHT +#undef TILE_M +#undef TILE_K +#undef TILE_N +#undef SLM_BLOCK + +#define SLM_SIZE 64 +void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( + const __global Dtype* srca_read0, + const __global Dtype* srca_read1, + const __global Dtype* srcb_read, + __local Dtype4* work0, + __local Dtype4* work1, + int N, + int K, + int x_gid, + int lid, + Dtype alpha, + Dtype beta, + __global Dtype* dstc0, + __global Dtype* dstc1) +{ + __local Dtype* work_each0 = (__local Dtype*)work0; + __local Dtype* work_each1 = (__local Dtype*)work1; + + int rows = N - x_gid * 4; + + Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + + int i = lid; + while( i < K / 4) { + const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; + const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; +#pragma unroll + for(int j = 0; j < rows; ++j) { + dot0[j] += b0 * vload4(i, srcb_read + j * K); + dot1[j] += b1 * vload4(i, srcb_read + j * K); + } + + i += get_local_size(0); + } +#pragma unroll + for(int j = 0; j < rows; ++j) { + work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w; + work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w; + } + + if(i == K / 4) { + short tail_items = K % 4; + + if(tail_items != 0) { + const __global Dtype *srcb_tail = srcb_read + i * 4; + const __global Dtype *srca_tail0 = srca_read0 + i * 4; + const __global Dtype *srca_tail1 = srca_read1 + i * 4; +#pragma unroll + for(short i = 0; i < tail_items; ++i) { + const Dtype at0 = srca_tail0[i]; + const Dtype at1 = srca_tail1[i]; +#pragma unroll + for(int j = 0; j < rows; ++j) { + work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K]; + work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K]; + } + } + } + } + + for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) { + work0[lid] += work0[lid+stride]; + work1[lid] += work1[lid+stride]; + } + } + + if(lid == 0) { +#pragma unroll + for(int j = 0; j < rows; ++j) { + dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)]; + dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)]; + } + } +} + +__kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( + __global const Dtype * A, + int offA, + __global const Dtype * B, + int offB, + __global Dtype * C, + int offC, + int M, + int N, + int K, + float alpha_f, + float beta_f) +{ + Dtype alpha = (Dtype)alpha_f; + Dtype beta = (Dtype)beta_f; + int x_gid = get_group_id(0); + int lid = get_local_id(0); + + const __global Dtype *srca_read0 = A + offA; + const __global Dtype *srca_read1 = srca_read0 + K; + + const __global Dtype *srcb_read = B + x_gid * 4 * K + offB; + + __global Dtype4 *dstc0 = (__global Dtype4*)(C + offC); + __global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N); + + __local Dtype4 work0[SLM_SIZE]; + __local Dtype4 work1[SLM_SIZE]; + __local Dtype* work_each0 = (__local Dtype*)work0; + __local Dtype* work_each1 = (__local Dtype*)work1; + + if(x_gid == N / 4) { + TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) \ + (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1); + } else { + Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + int i = lid; + while( i < K / 4) { + const Dtype4 b0 = vload4(i, srca_read0); + const Dtype4 b1 = vload4(i, srca_read1); +#pragma unroll + for(int j = 0; j < 4; ++j) { + Dtype4 a = vload4(i, srcb_read + j * K); + dot0[j] += b0 * a; + dot1[j] += b1 * a; + } + i += get_local_size(0); + } + +#pragma unroll + for(int j = 0; j < 4; ++j) { + work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w; + work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w; + } + + if(i == K / 4) { + short tail_items = K % 4; + if(tail_items != 0) { + const __global Dtype *srcb_tail = srcb_read + i * 4; + + const __global Dtype *srca_tail0 = srca_read0 + i * 4; + const __global Dtype *srca_tail1 = srca_read1 + i * 4; +#pragma unroll + for(short i = 0; i < tail_items; ++i) { + const Dtype at0 = srca_tail0[i]; + const Dtype at1 = srca_tail1[i]; +#pragma unroll + for(int j = 0; j < 4; ++j) { + work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K]; + work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K]; + } + } + } + } + + for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) { + work0[lid] += work0[lid+stride]; + work1[lid] += work1[lid+stride]; + } + } + + if(lid == 0) { + dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid]; + dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid]; + } + } +} +#undef SLM_SIZE + +#define SLM_SIZE 32 +void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)( + const __global Dtype* srca_read0, + const __global Dtype* srca_read1, + const __global Dtype* srca_read2, + const __global Dtype* srca_read3, + const __global Dtype* srcb_read, + __local Dtype4* work0, + __local Dtype4* work1, + __local Dtype4* work2, + __local Dtype4* work3, + int N, + int K, + int x_gid, + int lid, + Dtype alpha, + Dtype beta, + __global Dtype* dstc0, + __global Dtype* dstc1, + __global Dtype* dstc2, + __global Dtype* dstc3) +{ + __local Dtype* work_each0 = (__local Dtype*)(work0 + lid); + __local Dtype* work_each1 = (__local Dtype*)(work1 + lid); + __local Dtype* work_each2 = (__local Dtype*)(work2 + lid); + __local Dtype* work_each3 = (__local Dtype*)(work3 + lid); + + int rows = N - x_gid * 4; + + Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + + int i = lid; + while( i < K / 4) { + const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; + const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; + const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]}; + const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]}; +#pragma unrol + for(int j = 0; j < rows; ++j) { + dot0[j] += a0 * vload4(i, srcb_read + j * K); + dot1[j] += a1 * vload4(i, srcb_read + j * K); + dot2[j] += a2 * vload4(i, srcb_read + j * K); + dot3[j] += a3 * vload4(i, srcb_read + j * K); + } + + i += get_local_size(0); + } +#pragma unroll + for(int j = 0; j < rows; ++j) { + work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w; + work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w; + work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w; + work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w; + } + + if(i == K / 4) { + short tail_items = K % 4; + + if(tail_items != 0) { + const __global Dtype *srcb_tail = srcb_read + i * 4; + + const __global Dtype *srca_tail0 = srca_read0 + i * 4; + const __global Dtype *srca_tail1 = srca_read1 + i * 4; + const __global Dtype *srca_tail2 = srca_read2 + i * 4; + const __global Dtype *srca_tail3 = srca_read3 + i * 4; +#pragma unroll + for(short i = 0; i < tail_items; ++i) { + const Dtype at0 = srca_tail0[i]; + const Dtype at1 = srca_tail1[i]; + const Dtype at2 = srca_tail2[i]; + const Dtype at3 = srca_tail3[i]; +#pragma unroll + for(int j = 0; j < rows; ++j) { + work_each0[j] += at0 * srcb_tail[i + j * K]; + work_each1[j] += at1 * srcb_tail[i + j * K]; + work_each2[j] += at2 * srcb_tail[i + j * K]; + work_each3[j] += at3 * srcb_tail[i + j * K]; + } + } + } + } + + for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) { + work0[lid] += work0[lid+stride]; + work1[lid] += work1[lid+stride]; + work2[lid] += work2[lid+stride]; + work3[lid] += work3[lid+stride]; + } + } + + if(lid == 0) { +#pragma unroll + for(int j = 0; j < rows; ++j) { + dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)]; + dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)]; + dstc2[(x_gid * 4 + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)]; + dstc3[(x_gid * 4 + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)]; + } + } +} + +__kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( + __global const Dtype * A, + int offA, + __global const Dtype * B, + int offB, + __global Dtype * C, + int offC, + int M, + int N, + int K, + float alpha_f, + float beta_f) +{ + Dtype alpha = (Dtype)alpha_f; + Dtype beta = (Dtype)beta_f; + int x_gid = get_group_id(0); + int lid = get_local_id(0); + int lsize = get_local_size(0); + + const __global Dtype *srca_read0 = A + offA; + const __global Dtype *srca_read1 = srca_read0 + K; + const __global Dtype *srca_read2 = srca_read1 + K; + const __global Dtype *srca_read3 = srca_read2 + K; + + const __global Dtype *srcb_read = B + x_gid * 4 * K + offB; + + __global Dtype4 *dstc0 = (__global Dtype4*)(C + offC); + __global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N); + __global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N); + __global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N); + + __local Dtype4 work0[SLM_SIZE]; + __local Dtype4 work1[SLM_SIZE]; + __local Dtype4 work2[SLM_SIZE]; + __local Dtype4 work3[SLM_SIZE]; + __local Dtype* work_each0 = (__local Dtype*)(work0 + lid); + __local Dtype* work_each1 = (__local Dtype*)(work1 + lid); + __local Dtype* work_each2 = (__local Dtype*)(work2 + lid); + __local Dtype* work_each3 = (__local Dtype*)(work3 + lid); + + if(x_gid == N / 4) { + TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) \ + (srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, \ + work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, \ + (__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3); + } else { + Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + + int kid = lid; + while( kid < K / 4) { + const Dtype4 b0 = vload4(kid, srca_read0); + const Dtype4 b1 = vload4(kid, srca_read1); + const Dtype4 b2 = vload4(kid, srca_read2); + const Dtype4 b3 = vload4(kid, srca_read3); +#pragma unroll + for(int j = 0; j < 4; ++j) { + Dtype4 a = vload4(kid, srcb_read + j * K); + dot0[j] += b0 * a; + dot1[j] += b1 * a; + dot2[j] += b2 * a; + dot3[j] += b3 * a; + } + kid += lsize; + } +#pragma unroll + for(int j = 0; j < 4; ++j) { + work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w; + work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w; + work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w; + work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w; + } + + if(kid == (K >> 2)) { + short tail_items = K % 4; + if(tail_items != 0) { + int offset = kid << 2; + const __global Dtype *srcb_tail = srcb_read + offset; + + const __global Dtype *srca_tail0 = srca_read0 + offset; + const __global Dtype *srca_tail1 = srca_read1 + offset; + const __global Dtype *srca_tail2 = srca_read2 + offset; + const __global Dtype *srca_tail3 = srca_read3 + offset; +#pragma unroll + for(short i = 0; i < tail_items; ++i) { + const Dtype at0 = srca_tail0[i]; + const Dtype at1 = srca_tail1[i]; + const Dtype at2 = srca_tail2[i]; + const Dtype at3 = srca_tail3[i]; +#pragma unroll + for(int j = 0; j < 4; ++j) { + work_each0[j] += at0 * srcb_tail[i + j * K]; + work_each1[j] += at1 * srcb_tail[i + j * K]; + work_each2[j] += at2 * srcb_tail[i + j * K]; + work_each3[j] += at3 * srcb_tail[i + j * K]; + } + } + } + } + + for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) { + work0[lid] += work0[lid+stride]; + work1[lid] += work1[lid+stride]; + work2[lid] += work2[lid+stride]; + work3[lid] += work3[lid+stride]; + } + } + + if(lid == 0) { + dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid]; + dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid]; + dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid]; + dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid]; + } + } +} +#undef SLM_SIZE + +#define SLM_SIZE 16 +__kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( + __global const Dtype * A, + int offA, + __global const Dtype * B, + int offB, + __global Dtype * C, + int offC, + int M, + int N, + int K, + float alpha_f, + float beta_f) +{ + Dtype alpha = (Dtype)alpha_f; + Dtype beta = (Dtype)beta_f; + int x_gid = get_group_id(0); + int lid = get_local_id(0); + int lsize = get_local_size(0); + + const __global Dtype *srca_read0 = A + offA; + const __global Dtype *srca_read1 = srca_read0 + K; + const __global Dtype *srca_read2 = srca_read1 + K; + const __global Dtype *srca_read3 = srca_read2 + K; + const __global Dtype *srca_read4 = srca_read3 + K; + const __global Dtype *srca_read5 = srca_read4 + K; + const __global Dtype *srca_read6 = srca_read5 + K; + const __global Dtype *srca_read7 = srca_read6 + K; + + const __global Dtype *srcb_read = B + x_gid * K + offB; + + __global Dtype *dstc0 = C + offC; + __global Dtype *dstc1 = dstc0 + N; + __global Dtype *dstc2 = dstc1 + N; + __global Dtype *dstc3 = dstc2 + N; + __global Dtype *dstc4 = dstc3 + N; + __global Dtype *dstc5 = dstc4 + N; + __global Dtype *dstc6 = dstc5 + N; + __global Dtype *dstc7 = dstc6 + N; + + __local Dtype work0[SLM_SIZE]; + __local Dtype work1[SLM_SIZE]; + __local Dtype work2[SLM_SIZE]; + __local Dtype work3[SLM_SIZE]; + __local Dtype work4[SLM_SIZE]; + __local Dtype work5[SLM_SIZE]; + __local Dtype work6[SLM_SIZE]; + __local Dtype work7[SLM_SIZE]; + + Dtype4 dot0 = (Dtype4)(0.); + Dtype4 dot1 = (Dtype4)(0.); + Dtype4 dot2 = (Dtype4)(0.); + Dtype4 dot3 = (Dtype4)(0.); + Dtype4 dot4 = (Dtype4)(0.); + Dtype4 dot5 = (Dtype4)(0.); + Dtype4 dot6 = (Dtype4)(0.); + Dtype4 dot7 = (Dtype4)(0.); + + int kid = lid; + while( kid < K / 4) { + const Dtype4 a0 = vload4(kid, srca_read0); + const Dtype4 a1 = vload4(kid, srca_read1); + const Dtype4 a2 = vload4(kid, srca_read2); + const Dtype4 a3 = vload4(kid, srca_read3); + const Dtype4 a4 = vload4(kid, srca_read4); + const Dtype4 a5 = vload4(kid, srca_read5); + const Dtype4 a6 = vload4(kid, srca_read6); + const Dtype4 a7 = vload4(kid, srca_read7); + Dtype4 b = vload4(kid, srcb_read); + dot0 += a0 * b; + dot1 += a1 * b; + dot2 += a2 * b; + dot3 += a3 * b; + dot4 += a4 * b; + dot5 += a5 * b; + dot6 += a6 * b; + dot7 += a7 * b; + + kid += lsize; + } + work0[lid] = dot0.x + dot0.y + dot0.z + dot0.w; + work1[lid] = dot1.x + dot1.y + dot1.z + dot1.w; + work2[lid] = dot2.x + dot2.y + dot2.z + dot2.w; + work3[lid] = dot3.x + dot3.y + dot3.z + dot3.w; + work4[lid] = dot4.x + dot4.y + dot4.z + dot4.w; + work5[lid] = dot5.x + dot5.y + dot5.z + dot5.w; + work6[lid] = dot6.x + dot6.y + dot6.z + dot6.w; + work7[lid] = dot7.x + dot7.y + dot7.z + dot7.w; + + if(kid == (K >> 2)) { + short tail_items = K % 4; + if(tail_items != 0) { + int offset = kid << 2; + const __global Dtype *srcb_tail = srcb_read + offset; + + const __global Dtype *srca_tail0 = srca_read0 + offset; + const __global Dtype *srca_tail1 = srca_read1 + offset; + const __global Dtype *srca_tail2 = srca_read2 + offset; + const __global Dtype *srca_tail3 = srca_read3 + offset; + const __global Dtype *srca_tail4 = srca_read4 + offset; + const __global Dtype *srca_tail5 = srca_read5 + offset; + const __global Dtype *srca_tail6 = srca_read6 + offset; + const __global Dtype *srca_tail7 = srca_read7 + offset; +#pragma unroll + for(short item = 0; item < tail_items; ++item) { + work0[lid] += srca_tail0[item] * srcb_tail[item]; + work1[lid] += srca_tail1[item] * srcb_tail[item]; + work2[lid] += srca_tail2[item] * srcb_tail[item]; + work3[lid] += srca_tail3[item] * srcb_tail[item]; + work4[lid] += srca_tail4[item] * srcb_tail[item]; + work5[lid] += srca_tail5[item] * srcb_tail[item]; + work6[lid] += srca_tail6[item] * srcb_tail[item]; + work7[lid] += srca_tail7[item] * srcb_tail[item]; + } + } + } + + for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) { + work0[lid] += work0[lid+stride]; + work1[lid] += work1[lid+stride]; + work2[lid] += work2[lid+stride]; + work3[lid] += work3[lid+stride]; + work4[lid] += work4[lid+stride]; + work5[lid] += work5[lid+stride]; + work6[lid] += work6[lid+stride]; + work7[lid] += work7[lid+stride]; + } + } + + if(lid == 0) { + dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid]; + dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid]; + dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid]; + dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid]; + dstc4[x_gid] = alpha * work4[0] + beta * dstc4[x_gid]; + dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid]; + dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid]; + dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid]; + } +} +#undef SLM_SIZE + +#define VEC_SIZE 4 +#define LWG_HEIGHT 4 +#define TILE_M 8 +#define TILE_K 16 +#define TILE_N 32 + +__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) +__kernel void TEMPLATE(gemm_buffer_TN, Dtype)( + const __global float *src0, int off0, + const __global float *src1, int off1, + __global float *dst, int offd, + int M, + int N, + int K, + float alpha, + float beta, + int start_index) + +{ + const int group_x = get_group_id(0); + const int group_y = get_group_id(1); + const int local_x = get_local_id(0); + const int local_y = get_local_id(1); + const int global_x = get_global_id(0); + const int global_y = get_global_id(1); + + float4 brow; + + __global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + + const __global float *src0_read = src0 + (local_x * ( TILE_K / 8 ) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; + + const __global float *src1_read0 = src1 + local_x * VEC_SIZE + ( group_x * TILE_N ) + start_index * N + off1; + + float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : (beta * ((__global float4 *)dst_write0)[0]); + float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + N))[0] : (beta * ((__global float4 *)(dst_write0 + N))[0]); + float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 2 * N))[0]); + float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 3 * N))[0]); + float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 4 * N))[0]); + float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 5 * N))[0]); + float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 6 * N))[0]); + float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 7 * N))[0]); + + int end_index = min(start_index + 256, K); + while( start_index + TILE_K <= end_index ) { + float8 arow0 = alpha * ((__global float8 *)src0_read)[0]; + float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0]; + +#define MM_DOT_PRODUCT( _arow ) \ + brow = ((__global float4 *)src1_read0)[0]; src1_read0 += N; \ + dot00 = mad( (float4)(_arow.s0), brow, dot00 ); \ + dot01 = mad( (float4)(_arow.s1), brow, dot01 ); \ + dot02 = mad( (float4)(_arow.s2), brow, dot02 ); \ + dot03 = mad( (float4)(_arow.s3), brow, dot03 ); \ + dot04 = mad( (float4)(_arow.s4), brow, dot04 ); \ + dot05 = mad( (float4)(_arow.s5), brow, dot05 ); \ + dot06 = mad( (float4)(_arow.s6), brow, dot06 ); \ + dot07 = mad( (float4)(_arow.s7), brow, dot07 ); \ + + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 0 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 0 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 1 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 1 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 2 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 2 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 3 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 3 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 4 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 4 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 5 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 5 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 6 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 6 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 7 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 7 ) ); +#undef MM_DOT_PRODUCT + + src0_read += TILE_K * M; + start_index += TILE_K; + } + + if(start_index < end_index) { + float8 arow0 = ((start_index + local_x * 2) < K) ? (alpha * ((__global float8 *)src0_read)[0]) : 0.0f; + float8 arow1 = ((start_index + local_x * 2 + 1) < K) ? (alpha * ((__global float8 *)(src0_read + M))[0]) : 0.0f; + +#define MM_DOT_PRODUCT( _arow ) \ + brow = (start_index < K) ? ((__global float4 *)src1_read0)[0] : 0.0f; src1_read0 += N; start_index++; \ + dot00 = mad( (float4)(_arow.s0), brow, dot00 ); \ + dot01 = mad( (float4)(_arow.s1), brow, dot01 ); \ + dot02 = mad( (float4)(_arow.s2), brow, dot02 ); \ + dot03 = mad( (float4)(_arow.s3), brow, dot03 ); \ + dot04 = mad( (float4)(_arow.s4), brow, dot04 ); \ + dot05 = mad( (float4)(_arow.s5), brow, dot05 ); \ + dot06 = mad( (float4)(_arow.s6), brow, dot06 ); \ + dot07 = mad( (float4)(_arow.s7), brow, dot07 ); \ + + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 0 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 0 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 1 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 1 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 2 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 2 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 3 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 3 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 4 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 4 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 5 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 5 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 6 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 6 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 7 ) ); + MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 7 ) ); +#undef MM_DOT_PRODUCT + } + + if(global_x * 4 < N && global_y * 8 < M) { + if(mad24(global_x, 4, 3) < N) { + __global float4 *dst_write = (__global float4 *)dst_write0; + dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0; + if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; } + } else if(mad24(global_x, 4, 2) < N) { + __global float2 *dst_write = (__global float2 *)dst_write0; + dst_write[0] = dot00.xy; dst_write0[2] = dot00.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + if(mad24(global_y, 8, 1) < M) { + dst_write[0] = dot01.xy; dst_write0[2] = dot01.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 2) < M) { + dst_write[0] = dot02.xy; dst_write0[2] = dot02.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 3) < M) { + dst_write[0] = dot03.xy; dst_write0[2] = dot03.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 4) < M) { + dst_write[0] = dot04.xy; dst_write0[2] = dot04.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 5) < M) { + dst_write[0] = dot05.xy; dst_write0[2] = dot05.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 6) < M) { + dst_write[0] = dot06.xy; dst_write0[2] = dot06.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 7) < M) { + dst_write[0] = dot07.xy; dst_write0[2] = dot07.z; + } + } else if(mad24(global_x, 4, 1) < N) { + __global float2 *dst_write = (__global float2 *)dst_write0; + dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; + if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; } + } else { + dst_write0[0] = dot00.x; dst_write0 += N; + if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; } + } + } +} + +#undef VEC_SIZE +#undef LWG_HEIGHT +#undef TILE_M +#undef TILE_K +#undef TILE_N + +#define VEC_SIZE 4 +#define LWG_HEIGHT 4 +#define TILE_M 8 +#define TILE_K 16 +#define TILE_N 32 + +__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) +__kernel void TEMPLATE(gemm_buffer_TT, Dtype)( + const __global float *src0, int off0, + const __global float *src1, int off1, + __global float *dst, int offd, + int M, + int N, + int K, + float alpha, + float beta, + int start_index) + +{ + const int group_x = get_group_id(0); + const int group_y = get_group_id(1); + const int local_x = get_local_id(0); + const int local_y = get_local_id(1); + const int global_x = get_global_id(0); + const int global_y = get_global_id(1); + + float8 dot0 = 0.f; + float8 dot1 = 0.f; + float8 dot2 = 0.f; + float8 dot3 = 0.f; + + float16 brow0; + float16 brow1; + float16 brow2; + float16 brow3; + + __global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + + const __global float *src0_read = src0 + (local_x * ( TILE_K / 8 ) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; + + const __global float *src1_read0 = src1 + (local_x * VEC_SIZE + ( group_x * TILE_N )) * K + start_index + off1; + + float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : (beta * ((__global float4 *)dst_write0)[0]); + float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + N))[0] : (beta * ((__global float4 *)(dst_write0 + N))[0]); + float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 2 * N))[0]); + float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 3 * N))[0]); + float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 4 * N))[0]); + float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 5 * N))[0]); + float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 6 * N))[0]); + float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 7 * N))[0]); + + int end_index = min(start_index + 256, K); + while( start_index + TILE_K <= end_index ) { + brow0 = ((__global float16 *)src1_read0)[0]; + brow1 = ((__global float16 *)(src1_read0 + K))[0]; + brow2 = ((__global float16 *)(src1_read0 + 2 * K))[0]; + brow3 = ((__global float16 *)(src1_read0 + 3 * K))[0]; + + float8 arow0 = alpha * ((__global float8 *)src0_read)[0]; + float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0]; + +#define MM_DOT_PRODUCT( _brow, _dot) \ + _dot = mad( intel_sub_group_shuffle( arow0, 0 ), (float8)_brow.s0, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow1, 0 ), (float8)_brow.s1, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow0, 1 ), (float8)_brow.s2, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow1, 1 ), (float8)_brow.s3, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow0, 2 ), (float8)_brow.s4, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow1, 2 ), (float8)_brow.s5, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow0, 3 ), (float8)_brow.s6, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow1, 3 ), (float8)_brow.s7, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow0, 4 ), (float8)_brow.s8, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow1, 4 ), (float8)_brow.s9, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow0, 5 ), (float8)_brow.sa, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow1, 5 ), (float8)_brow.sb, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow0, 6 ), (float8)_brow.sc, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow1, 6 ), (float8)_brow.sd, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow0, 7 ), (float8)_brow.se, _dot ); \ + _dot = mad( intel_sub_group_shuffle( arow1, 7 ), (float8)_brow.sf, _dot ); \ + + MM_DOT_PRODUCT( brow0, dot0 ); + MM_DOT_PRODUCT( brow1, dot1 ); + MM_DOT_PRODUCT( brow2, dot2 ); + MM_DOT_PRODUCT( brow3, dot3 ); +#undef MM_DOT_PRODUCT + + src1_read0 += TILE_K; + src0_read += TILE_K * M; + start_index += TILE_K; + } + + if(start_index < end_index) { + brow0 = ((__global float16 *)src1_read0)[0]; src1_read0 += K; + brow1 = ((__global float16 *)src1_read0)[0]; src1_read0 += K; + brow2 = ((__global float16 *)src1_read0)[0]; src1_read0 += K; + brow3 = ((__global float16 *)src1_read0)[0]; + + float8 arow0 = alpha * ((__global float8 *)src0_read)[0]; + float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0]; + +#define MM_DOT_PRODUCT( _brow, _dot) \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 0 ), (float8)_brow.s0, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 0 ), (float8)_brow.s1, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 1 ), (float8)_brow.s2, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 1 ), (float8)_brow.s3, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 2 ), (float8)_brow.s4, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 2 ), (float8)_brow.s5, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 3 ), (float8)_brow.s6, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 3 ), (float8)_brow.s7, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 4 ), (float8)_brow.s8, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 4 ), (float8)_brow.s9, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 5 ), (float8)_brow.sa, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 5 ), (float8)_brow.sb, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 6 ), (float8)_brow.sc, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 6 ), (float8)_brow.sd, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 7 ), (float8)_brow.se, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 7 ), (float8)_brow.sf, _dot ) : _dot; \ + + int w = start_index; + MM_DOT_PRODUCT( brow0, dot0 ); + w = start_index; + MM_DOT_PRODUCT( brow1, dot1 ); + w = start_index; + MM_DOT_PRODUCT( brow2, dot2 ); + w = start_index; + MM_DOT_PRODUCT( brow3, dot3 ); +#undef MM_DOT_PRODUCT + } + + dot00 += (float4)(dot0.s0, dot1.s0, dot2.s0, dot3.s0); + dot01 += (float4)(dot0.s1, dot1.s1, dot2.s1, dot3.s1); + dot02 += (float4)(dot0.s2, dot1.s2, dot2.s2, dot3.s2); + dot03 += (float4)(dot0.s3, dot1.s3, dot2.s3, dot3.s3); + dot04 += (float4)(dot0.s4, dot1.s4, dot2.s4, dot3.s4); + dot05 += (float4)(dot0.s5, dot1.s5, dot2.s5, dot3.s5); + dot06 += (float4)(dot0.s6, dot1.s6, dot2.s6, dot3.s6); + dot07 += (float4)(dot0.s7, dot1.s7, dot2.s7, dot3.s7); + + if(global_x * 4 < N && global_y * 8 < M) { + if(mad24(global_x, 4, 3) < N) { + __global float4 *dst_write = (__global float4 *)dst_write0; + dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0; + if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + else return; + if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; } + } else if(mad24(global_x, 4, 2) < N) { + __global float2 *dst_write = (__global float2 *)dst_write0; + dst_write[0] = dot00.xy; dst_write0[2] = dot00.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + if(mad24(global_y, 8, 1) < M) { + dst_write[0] = dot01.xy; dst_write0[2] = dot01.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 2) < M) { + dst_write[0] = dot02.xy; dst_write0[2] = dot02.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 3) < M) { + dst_write[0] = dot03.xy; dst_write0[2] = dot03.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 4) < M) { + dst_write[0] = dot04.xy; dst_write0[2] = dot04.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 5) < M) { + dst_write[0] = dot05.xy; dst_write0[2] = dot05.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 6) < M) { + dst_write[0] = dot06.xy; dst_write0[2] = dot06.z; + dst_write0 += N; dst_write = (__global float2 *)dst_write0; + } else + return; + if(mad24(global_y, 8, 7) < M) { + dst_write[0] = dot07.xy; dst_write0[2] = dot07.z; + } + } else if(mad24(global_x, 4, 1) < N) { + __global float2 *dst_write = (__global float2 *)dst_write0; + dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; + if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + else return; + if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; } + } else { + dst_write0[0] = dot00.x; dst_write0 += N; + if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; } + else return; + if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; } + } + } +} + +#undef VEC_SIZE +#undef LWG_HEIGHT +#undef TILE_M +#undef TILE_K +#undef TILE_N diff --git a/src/caffe/greentea/cl_kernels/matvec_mul.cl b/src/caffe/greentea/cl_kernels/matvec_mul.cl new file mode 100644 index 00000000000..d0b9bacb385 --- /dev/null +++ b/src/caffe/greentea/cl_kernels/matvec_mul.cl @@ -0,0 +1,143 @@ +#ifndef __OPENCL_VERSION__ +#include "header.cl" +#endif + +__kernel void TEMPLATE(matvec_mul4,Dtype)( + __global const float * A, + int offA, + unsigned int A_col_size, + unsigned int trail_item, + __global const float * v, + int offv, + float alpha, + float beta, + __global float4 * result, + int offr, + __local float4 * work) +{ + unsigned int row_gid = get_group_id(0); + unsigned int lid = get_local_id(0); + const __global float *src0_read = A + row_gid * 4 * A_col_size + offA; + const __global float *src1_read = v + offv; + result = (__global float4*)((__global float*)result + offr); + float4 dot0 = (float4)(0.f); + float4 dot1 = (float4)(0.f); + float4 dot2 = (float4)(0.f); + float4 dot3 = (float4)(0.f); + + unsigned int i = lid; + while( i < A_col_size / 4) { + const float4 a0 = vload4(i, src0_read); + const float4 a1 = vload4(i, src0_read + A_col_size); + const float4 a2 = vload4(i, src0_read + 2 * A_col_size); + const float4 a3 = vload4(i, src0_read + 3 * A_col_size); + + const float4 b0 = vload4(i, src1_read); + + dot0 += a0 * b0; + dot1 += a1 * b0; + dot2 += a2 * b0; + dot3 += a3 * b0; + + i += get_local_size(0); + } + + work[lid].s0 = dot0.x + dot0.y + dot0.z + dot0.w; + work[lid].s1 = dot1.x + dot1.y + dot1.z + dot1.w; + work[lid].s2 = dot2.x + dot2.y + dot2.z + dot2.w; + work[lid].s3 = dot3.x + dot3.y + dot3.z + dot3.w; + + if(i == A_col_size / 4) + { + if(trail_item != 0) + { + const __global float *src0_trail = src0_read + i * 4; + const __global float *src1_trail = src1_read + i * 4; + for(unsigned int i = 0; i < trail_item; ++i) { + const float at0 = src0_trail[i]; + const float at1 = src0_trail[i + A_col_size]; + const float at2 = src0_trail[i + 2 * A_col_size]; + const float at3 = src0_trail[i + 3 * A_col_size]; + + const float bt = src1_trail[i]; + + work[lid].s0 += at0 * bt; + work[lid].s1 += at1 * bt; + work[lid].s2 += at2 * bt; + work[lid].s3 += at3 * bt; + } + } + + } + + for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) + work[lid] += work[lid+stride]; + } + if(lid == 0) + result[row_gid] = alpha * work[0] + beta * result[row_gid]; +} + +/* This kernel used for the trailing rows when row_of_A %4 !=0 */ +__kernel void TEMPLATE(matvec_mul1,Dtype)( + __global const float * A, + int offA, + unsigned int A_col_size, + unsigned int row_offset, + unsigned int trail_item, + __global const float * v, + int offv, + float alpha, + float beta, + __global float * result, + int offr, + __local float * work) +{ + unsigned int row_gid = get_group_id(0); + unsigned int lid = get_local_id(0); + + const __global float *src0_read = A + (row_offset + row_gid) * A_col_size + offA; + const __global float *src1_read = v + + offv; + result = result + offr; + float4 dot0 = (float4)(0.f); + + unsigned int i = lid; + while( i < A_col_size / 4) + { + const float4 a0 = vload4(i, src0_read); + const float4 b0 = vload4(i, src1_read); + + dot0 += a0 * b0; + i += get_local_size(0); + } + + work[lid] = dot0.x + dot0.y + dot0.z + dot0.w; + + if(i == A_col_size / 4) + { + if(trail_item != 0) + { + const __global float *src0_trail = src0_read + i * 4; + const __global float *src1_trail = src1_read + i * 4; + for(unsigned int i = 0; i < trail_item; ++i) { + const float at0 = src0_trail[i]; + const float bt = src1_trail[i]; + + work[lid] += at0 * bt; + } + } + + } + for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) + work[lid] += work[lid+stride]; + } + + if(lid == 0) { + result[row_gid+row_offset] *= beta; + result[row_gid+row_offset] += alpha * work[0]; + //result[row_gid+row_offset] = alpha * work[0] + beta * result[row_gid+row_offset]; + } +} diff --git a/src/caffe/greentea/greentea_math_functions.cpp b/src/caffe/greentea/greentea_math_functions.cpp index 802e66dd1dd..60739c2cec0 100644 --- a/src/caffe/greentea/greentea_math_functions.cpp +++ b/src/caffe/greentea/greentea_math_functions.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -29,22 +30,23 @@ #include "viennacl/ocl/device.hpp" #include "viennacl/ocl/platform.hpp" +#include "caffe/util/benchmark.hpp" #include "caffe/util/math_functions.hpp" #if defined(USE_CLBLAS) #include // NOLINT #elif defined(USE_CLBLAST) #include // NOLINT -#else - #include "viennacl/linalg/inner_prod.hpp" - #include "viennacl/linalg/norm_1.hpp" - #include "viennacl/linalg/norm_2.hpp" - #include "viennacl/linalg/norm_inf.hpp" - #include "viennacl/linalg/prod.hpp" - #include "viennacl/matrix.hpp" - #include "viennacl/scalar.hpp" - #include "viennacl/vector.hpp" #endif +#include "viennacl/linalg/inner_prod.hpp" +#include "viennacl/linalg/norm_1.hpp" +#include "viennacl/linalg/norm_2.hpp" +#include "viennacl/linalg/norm_inf.hpp" +#include "viennacl/linalg/prod.hpp" +#include "viennacl/matrix.hpp" +#include "viennacl/scalar.hpp" +#include "viennacl/vector.hpp" + // ViennaCL 1.5.1 compability fix #ifndef VIENNACL_MINOR_VERSION @@ -174,14 +176,712 @@ template void greentea_copy(const int_tp N, const cl_mem X, const int_tp offY, viennacl::ocl::context *ctx); +struct gemm_callback_arg { + std::vector evs; + std::vector imgs; +}; + +static void CL_CALLBACK gemm_callback (cl_event event, + cl_int event_command_exec_status, + void *user_data) { + struct gemm_callback_arg *arg = (struct gemm_callback_arg *) user_data; + for(int i = 0; i < arg->evs.size(); i++) { + clReleaseEvent(arg->evs[i]); + } + + for(int i = 0; i < arg->imgs.size(); i++) { + clReleaseMemObject(arg->imgs[i]); + } + delete arg; +} + +// Create and copy buffer to image for GEMM's matrix A and B. +// Will return image to caller if the input image is NULL. Otherwise, +// will use the image directly. It's caller's responsibility to +// release the created image. +void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, + cl_mem *image, cl_mem buffer, int offset, + bool is_matrix_a, bool transpose, + bool padding, int padded_height, + int padded_width, int height, + int width, int wait_list_size, + cl_event *wait_list, + cl_event *event) { + + viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) + ->program(); + cl_image_desc desc; + cl_image_format format; + + memset(&desc, 0, sizeof(desc)); + if (!is_matrix_a && transpose) { + // For matrix B with transpose, we need to handle them differently. + // As we can't use the sub group block read to get a row easily, + // we have to use CL_FLOAT type with read_imagef to get the row. + cl_int err; + format.image_channel_data_type = CL_FLOAT; + desc.image_type = CL_MEM_OBJECT_IMAGE2D; + if ( width % 4 == 0 ) { + desc.image_width = width / 4; + format.image_channel_order = CL_RGBA; + } else { + desc.image_width = width; + format.image_channel_order = CL_R; + } + desc.image_height = height; + // if (offB == 0 && (desc.image_width % 4) == 0 && N > 8 && K > 8) + // desc.mem_object = buffer; + if (*image == NULL) { + *image = clCreateImage( + ctx.handle().get(), + CL_MEM_READ_WRITE, + &format, + &desc, + NULL, + &err); + OCL_CHECK(err); + } + // if (!desc.mem_object) { + size_t origin[] = {0, 0, 0}; + size_t region[] = {(size_t)desc.image_width, + (size_t)desc.image_height, 1}; + OCL_CHECK(clEnqueueCopyBufferToImage(ctx.get_queue().handle().get(), + buffer, *image, sizeof(float) * offset, + origin, region, wait_list_size, + wait_list, event)); + // } + return; + } + + if (*image == NULL) { + desc.image_type = CL_MEM_OBJECT_IMAGE2D; + format.image_channel_data_type = CL_UNSIGNED_INT8; + format.image_channel_order = CL_RGBA; + if (!padding) { + //if (width % 4 == 0 && offset == 0 && height > 8 && width > 8) + // desc.buffer = buffer; + desc.image_width = width; + desc.image_height = height; + } else { + desc.image_width = padded_width; + desc.image_height = padded_height; + } + cl_int err; + *image = clCreateImage(ctx.handle().get(), + desc.buffer ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE, + &format, + &desc, + NULL, + &err); + OCL_CHECK(err); + } + if (!padding && desc.buffer != NULL) + return; + if (!padding && desc.buffer == NULL) { + // copy without padding. + size_t origin[] = {0, 0, 0}; + size_t region[] = {(size_t)width, (size_t)height, 1}; + OCL_CHECK(clEnqueueCopyBufferToImage(ctx.get_queue().handle().get(), + buffer, *image, sizeof(float) * offset, + origin, region, wait_list_size, wait_list, event)); + return; + } + viennacl::ocl::kernel &oclk_gemm_copy = program.get_kernel( + "gemm_buffer_copy_image_float"); + + size_t global_copy[2]; + global_copy[0] = padding ? padded_width : width; + global_copy[1] = padding ? padded_height : height; + oclk_gemm_copy.arg(0, WrapHandle(buffer, &ctx)); + oclk_gemm_copy.arg(1, WrapHandle(*image, &ctx)); + oclk_gemm_copy.arg(2, offset); + oclk_gemm_copy.arg(3, width); + oclk_gemm_copy.arg(4, height); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_gemm_copy.handle().get(), + 2, NULL, global_copy, NULL, + wait_list_size, wait_list, + event)); +} + +// #define GEMM_PROFILING +#ifdef GEMM_PROFILING +#define START_TIMER(n) \ + clFinish(ctx.get_queue().handle().get()); \ + gettimeofday(&start[n], NULL); + +#define STOP_TIMER(n) \ + clFinish(ctx.get_queue().handle().get()); \ + gettimeofday(&end[n], NULL); +#else +#define START_TIMER(n) +#define STOP_TIMER(n) +#endif + +enum gemm_type_t { + GEMM_TYPE_NONE = 0, + GEMM_TYPE_CLBLAS, + GEMM_TYPE_CLBLAST, + GEMM_TYPE_VIENNACL, + GEMM_TYPE_FAST_IMAGE_32_1, + GEMM_TYPE_FAST_IMAGE_32_2, + GEMM_TYPE_FAST_BUFFER, + GEMM_TYPE_MAX +}; + +static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int_tp M, + const int_tp N, const int_tp K, const float alpha, + const cl_mem A, const int_tp offA, const cl_mem B, + const int_tp offB, const float beta, cl_mem C, + const int_tp offC, bool is_image_a, bool is_image_b, + enum gemm_type_t gemm_type) { + CHECK_EQ(gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 + || gemm_type == GEMM_TYPE_FAST_IMAGE_32_2, true) + << "Invalid fast image gemm type." << std::endl; + if (is_image_a) + CHECK_EQ(offA, 0) << "Invalid input image offset." << std::endl; + + if (is_image_b) + CHECK_EQ(offB, 0) << "Invalid input image offset." << std::endl; + + #ifdef GEMM_PROFILING + struct timeval start[4], end[4]; + for(int i = 0; i < 4; i++) + start[i] = end[i]; + #endif + uint32_t widthA = (TransA == CblasNoTrans) ? K : M; + uint32_t heightA = (TransA == CblasNoTrans) ? M : K; + uint32_t widthB = (TransB == CblasNoTrans) ? N : K; + uint32_t heightB = (TransB == CblasNoTrans) ? K : N; + // To fix the edge problem casued by the sub group block read. + // we have to pad the image if it's not multiple of tile. + // just padding one line is enough as the sub group block read + // will clamp to edge according to the spec. + uint32_t padded_k = K + ((K & 7) ? 1 : 0); + uint32_t imageA_w = (TransA == CblasNoTrans) ? padded_k : M; + uint32_t imageA_h = (TransA == CblasNoTrans) ? M : padded_k; + uint32_t imageB_w = (TransB == CblasNoTrans) ? N : padded_k; + uint32_t imageB_h = (TransB == CblasNoTrans) ? padded_k : N; + viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) + ->program(); + + cl_mem ImA = NULL; + cl_mem ImB = NULL; + + cl_event ev[5]; + cl_uint ev_idx = 0; + memset(ev, 0, sizeof(cl_event) * 5); + struct gemm_callback_arg * arg = new gemm_callback_arg; + if (TransB == CblasNoTrans) { + bool padding_A = false; + bool padding_B = false; + + if (!is_image_a && !is_image_b) { + if (M * K < N * K) + padding_B = true; + else + padding_A = true; + } + + START_TIMER(0); + if (!is_image_a) { + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, offA, + true, TransA != CblasNoTrans, + padding_A, imageA_h, imageA_w, + heightA, widthA, 0, NULL, &ev[ev_idx]); + if (ev[ev_idx] != NULL) + ev_idx++; + } + + STOP_TIMER(0); + START_TIMER(1); + + if (!is_image_b) { + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, offB, + false, false, + padding_B, imageB_h, imageB_w, + heightB, widthB, 0, NULL, &ev[ev_idx]); + if (ev[ev_idx] != NULL) + ev_idx++; + } + STOP_TIMER(1); + } else { + // We will use normal read_imagef to read image B when B has transpose. + // thus we don't need to pad image A at all. + START_TIMER(2); + if (!is_image_a) { + bool padding; + padding = !is_image_b; + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, offA, + true, TransA != CblasNoTrans, + padding, imageA_h, imageA_w, + heightA, widthA, 0, NULL, &ev[ev_idx]); + if (ev[ev_idx] != NULL) + ev_idx++; + } + STOP_TIMER(2); + } + if (!is_image_a) + arg->imgs.push_back(ImA); + else + ImA = A; + if (!is_image_b) + arg->imgs.push_back(ImB); + else + ImB = B; + + viennacl::ocl::kernel *oclk_gemm_float; + std::string kernel_name("gemm_"); + if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1) + kernel_name += "32_1_"; + else + kernel_name += "32_2_"; + + if (TransA == CblasNoTrans) + kernel_name += "N"; + else + kernel_name += "T"; + + if (TransB == CblasNoTrans) + kernel_name += "N_"; + else { + kernel_name += "T_"; + if (is_image_b) { + if (K % 4 == 0) + kernel_name += "VEC4_"; + else + kernel_name += "SCALAR_"; + } else { + kernel_name += "BUFFER_"; + } + } + + if (alpha == 1) + kernel_name += "1_"; + else + kernel_name += "0_"; + + if (beta == 0) + kernel_name += "0"; + else + kernel_name += "1"; + kernel_name += "_float"; + + oclk_gemm_float = &program.get_kernel(kernel_name); + + size_t global[2]; + + if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1) + global[0] = (size_t)( N + 7 ) & ~7; + else + global[0] = (size_t)( (N / 2 ) + 7 ) ^ ~7; + + global[1] = (size_t)(M + 31) / 32; + const size_t local[] = {8, 1}; + + cl_uint arg_idx = 0; + oclk_gemm_float->arg(arg_idx++, WrapHandle(ImA, &ctx)); + if (TransB == CblasNoTrans || is_image_b) + oclk_gemm_float->arg(arg_idx++, WrapHandle(ImB, &ctx)); + else { + oclk_gemm_float->arg(arg_idx++, WrapHandle(B, &ctx)); + oclk_gemm_float->arg(arg_idx++, offB); + } + oclk_gemm_float->arg(arg_idx++, WrapHandle(C, &ctx)); + oclk_gemm_float->arg(arg_idx++, offC); + oclk_gemm_float->arg(arg_idx++, M); + oclk_gemm_float->arg(arg_idx++, N); + oclk_gemm_float->arg(arg_idx++, alpha); + oclk_gemm_float->arg(arg_idx++, beta); + oclk_gemm_float->arg(arg_idx++, padded_k); + if (TransB != CblasNoTrans) + oclk_gemm_float->arg(arg_idx++, K); + + cl_event *wait_list = NULL; + if (ev_idx != 0) + wait_list = &ev[0]; + START_TIMER(3); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_gemm_float->handle().get(), 2, NULL, + global, local, ev_idx, + wait_list, &ev[ev_idx])); + STOP_TIMER(3); + #ifdef GEMM_PROFILING + double elapsed[4], total_elapsed; + for( int i = 0; i < 4; i++ ) { + elapsed[i] = (end[i].tv_sec - start[i].tv_sec) * 1e6 + (end[i].tv_usec - start[i].tv_usec); + total_elapsed += elapsed[i]; + } + printf("kernel name %s \n", kernel_name.c_str()); + printf("gemm %d %d %d %f %f %d %d %f %f %f %f %fGFLOPS %f GFLOPS\n", + M, K, N, alpha, beta, TransA == CblasNoTrans, TransB == CblasNoTrans, + elapsed[0] / 1000., elapsed[1] / 1000., elapsed[2] / 1000., + elapsed[3] / 1000., + M * N * ( 2*K - 1. ) / ( elapsed[3] * 1e3 ), + M * N * ( 2 * K - 1.) / ( total_elapsed * 1e3 ) ); + #endif + arg->evs.assign(ev, ev + ev_idx + 1); + clSetEventCallback(ev[ev_idx], CL_COMPLETE, &gemm_callback, (void*)arg); +} + +static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int_tp M, + const int_tp N, const int_tp K, const float alpha, + const cl_mem A, const int_tp offA, const cl_mem B, + const int_tp offB, const float beta, cl_mem C, + const int_tp offC, enum gemm_type_t gemm_type) { + CHECK_EQ(gemm_type == GEMM_TYPE_FAST_BUFFER, true) + << "Invalid fast buffer gemm type." << std::endl; + +#ifdef GEMM_PROFILING + struct timeval start[1], end[1]; + start[0] = end[0]; +#endif + + cl_event ev; + + viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) + ->program(); + size_t sub_group_size = 8; + bool is_small_batch = (M == 2 || M == 4 || M == 8); + viennacl::ocl::kernel *oclk_gemm_float; + std::string kernel_name("gemm_buffer_"); + if(TransA == CblasNoTrans && TransB == CblasNoTrans) { + kernel_name += "NN_float"; + } else if(TransA == CblasNoTrans && TransB != CblasNoTrans) { + if (M == 2) + kernel_name +="NT_M_2_float"; + else if (M == 4) + kernel_name +="NT_M_4_float"; + else if (M == 8) + kernel_name +="NT_M_8_float"; + else + kernel_name += "NT_float"; + } else if(TransA != CblasNoTrans && TransB == CblasNoTrans) { + kernel_name += "TN_float"; + } else { + kernel_name += "TT_float"; + } + oclk_gemm_float = &program.get_kernel(kernel_name); + size_t local[2] = {}; + size_t global[2] = {}; + if (TransA == CblasNoTrans && TransB != CblasNoTrans && is_small_batch ) { + if(M == 8) + local[0] = 16; + else if(M == 4) + local[0] = 32; + else + local[0] = 64; + local[1] = 1; + + if(M == 8) + global[0] = N * local[0]; + else + global[0] = (N + 3) / 4 * local[0]; + global[1] = 1; + } else { + size_t lx = sub_group_size; + size_t ly = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 16 : 4; + int dx = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 1 : 4; + int dy = 8; + size_t gx = (size_t)(N + dx - 1) / dx; + size_t gy = (size_t)(M + dy - 1) / dy; + global[0] = (gx + lx - 1) / lx * lx; + global[1] = (gy + ly - 1) / ly * ly; + local[0] = lx; + local[1] = ly; + } + + cl_uint arg_idx = 0; + oclk_gemm_float->arg(arg_idx++, WrapHandle(A, &ctx)); + oclk_gemm_float->arg(arg_idx++, offA); + oclk_gemm_float->arg(arg_idx++, WrapHandle(B, &ctx)); + oclk_gemm_float->arg(arg_idx++, offB); + oclk_gemm_float->arg(arg_idx++, WrapHandle(C, &ctx)); + oclk_gemm_float->arg(arg_idx++, offC); + oclk_gemm_float->arg(arg_idx++, M); + oclk_gemm_float->arg(arg_idx++, N); + oclk_gemm_float->arg(arg_idx++, K); + oclk_gemm_float->arg(arg_idx++, alpha); + oclk_gemm_float->arg(arg_idx++, beta); + + START_TIMER(0); + if(TransB == CblasNoTrans || TransA != CblasNoTrans) { + int stride = 256; + for(int start_index = 0; start_index < K; start_index += stride) { + oclk_gemm_float->arg(arg_idx, start_index); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_gemm_float->handle().get(), 2, NULL, + global, local, 0, + NULL, &ev)); + } + } else { + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_gemm_float->handle().get(), 2, NULL, + global, local, 0, + NULL, &ev)); + } + STOP_TIMER(0); + clReleaseEvent(ev); + +#ifdef GEMM_PROFILING + double total_elapsed; + total_elapsed = (end[0].tv_sec - start[0].tv_sec) * 1e6 + (end[0].tv_usec - start[0].tv_usec); + printf("kernel name %s \n", kernel_name.c_str()); + printf("gemm %d %d %d %f %f %d %d %f %fGFLOPS\n", + M, K, N, alpha, beta, TransA == CblasNoTrans, TransB == CblasNoTrans, + total_elapsed / 1000., M * N * ( 2 * K - 1.) / ( total_elapsed * 1e3 ) ); +#endif +} + +template +static void greentea_gpu_gemm_common(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int_tp M, + const int_tp N, const int_tp K, const Dtype alpha, + const cl_mem A, const int_tp offA, const cl_mem B, + const int_tp offB, const Dtype beta, cl_mem C, + const int_tp offC, bool is_image_a, bool is_image_b, + gemm_type_t gemm_type) { + + viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + int_tp lda = (TransA == CblasNoTrans) ? K : M; + int_tp ldb = (TransB == CblasNoTrans) ? N : K; + int_tp ldc = N; + + if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || + gemm_type == GEMM_TYPE_FAST_IMAGE_32_2) { + greentea_gpu_fast_image_gemm(ctx_id, TransA, TransB, M, N, K, + alpha, A, offA, B, offB, beta, C, + offC, is_image_a, is_image_b, + gemm_type); + } else if (gemm_type == GEMM_TYPE_FAST_BUFFER) { + greentea_gpu_fast_buffer_gemm(ctx_id, TransA, TransB, M, N, K, + alpha, A, offA, B, offB, beta, C, + offC, gemm_type); + } else if (gemm_type == GEMM_TYPE_CLBLAS) { + #if defined(USE_CLBLAS) + if ((M == 2 || M == 4 || M == 8) && std::is_same::value + && TransA == CblasNoTrans && TransB != CblasNoTrans) { + greentea_gpu_fast_buffer_gemm(ctx_id, TransA, TransB, M, N, K, + alpha, A, offA, B, offB, beta, C, + offC, GEMM_TYPE_FAST_BUFFER); + } else { + clblasOrder clOrder = clblasRowMajor; + clblasTranspose clTransA = + (TransA == CblasNoTrans) ? clblasNoTrans : clblasTrans; + clblasTranspose clTransB = + (TransB == CblasNoTrans) ? clblasNoTrans : clblasTrans; + + cl_command_queue queue = ctx.get_queue().handle().get(); + + if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK( + clblasSgemm(clOrder, clTransA, clTransB, + M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, + C, offC, ldc, 1, &queue, 0, NULL, NULL)); + } else { + GREENTEA_CL_BLAS_CHECK( + clblasDgemm(clOrder, clTransA, clTransB, + M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, + C, offC, ldc, 1, &queue, 0, NULL, NULL)); + } + } + #endif + } else if (gemm_type == GEMM_TYPE_CLBLAST) { + #ifdef USE_CLBLAST + cl_command_queue queue = ctx.get_queue().handle().get(); + + clblast::Layout layout = clblast::Layout::kRowMajor; + clblast::Transpose a_transpose = (TransA == CblasNoTrans) ? + clblast::Transpose::kNo : clblast::Transpose::kYes; + clblast::Transpose b_transpose = (TransB == CblasNoTrans) ? + clblast::Transpose::kNo : clblast::Transpose::kYes; + + if (std::is_same::value) { + GREENTEA_CLBLAST_CHECK( + clblast::Gemm( + layout, a_transpose, b_transpose, + M, N, K, + alpha, + A, offA, lda, + B, offB, ldb, + beta, + C, offC, ldc, + &queue)); + } else { + GREENTEA_CLBLAST_CHECK( + clblast::Gemm( + layout, a_transpose, b_transpose, + M, N, K, + alpha, + A, offA, lda, + B, offB, ldb, + beta, + C, offC, ldc, + &queue)); + } + #endif + } else if (gemm_type == GEMM_TYPE_VIENNACL) { + typedef typename viennacl::matrix_base::size_type size_type; + typedef typename viennacl::matrix_base::size_type difference_type; + + size_type A_size1 = static_cast((TransA == CblasTrans) ? K : M); + size_type A_size2 = static_cast((TransA == CblasTrans) ? M : K); + + size_type B_size1 = static_cast((TransB == CblasTrans) ? N : K); + size_type B_size2 = static_cast((TransB == CblasTrans) ? K : N); + + viennacl::matrix_base matA(A, ctx, A_size1, + size_type(0), + difference_type(1), + size_type(M), A_size2, + size_type(offA), + difference_type(1), + size_type(lda) + VCL_ROW_MAJOR); + + viennacl::matrix_base matB(B, ctx, B_size1, + size_type(0), + difference_type(1), + size_type(K), B_size2, + size_type(offB), + difference_type(1), + size_type(ldb) + VCL_ROW_MAJOR); + + viennacl::matrix_base matC(C, ctx, size_type(M), + size_type(0), + difference_type(1), + size_type(M), + size_type(N), + size_type(offC), + difference_type(1), + size_type(ldc) + VCL_ROW_MAJOR); + + if (TransA == CblasTrans && TransB == CblasTrans) + viennacl::linalg::prod_impl(viennacl::trans(matA), viennacl::trans(matB), + matC, alpha, beta); + else if (TransA == CblasTrans && TransB == CblasNoTrans) + viennacl::linalg::prod_impl(viennacl::trans(matA), matB, matC, alpha, + beta); + else if (TransA == CblasNoTrans && TransB == CblasTrans) + viennacl::linalg::prod_impl(matA, viennacl::trans(matB), matC, alpha, + beta); + else if (TransA == CblasNoTrans && TransB == CblasNoTrans) + viennacl::linalg::prod_impl(matA, matB, matC, alpha, beta); + } +} + +static void auto_tune_gemm(int ctx_id, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + gemm_type_t *tuned_gemm_types, + bool use_fast_gemm_image) { + viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + int M = 1024; + int K = 512; + int N = 1024; + cl_int err; + cl_mem A = clCreateBuffer(ctx.handle().get(), CL_MEM_ALLOC_HOST_PTR, M * K * sizeof(float), NULL, &err); + OCL_CHECK(err); + cl_mem B = clCreateBuffer(ctx.handle().get(), CL_MEM_ALLOC_HOST_PTR, K * N * sizeof(float), NULL, &err); + OCL_CHECK(err); + cl_mem C = clCreateBuffer(ctx.handle().get(), CL_MEM_ALLOC_HOST_PTR, M * N * sizeof(float), NULL, &err); + OCL_CHECK(err); + + std::vector gemm_tests; + + gemm_tests.push_back(GEMM_TYPE_VIENNACL); + if(use_fast_gemm_image) + gemm_tests.push_back(GEMM_TYPE_FAST_IMAGE_32_1); + gemm_tests.push_back(GEMM_TYPE_FAST_BUFFER); + +#ifdef USE_CLBLAS + gemm_tests.push_back(GEMM_TYPE_CLBLAS); +#endif +#ifdef USE_CLBLAST + gemm_tests.push_back(GEMM_TYPE_CLBLAST); +#endif + // warm up. + for( int i = 0; i < gemm_tests.size(); i++ ) { + greentea_gpu_gemm_common(ctx_id, TransA, TransB, M, N, K, + 1.0f, A, 0, B, 0, 0.0f, C, 0, false, false, + gemm_tests[i]); + } + float fastest_time = 1e10; + int fastest_index = -1; + clFinish(ctx.get_queue().handle().get()); + for( int i = 0; i < gemm_tests.size(); i++ ) { + struct timeval start, end; + gettimeofday(&start, NULL); + greentea_gpu_gemm_common(ctx_id, TransA, TransB, M, N, K, + 1.0f, A, 0, B, 0, 0.0f, C, 0, false, false, + gemm_tests[i]); + clFinish(ctx.get_queue().handle().get()); + gettimeofday(&end, NULL); + float elapsed = (end.tv_sec - start.tv_sec) * 1e6 + (end.tv_usec - start.tv_usec); + if (elapsed < fastest_time) { + fastest_time = elapsed; + fastest_index = i; + } + } + clReleaseMemObject(A); + clReleaseMemObject(B); + clReleaseMemObject(C); + + if (fastest_index >= 0) { + tuned_gemm_types[ctx_id] = gemm_tests[fastest_index]; +#ifdef GEMM_PROFILING + printf("The tuned GEMM kernel get %f GFLOPS with kernel type %d.\n", + M*N*(2*(double)K-1)/(fastest_time * 1e3), + tuned_gemm_types[ctx_id]); +#endif + } +} + +static gemm_type_t tuned_gemm_nn_types_with_image[16] = {GEMM_TYPE_NONE}; +static gemm_type_t tuned_gemm_nt_types_with_image[16] = {GEMM_TYPE_NONE}; +static gemm_type_t tuned_gemm_tn_types_with_image[16] = {GEMM_TYPE_NONE}; +static gemm_type_t tuned_gemm_tt_types_with_image[16] = {GEMM_TYPE_NONE}; + +static gemm_type_t tuned_gemm_nn_types_without_image[16] = {GEMM_TYPE_NONE}; +static gemm_type_t tuned_gemm_nt_types_without_image[16] = {GEMM_TYPE_NONE}; +static gemm_type_t tuned_gemm_tn_types_without_image[16] = {GEMM_TYPE_NONE}; +static gemm_type_t tuned_gemm_tt_types_without_image[16] = {GEMM_TYPE_NONE}; + +static void auto_tune_gemm_all(int ctx_id, bool use_fast_gemm_image) { + if(use_fast_gemm_image) { + auto_tune_gemm(ctx_id, CblasNoTrans, CblasNoTrans, tuned_gemm_nn_types_with_image, true); + auto_tune_gemm(ctx_id, CblasNoTrans, CblasTrans, tuned_gemm_nt_types_with_image, true); + auto_tune_gemm(ctx_id, CblasTrans, CblasNoTrans, tuned_gemm_tn_types_with_image, true); + auto_tune_gemm(ctx_id, CblasTrans, CblasTrans, tuned_gemm_tt_types_with_image, true); + } else { + auto_tune_gemm(ctx_id, CblasNoTrans, CblasNoTrans, tuned_gemm_nn_types_without_image, false); + auto_tune_gemm(ctx_id, CblasNoTrans, CblasTrans, tuned_gemm_nt_types_without_image, false); + auto_tune_gemm(ctx_id, CblasTrans, CblasNoTrans, tuned_gemm_tn_types_without_image, false); + auto_tune_gemm(ctx_id, CblasTrans, CblasTrans, tuned_gemm_tt_types_without_image, false); + } +} + +static boost::mutex auto_tune_gemm_mutex; + template void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int_tp M, const int_tp N, const int_tp K, const Dtype alpha, const cl_mem A, const int_tp offA, const cl_mem B, const int_tp offB, const Dtype beta, cl_mem C, - const int_tp offC) { + const int_tp offC, bool is_image_a, bool is_image_b) { + CHECK_LT(ctx_id, 16) << "Too many GPU devices."; viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + bool use_fast_gemm_image = false; + bool use_fast_gemm_buffer = false; if (ctx.devices()[0].type() == CL_DEVICE_TYPE_CPU) { Dtype* Aptr = reinterpret_cast(clEnqueueMapBuffer( @@ -203,122 +903,80 @@ void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, NULL); clEnqueueUnmapMemObject(ctx.get_queue().handle().get(), C, Cptr, 0, NULL, NULL); - } else { - int_tp lda = (TransA == CblasNoTrans) ? K : M; - int_tp ldb = (TransB == CblasNoTrans) ? N : K; - int_tp ldc = N; - -#if defined(USE_CLBLAS) - - clblasOrder clOrder = clblasRowMajor; - clblasTranspose clTransA = - (TransA == CblasNoTrans) ? clblasNoTrans : clblasTrans; - clblasTranspose clTransB = - (TransB == CblasNoTrans) ? clblasNoTrans : clblasTrans; - - cl_command_queue queue = ctx.get_queue().handle().get(); + return; + } - if (std::is_same::value) { - GREENTEA_CL_BLAS_CHECK( - clblasSgemm(clOrder, clTransA, clTransB, - M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, - C, offC, ldc, 1, &queue, 0, NULL, NULL)); - } else { - GREENTEA_CL_BLAS_CHECK( - clblasDgemm(clOrder, clTransA, clTransB, - M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, - C, offC, ldc, 1, &queue, 0, NULL, NULL)); + if (ctx.devices()[0].type() == CL_DEVICE_TYPE_GPU && + std::is_same::value) { + // Check whether can/should we use the fast gemm driver. + // There are the following considerations/restrications: + // 1. The fast gemm kernel is using image which has a size limitation. + // 2. The fast gemm kernel is using the intel sub group extension. + // 3. Currently, only the IGC compiler (the driver version is 16.xxx) + // can get better performance with the fast gemm. + // Cap at 1 MB to capture faulty OpenCL implementations (nVidia) + bool has_sub_group_ext = ctx.devices()[0].extensions().find("cl_intel_subgroups") + != std::string::npos; + if (has_sub_group_ext) { + size_t max_image_size = std::min(ctx.devices()[0].image2d_max_width(), + ctx.devices()[0].image2d_max_height()); + if (M <= max_image_size && + K <= max_image_size && + N <= max_image_size) { + use_fast_gemm_image = true; + } + use_fast_gemm_buffer = true; } + } -#elif defined(USE_CLBLAST) - - cl_command_queue queue = ctx.get_queue().handle().get(); - - clblast::Layout layout = clblast::Layout::kRowMajor; - clblast::Transpose a_transpose = (TransA == CblasNoTrans) ? - clblast::Transpose::kNo : clblast::Transpose::kYes; - clblast::Transpose b_transpose = (TransB == CblasNoTrans) ? - clblast::Transpose::kNo : clblast::Transpose::kYes; + gemm_type_t preferred_gemm_type = GEMM_TYPE_VIENNACL; +#ifdef USE_CLBLAS + preferred_gemm_type = GEMM_TYPE_CLBLAS; +#endif +#ifdef USE_CLBLAST + preferred_gemm_type = GEMM_TYPE_CLBLAST; +#endif - if (std::is_same::value) { - GREENTEA_CLBLAST_CHECK( - clblast::Gemm( - layout, a_transpose, b_transpose, - M, N, K, - alpha, - A, offA, lda, - B, offB, ldb, - beta, - C, offC, ldc, - &queue)); - } else { - GREENTEA_CLBLAST_CHECK( - clblast::Gemm( - layout, a_transpose, b_transpose, - M, N, K, - alpha, - A, offA, lda, - B, offB, ldb, - beta, - C, offC, ldc, - &queue)); + { + boost::mutex::scoped_lock lock(auto_tune_gemm_mutex); + if(use_fast_gemm_image) { + if (tuned_gemm_nn_types_with_image[ctx_id] == GEMM_TYPE_NONE) { + auto_tune_gemm_all(ctx_id, true); + } + + if (TransA == CblasNoTrans && TransB == CblasNoTrans) + preferred_gemm_type = tuned_gemm_nn_types_with_image[ctx_id]; + else if (TransA == CblasTrans && TransB == CblasNoTrans) + preferred_gemm_type = tuned_gemm_tn_types_with_image[ctx_id]; + else if (TransA == CblasNoTrans && TransB == CblasTrans) + preferred_gemm_type = tuned_gemm_nt_types_with_image[ctx_id]; + else if (TransA == CblasTrans && TransB == CblasTrans) + preferred_gemm_type = tuned_gemm_tt_types_with_image[ctx_id]; + } else if(use_fast_gemm_buffer) { + if (tuned_gemm_nn_types_without_image[ctx_id] == GEMM_TYPE_NONE) { + auto_tune_gemm_all(ctx_id, false); + } + + if (TransA == CblasNoTrans && TransB == CblasNoTrans) + preferred_gemm_type = tuned_gemm_nn_types_without_image[ctx_id]; + else if (TransA == CblasTrans && TransB == CblasNoTrans) + preferred_gemm_type = tuned_gemm_tn_types_without_image[ctx_id]; + else if (TransA == CblasNoTrans && TransB == CblasTrans) + preferred_gemm_type = tuned_gemm_nt_types_without_image[ctx_id]; + else if (TransA == CblasTrans && TransB == CblasTrans) + preferred_gemm_type = tuned_gemm_tt_types_without_image[ctx_id]; } + } -#else // default (ViennaCL) - - typedef typename viennacl::matrix_base::size_type size_type; - typedef typename viennacl::matrix_base::size_type difference_type; - - size_type A_size1 = static_cast((TransA == CblasTrans) ? K : M); - size_type A_size2 = static_cast((TransA == CblasTrans) ? M : K); - - size_type B_size1 = static_cast((TransB == CblasTrans) ? N : K); - size_type B_size2 = static_cast((TransB == CblasTrans) ? K : N); - - viennacl::matrix_base matA(A, ctx, A_size1, - size_type(0), - difference_type(1), - size_type(M), A_size2, - size_type(offA), - difference_type(1), - size_type(lda) - VCL_ROW_MAJOR); - - viennacl::matrix_base matB(B, ctx, B_size1, - size_type(0), - difference_type(1), - size_type(K), B_size2, - size_type(offB), - difference_type(1), - size_type(ldb) - VCL_ROW_MAJOR); - - viennacl::matrix_base matC(C, ctx, size_type(M), - size_type(0), - difference_type(1), - size_type(M), - size_type(N), - size_type(offC), - difference_type(1), - size_type(ldc) - VCL_ROW_MAJOR); + CHECK_EQ(use_fast_gemm_image || (!is_image_a && !is_image_b), true) + << "Invalid GEMM parameters."; - if (TransA == CblasTrans && TransB == CblasTrans) - viennacl::linalg::prod_impl(viennacl::trans(matA), viennacl::trans(matB), - matC, alpha, beta); - else if (TransA == CblasTrans && TransB == CblasNoTrans) - viennacl::linalg::prod_impl(viennacl::trans(matA), matB, matC, alpha, - beta); - else if (TransA == CblasNoTrans && TransB == CblasTrans) - viennacl::linalg::prod_impl(matA, viennacl::trans(matB), matC, alpha, - beta); - else if (TransA == CblasNoTrans && TransB == CblasNoTrans) - viennacl::linalg::prod_impl(matA, matB, matC, alpha, beta); + if (is_image_a || is_image_b) + preferred_gemm_type = GEMM_TYPE_FAST_IMAGE_32_1; -#endif // clBLAS, CLBlast, or default (ViennaCL) - } + greentea_gpu_gemm_common(ctx_id, TransA, TransB, M, N, K, alpha, A, offA, + B, offB, beta, C, offC, is_image_a, is_image_b, + preferred_gemm_type); } template void greentea_gpu_gemm(const int_tp ctx_id, @@ -329,7 +987,9 @@ template void greentea_gpu_gemm(const int_tp ctx_id, const cl_mem A, const int_tp offA, const cl_mem B, const int_tp offB, const float beta, cl_mem C, - const int_tp offC); + const int_tp offC, + const bool is_image_a = false, + const bool is_image_b = false); template void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, @@ -338,7 +998,34 @@ template void greentea_gpu_gemm(const int_tp ctx_id, const cl_mem A, const int_tp offA, const cl_mem B, const int_tp offB, const double beta, cl_mem C, - const int_tp offC); + const int_tp offC, + const bool is_image_a = false, + const bool is_image_b = false); + +template void greentea_gpu_gemm_common(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int_tp M, const int_tp N, + const int_tp K, const float alpha, + const cl_mem A, const int_tp offA, + const cl_mem B, const int_tp offB, + const float beta, cl_mem C, + const int_tp offC, + const bool is_image_a, + const bool is_image_b, + const gemm_type_t); +template void greentea_gpu_gemm_common(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int_tp M, const int_tp N, + const int_tp K, const double alpha, + const cl_mem A, const int_tp offA, + const cl_mem B, const int_tp offB, + const double beta, cl_mem C, + const int_tp offC, + const bool is_image_a, + const bool is_image_b, + const gemm_type_t); template void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, @@ -372,91 +1059,153 @@ void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, clEnqueueUnmapMemObject(ctx.get_queue().handle().get(), y, yptr, 0, NULL, NULL); } else { + if (std::is_same::value && TransA == CblasNoTrans) { + viennacl::ocl::program &program = + (Caffe::Get().GetDevice(ctx_id, false)) + ->program(); + viennacl::ocl::kernel &k = + program.get_kernel(CL_KERNEL_SELECT("matvec_mul4")); + uint row_size = M; + uint col_size = N; + size_t localsize = 128; + size_t globalsize = row_size / 4 * localsize; + + uint argId = 0; + k.arg(argId++, WrapHandle(A, &ctx)); + k.arg(argId++, offA); + k.arg(argId++, cl_uint(col_size)); + k.arg(argId++, cl_uint(col_size%4)); + k.arg(argId++, WrapHandle(x, &ctx)); + k.arg(argId++, offx); + k.arg(argId++, alpha); + k.arg(argId++, beta); + k.arg(argId++, WrapHandle(y, &ctx)); + k.arg(argId++, offy); + k.arg(argId++, viennacl::ocl::local_mem(sizeof(cl_float4) * localsize)); + + clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + k.handle().get(), 1, + NULL, + &globalsize, + &localsize, 0, NULL, + NULL); + if ((row_size % 4) != 0) { + viennacl::ocl::kernel &k_1 = + program.get_kernel(CL_KERNEL_SELECT("matvec_mul1")); + size_t localsize = 128; + size_t globalsize = row_size % 4 * localsize; + uint row_offset = row_size - (row_size % 4); + + uint argId = 0; + k_1.arg(argId++, WrapHandle(A, &ctx)); + k_1.arg(argId++, offA); + k_1.arg(argId++, cl_uint(col_size)); + k_1.arg(argId++, cl_uint(row_offset)); + k_1.arg(argId++, cl_uint(col_size%4)); + k_1.arg(argId++, WrapHandle(x, &ctx)); + k_1.arg(argId++, offx); + k_1.arg(argId++, alpha); + k_1.arg(argId++, beta); + k_1.arg(argId++, WrapHandle(y, &ctx)); + k_1.arg(argId++, offy); + k_1.arg(argId++, + viennacl::ocl::local_mem(sizeof(cl_float) * localsize)); + + clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + k_1.handle().get(), 1, + NULL, + &globalsize, + &localsize, 0, NULL, + NULL); + } + } else { #if defined(USE_CLBLAS) - clblasTranspose clTransA = - (TransA == CblasNoTrans) ? clblasNoTrans : clblasTrans; + clblasTranspose clTransA = + (TransA == CblasNoTrans) ? clblasNoTrans : clblasTrans; - cl_command_queue queue = ctx.get_queue().handle().get(); + cl_command_queue queue = ctx.get_queue().handle().get(); - if (std::is_same::value) { - GREENTEA_CL_BLAS_CHECK( - clblasSgemv(clblasRowMajor, - clTransA, M, N, alpha, A, offA, N, x, offx, 1, - beta, y, offy, 1, 1, &queue, 0, NULL, NULL)); - } else { - GREENTEA_CL_BLAS_CHECK( - clblasDgemv(clblasRowMajor, - clTransA, M, N, alpha, A, offA, N, x, offx, 1, - beta, y, offy, 1, 1, &queue, 0, NULL, NULL)); - } + if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK( + clblasSgemv(clblasRowMajor, + clTransA, M, N, alpha, A, offA, N, x, offx, 1, + beta, y, offy, 1, 1, &queue, 0, NULL, NULL)); + } else { + GREENTEA_CL_BLAS_CHECK( + clblasDgemv(clblasRowMajor, + clTransA, M, N, alpha, A, offA, N, x, offx, 1, + beta, y, offy, 1, 1, &queue, 0, NULL, NULL)); + } #elif defined(USE_CLBLAST) - cl_command_queue queue = ctx.get_queue().handle().get(); - - clblast::Layout layout = clblast::Layout::kRowMajor; - clblast::Transpose a_transpose = (TransA == CblasNoTrans) ? - clblast::Transpose::kNo : clblast::Transpose::kYes; - - const size_t ldA = N; - const size_t incx = 1; - const size_t incy = 1; - - if (std::is_same::value) { - GREENTEA_CLBLAST_CHECK( - clblast::Gemv( - layout, a_transpose, - M, N, - alpha, - A, offA, ldA, - x, offx, incx, - beta, - y, offy, incy, - &queue)); - } else { - GREENTEA_CLBLAST_CHECK( - clblast::Gemv( - layout, a_transpose, - M, N, - alpha, - A, offA, ldA, - x, offx, incx, - beta, - y, offy, incy, - &queue)); - } + cl_command_queue queue = ctx.get_queue().handle().get(); + + clblast::Layout layout = clblast::Layout::kRowMajor; + clblast::Transpose a_transpose = (TransA == CblasNoTrans) ? + clblast::Transpose::kNo : clblast::Transpose::kYes; + + const size_t ldA = N; + const size_t incx = 1; + const size_t incy = 1; + + if (std::is_same::value) { + GREENTEA_CLBLAST_CHECK( + clblast::Gemv( + layout, a_transpose, + M, N, + alpha, + A, offA, ldA, + x, offx, incx, + beta, + y, offy, incy, + &queue)); + } else { + GREENTEA_CLBLAST_CHECK( + clblast::Gemv( + layout, a_transpose, + M, N, + alpha, + A, offA, ldA, + x, offx, incx, + beta, + y, offy, incy, + &queue)); + } #else // default (ViennaCL) - typedef typename viennacl::vector_base::size_type size_type; - typedef typename viennacl::vector_base::size_type difference_type; - - viennacl::vector_base v1( - x, size_type((TransA == CblasTrans) ? M : N), size_type(offx), - difference_type(1), ctx); - viennacl::vector_base v2( - y, size_type((TransA == CblasTrans) ? N : M), size_type(offy), - difference_type(1), ctx); - viennacl::matrix_base mat(A, ctx, size_type(M), - size_type(0), - difference_type(1), - size_type(M), - size_type(N), - size_type(offA), - difference_type(1), - size_type(N) - VCL_ROW_MAJOR); - v2 *= beta; - if (TransA == CblasTrans) { - v2 += alpha * viennacl::linalg::prod(viennacl::trans(mat), v1); - } else { - v2 += alpha * viennacl::linalg::prod(mat, v1); - } + typedef typename viennacl::vector_base::size_type size_type; + typedef typename viennacl::vector_base::size_type difference_type; + + viennacl::vector_base v1( + x, size_type((TransA == CblasTrans) ? M : N), size_type(offx), + difference_type(1), ctx); + viennacl::vector_base v2( + y, size_type((TransA == CblasTrans) ? N : M), size_type(offy), + difference_type(1), ctx); + viennacl::matrix_base mat( + A, ctx, size_type(M), + size_type(0), + difference_type(1), + size_type(M), + size_type(N), + size_type(offA), + difference_type(1), + size_type(N) + VCL_ROW_MAJOR); + v2 *= beta; + if (TransA == CblasTrans) { + v2 += alpha * viennacl::linalg::prod(viennacl::trans(mat), v1); + } else { + v2 += alpha * viennacl::linalg::prod(mat, v1); + } #endif // clBLAS, CLBlast, or default (ViennaCL) + } } } From 3c9c415caa3f38a1f1daa669cdbe7056d89b2d4b Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 02:58:22 +0800 Subject: [PATCH 03/33] Prepare to support layer fusions. Signed-off-by: Zhigang Gong --- src/caffe/proto/caffe.proto | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index eeea828b0ff..cc54e8c0491 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -526,6 +526,8 @@ message BatchNormParameter { // Small value to add to the variance estimate so that we don't divide by // zero. optional float eps = 3 [default = 1e-5]; + + optional bool fused_relu = 4 [ default = false ]; } message BiasParameter { @@ -626,6 +628,15 @@ message ConvolutionParameter { // implementation; for input blobs with num_axes != 2, this option is // ignored and the ND implementation will be used.) optional bool force_nd_im2col = 17 [default = false]; + + enum FuseType { + UNFUSED = 0; + FUSED_CONV_MAX_POOLING_RELU = 1; + FUSED_CONV_RELU = 2; + FUSED_CONV_ELTWISE_RELU = 3; + } + optional FuseType fuse_type = 19 [default = UNFUSED]; // Whether to fuse convolution with other layers + optional EltwiseParameter eltwise_param = 20; } message CropParameter { @@ -682,6 +693,7 @@ message DataParameter { message DropoutParameter { optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase } // DummyDataLayer fills any number of arbitrarily shaped blobs with random @@ -879,6 +891,14 @@ message LRNParameter { CUDNN = 2; } optional Engine engine = 6 [default = DEFAULT]; + enum FuseType { + UNFUSED = 0; + FUSED_POOL_MAX = 1; + } + optional FuseType fuse_type = 7 [default = UNFUSED]; // Whether to fuse convolution with other layers + optional PoolingParameter pooling_param = 8; + optional bool unit_test_mode = 9 [default = false]; + optional bool unit_test_fuse_kernel = 10 [default = false]; } message MemoryDataParameter { From bc013a50282a0b50609c5c9f7390936b6e38fdee Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 02:51:22 +0800 Subject: [PATCH 04/33] Implement layer fusion in spatial convolution engine. Support the following fusion types: ConvolutionParameter_FuseType_FUSED_CONV_RELU - CONV + Relu ConvolutionParameter_FuseType_FUSED_CONV_MAX_POOLING_RELU - CONV + Max pooling(without padding) + Relu ConvolutionParameter_FuseType_FUSED_CONV_ELTWISE_RELU - Conv + EltWise + Relu Signed-off-by: Zhigang Gong --- include/caffe/layers/conv_spatial_layer.hpp | 30 +++- src/caffe/greentea/cl_kernels.cpp | 157 ++++++++++++-------- .../greentea/cl_kernels/conv_layer_spatial.cl | 163 +++++++++++++-------- src/caffe/layers/conv_layer_spatial.cpp | 120 ++++++++++++--- 4 files changed, 325 insertions(+), 145 deletions(-) diff --git a/include/caffe/layers/conv_spatial_layer.hpp b/include/caffe/layers/conv_spatial_layer.hpp index ef79e7d704c..5290fb65299 100644 --- a/include/caffe/layers/conv_spatial_layer.hpp +++ b/include/caffe/layers/conv_spatial_layer.hpp @@ -64,7 +64,7 @@ class ConvolutionLayerSpatial : public BaseConvolutionLayer { return 1; } virtual inline bool EqualNumBottomTopBlobs() const { - return true; + return IsFusedWithEltwiseReLU() ? false : true; } protected: @@ -187,6 +187,27 @@ class ConvolutionLayerSpatial : public BaseConvolutionLayer { const vector*>& top); std::map, cl_mem> subBufferMap; std::vector tmpSubBuffers; + + bool IsFused() const + { + return (this->layer_param_.convolution_param().fuse_type() != ConvolutionParameter_FuseType_UNFUSED); + } + + bool IsFusedWithReLU() const + { + return (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_RELU); + } + + bool IsFusedWithMaxPoolAndReLU() const + { + return (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_MAX_POOLING_RELU); + } + + bool IsFusedWithEltwiseReLU() const + { + return (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_ELTWISE_RELU); + } + #endif #endif @@ -239,6 +260,13 @@ class ConvolutionLayerSpatial : public BaseConvolutionLayer { vector kernelQueue; kernelConfig* bestKernelConfig; + + // parameters for fused eltwise layer. + EltwiseParameter_EltwiseOp op_; vector coeffs_; + Blob max_idx_; + + bool stable_prod_grad_; + }; } // namespace caffe diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index ddae5f6f898..a35ae6b2a8a 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -528,6 +528,9 @@ static std::vector> cl_kernels{ "", // NOLINT "#ifdef MULTI", // NOLINT "__kernel void CFMultiNoPadding(", // NOLINT +"#ifdef FUSED_CONV_ELTWISE", // NOLINT +"__global Dtype* eltwise_data,", // NOLINT +"#endif", // NOLINT "__global Dtype* image_data,", // NOLINT "int_tp image_offset,", // NOLINT "__global Dtype* kernel_data, int_tp kernel_offset,", // NOLINT @@ -612,7 +615,6 @@ static std::vector> cl_kernels{ "//Begin IDLF kernels below here", // NOLINT "#ifdef IDLF", // NOLINT "", // NOLINT -"#define activation_function(x) (x)", // NOLINT "#define OUT_BLOCK_SIZE (OUT_BLOCK_WIDTH*OUT_BLOCK_HEIGHT)", // NOLINT "", // NOLINT "// Each work-item computes a OUT_BLOCK_WIDTH * OUT_BLOCK_HEIGHT region of one output map.", // NOLINT @@ -621,8 +623,11 @@ static std::vector> cl_kernels{ "", // NOLINT "// NOTE: for beignet this reqd_work_group_size does not guarantee that SIMD16/8 mode will be used, the compiler could choose to use two SIMD8 threads, and if that happens the code will break.", // NOLINT "__attribute__((reqd_work_group_size(1, 1, SIMD_SIZE)))", // NOLINT -"kernel void", // NOLINT +"__kernel void", // NOLINT "convolve_simd( // __global float *inputs, __global float* weights, __global float* outputs", // NOLINT +"#ifdef FUSED_CONV_ELTWISE", // NOLINT +"__global Dtype* eltwise_data,", // NOLINT +"#endif", // NOLINT "__global float* inputs_base,", // NOLINT "filter_qualifier float* weights_base,", // NOLINT "__global float* biases_base,", // NOLINT @@ -630,7 +635,9 @@ static std::vector> cl_kernels{ "const ushort input_width,", // NOLINT "const ushort input_height,", // NOLINT "const ushort output_width,", // NOLINT -"const ushort output_height)", // NOLINT +"const ushort output_height,", // NOLINT +"const ushort last_block_width,", // NOLINT +"const ushort last_block_height)", // NOLINT "{", // NOLINT "__global float* outputs = outputs_base;", // NOLINT "__global float* inputs = inputs_base;", // NOLINT @@ -658,8 +665,10 @@ static std::vector> cl_kernels{ "", // NOLINT "uint_tp input_batch_offset = num_in_batch * input_height * input_width * TOTAL_INPUT_DEPTH_SIZE;", // NOLINT "", // NOLINT -"int curr_y = or * STRIDEY + INPUT_START_Y + ( lid / ( TILE_X / 4 ) );", // NOLINT -"int curr_x = oc * STRIDEX + INPUT_START_X + ( lid % ( TILE_X / 4 ) ) * 4;", // NOLINT +"int curr_local_y = ( lid / ( TILE_X / 4 ) );", // NOLINT +"int curr_local_x = ( lid % ( TILE_X / 4 ) ) * 4;", // NOLINT +"int curr_y = or * STRIDEY + INPUT_START_Y + curr_local_y;", // NOLINT +"int curr_x = oc * STRIDEX + INPUT_START_X + curr_local_x;", // NOLINT "#if INPUT_PAD_W != 0 || INPUT_PAD_H != 0", // NOLINT "int saved_y = curr_y;", // NOLINT "#endif", // NOLINT @@ -677,6 +686,7 @@ static std::vector> cl_kernels{ "int_tp reg = 0;", // NOLINT "LOOP(INVEC_SIZE, reg,", // NOLINT "{", // NOLINT +"if (curr_local_y + reg * TILE_Y_STRIDE < TILE_Y || INVEC_SIZE * TILE_Y_STRIDE == TILE_Y || reg < INVEC_SIZE - 1) {", // NOLINT "#if INPUT_PAD_W != 0 || INPUT_PAD_H != 0", // NOLINT "if (curr_y >= INPUT_PAD_H && curr_y < input_height + INPUT_PAD_H && curr_x + 3 >= INPUT_PAD_W && curr_x < input_width + INPUT_PAD_W) {", // NOLINT "if (curr_x < INPUT_PAD_W) {", // NOLINT @@ -706,6 +716,7 @@ static std::vector> cl_kernels{ "#else", // NOLINT "in_buf.in_vec[reg] = *(global float4*)(inputs + in_offset); // read SIMD_SIZE elements", // NOLINT "#endif", // NOLINT +"}", // NOLINT "in_offset += input_width * TILE_Y_STRIDE;", // NOLINT "});", // NOLINT "in_addr += input_height * input_width;", // NOLINT @@ -784,24 +795,50 @@ static std::vector> cl_kernels{ "if (ALIGNED_NUM_FILTERS != NUM_FILTERS && fm > 0xfffffffeul) {", // NOLINT "outputs[0] = BLOCK_IN(fm % SIMD_SIZE);", // NOLINT "}", // NOLINT -"", // NOLINT "fm = fm % ALIGNED_NUM_FILTERS;", // NOLINT "", // NOLINT "if ((ALIGNED_NUM_FILTERS == NUM_FILTERS || fm < NUM_FILTERS)) {", // NOLINT -"", // NOLINT "uint_tp out_addr = OUT_BUFF_OFFSET + ( num_in_batch * TOTAL_OUTPUT_DEPTH + fm ) * output_width * output_height;", // NOLINT "out_addr += or * output_width + oc;", // NOLINT -"float bias = biases[fm];", // NOLINT -"", // NOLINT +"float bias = biases[(fm % ALIGNED_NUM_FILTERS)];", // NOLINT +"#ifndef WRITE_PADDED_VALUES", // NOLINT +"if (or + OUT_BLOCK_HEIGHT < output_height &&", // NOLINT +"oc + OUT_BLOCK_WIDTH < output_width)", // NOLINT +"{", // NOLINT +"#endif", // NOLINT "for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) {", // NOLINT -"if (r + or >= output_height) break;", // NOLINT "for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) {", // NOLINT -"if (c + oc >= output_width) break;", // NOLINT "// this does a scattered write to SIMD_SIZE different feature maps, so that data within one map is contiguous, thus ready for input to next layer.", // NOLINT -"outputs[out_addr + r * output_width + c] = activation_function(bias + out[r * OUT_BLOCK_WIDTH + c]);", // NOLINT +"ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]);", // NOLINT +"}", // NOLINT +"}", // NOLINT +"#ifndef WRITE_PADDED_VALUES", // NOLINT +"} else if ( or + OUT_BLOCK_HEIGHT < output_height )", // NOLINT +"{", // NOLINT +"for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) {", // NOLINT +"for(uint_tp c = 0; c < last_block_width; c++) {", // NOLINT +"ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]);", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"else if ( oc + OUT_BLOCK_WIDTH < output_width )", // NOLINT +"{", // NOLINT +"for(uint_tp r = 0; r < last_block_height; r++) {", // NOLINT +"for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) {", // NOLINT +"ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]);", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"else", // NOLINT +"{", // NOLINT +"for(uint_tp r = 0; r < last_block_height; r++) {", // NOLINT +"for(uint_tp c = 0; c < last_block_width; c++) {", // NOLINT +"ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]);", // NOLINT "}", // NOLINT "}", // NOLINT "}", // NOLINT +"#endif //#ifndef WRITE_PADDED_VALUES", // NOLINT +"}", // NOLINT "}", // NOLINT "#endif", // NOLINT "", // NOLINT @@ -1016,12 +1053,12 @@ static std::vector> cl_kernels{ "", // NOLINT "// Dst resembles a cube of width x height x (output channel * batches). Each tile writes:", // NOLINT "// (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used.", // NOLINT -"__global float *out = dst", // NOLINT -"+ global_z * out_pitch_z // batch offset", // NOLINT +"int_tp out_offset = global_z * out_pitch_z // batch offset", // NOLINT "+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT "+ ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT "", // NOLINT +"__global float *out = dst + out_offset;", // NOLINT "float bias[4];", // NOLINT "float4 *bias_vec;", // NOLINT "bias_vec = (float4*)bias;", // NOLINT @@ -1031,10 +1068,10 @@ static std::vector> cl_kernels{ "{", // NOLINT "for (int i = 0; i < 8; i++)", // NOLINT "{", // NOLINT -"out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);", // NOLINT -"out[( 8+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);", // NOLINT -"out[(16+i) * out_pitch_y] = blockC20[i] + intel_sub_group_shuffle(bias[2], i);", // NOLINT -"out[(24+i) * out_pitch_y] = blockC30[i] + intel_sub_group_shuffle(bias[3], i);", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 0 + i ) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 8 + i ) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 16 + i ) * out_pitch_y, blockC20[i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 24 + i ) * out_pitch_y, blockC30[i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT "}", // NOLINT "}", // NOLINT "}", // NOLINT @@ -1176,12 +1213,12 @@ static std::vector> cl_kernels{ "", // NOLINT "// Dst resembles a cube of width x height x (output channel * batches). Each tile writes:", // NOLINT "// (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used.", // NOLINT -"__global float *out = dst", // NOLINT -"+ global_z * out_pitch_z // batch offset", // NOLINT +"int_tp out_offset = global_z * out_pitch_z // batch offset", // NOLINT "+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT "+ ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT "", // NOLINT +"__global float *out = dst + out_offset;", // NOLINT "float bias[4];", // NOLINT "float4 *bias_vec;", // NOLINT "bias_vec = (float4*)bias;", // NOLINT @@ -1191,10 +1228,10 @@ static std::vector> cl_kernels{ "{", // NOLINT "for (int i = 0; i < 8; i++)", // NOLINT "{", // NOLINT -"if ( TILE_N_LAST_DIV8 > 0 ) out[( 0+i) * out_pitch_y] = blockC[0][i] + intel_sub_group_shuffle(bias[0], i);", // NOLINT -"if ( TILE_N_LAST_DIV8 > 1 ) out[( 8+i) * out_pitch_y] = blockC[1][i] + intel_sub_group_shuffle(bias[1], i);", // NOLINT -"if ( TILE_N_LAST_DIV8 > 2 ) out[(16+i) * out_pitch_y] = blockC[2][i] + intel_sub_group_shuffle(bias[2], i);", // NOLINT -"if ( TILE_N_LAST_DIV8 > 3 ) out[(24+i) * out_pitch_y] = blockC[3][i] + intel_sub_group_shuffle(bias[3], i);", // NOLINT +"if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC[0][i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out_offset + ( 8+i) * out_pitch_y, blockC[1][i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC[2][i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out_offset + (24+i) * out_pitch_y, blockC[3][i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT "}", // NOLINT "}", // NOLINT "}", // NOLINT @@ -1341,42 +1378,43 @@ static std::vector> cl_kernels{ "", // NOLINT "// Dst resembles a cube of width x height x (output channel * batches). Each tile writes:", // NOLINT "// (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used.", // NOLINT -"__global Dtype *out = dst", // NOLINT -"+ global_z * out_pitch_z // batch offset", // NOLINT +"int_tp out_offset = global_z * out_pitch_z // batch offset", // NOLINT "+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT "+ ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT +"__global Dtype *out = dst + out_offset;", // NOLINT "", // NOLINT "Dtype bias[2];", // NOLINT "Dtype2 *bias_vec;", // NOLINT "bias_vec = (Dtype2*)bias;", // NOLINT "*bias_vec = as_float2(intel_sub_group_block_read2((__global uint *)biases + group_x * TILE_N));", // NOLINT "// Work around a potential compiler bug.", // NOLINT -"if (group_x > 0xFFFFFFFEul)", // NOLINT +"if (group_x > 0xFFFFFFFEul) {", // NOLINT "out[0] = bias[0] + bias[1];", // NOLINT +"}", // NOLINT "", // NOLINT "if (global_y * TILE_M < output_width * output_height )", // NOLINT "{", // NOLINT "#if ( ( OUT_DEPTH % TILE_N ) == 0 )", // NOLINT "for (int i = 0; i < 16; i++)", // NOLINT "{", // NOLINT -"out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);", // NOLINT -"out[(16+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);;", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT "}", // NOLINT "#elif ( ( OUT_DEPTH % 16 ) == 0 )", // NOLINT "if ( ( global_x + 1 ) < get_global_size(0) )", // NOLINT "{", // NOLINT "for ( int i = 0; i < 16; i++ )", // NOLINT "{", // NOLINT -"out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);;", // NOLINT -"out[(16+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);;", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT "}", // NOLINT "}", // NOLINT "else", // NOLINT "{", // NOLINT "for (int i = 0; i < 16; i++)", // NOLINT "{", // NOLINT -"out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);;", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT "}", // NOLINT "}", // NOLINT "#else", // NOLINT @@ -1384,8 +1422,8 @@ static std::vector> cl_kernels{ "{", // NOLINT "for ( int i = 0; i < 16; i++ )", // NOLINT "{", // NOLINT -"out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);;", // NOLINT -"out[(16+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);;", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT "}", // NOLINT "}", // NOLINT "else", // NOLINT @@ -1394,18 +1432,18 @@ static std::vector> cl_kernels{ "{", // NOLINT "for (int i = 0; i < 16 ; i++)", // NOLINT "{", // NOLINT -"out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);;", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT "}", // NOLINT "for (int i = 0; i < OUT_DEPTH % 16 ; i++)", // NOLINT "{", // NOLINT -"out[(16+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);;", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT "}", // NOLINT "}", // NOLINT "#else", // NOLINT "{", // NOLINT "for (int i = 0; i < OUT_DEPTH % 16 ; i++)", // NOLINT "{", // NOLINT -"out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);;", // NOLINT +"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT "}", // NOLINT "}", // NOLINT "#endif", // NOLINT @@ -1556,7 +1594,6 @@ static std::vector> cl_kernels{ "p4BlockB00[KERNEL_WIDTH - 1] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) );", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "}", // NOLINT -"", // NOLINT "// Perform MADs", // NOLINT "kernel_idx = 0;", // NOLINT "interleaved_y = 0;", // NOLINT @@ -1608,13 +1645,11 @@ static std::vector> cl_kernels{ "", // NOLINT "// Dst resembles a cube of width x height x (output channel * batches). Each tile writes:", // NOLINT "// (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used.", // NOLINT -"__global float *out0 = dst", // NOLINT -"+ global_z * out_pitch_z // batch offset", // NOLINT +"int_tp out0_offset = global_z * out_pitch_z // batch offset", // NOLINT "+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT "+ ( ( global_y * TILE_M + 0 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M + 0 ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT -"__global float *out1 = dst", // NOLINT -"+ global_z * out_pitch_z // batch offset", // NOLINT +"int_tp out1_offset = global_z * out_pitch_z // batch offset", // NOLINT "+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT "+ ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT @@ -1628,20 +1663,20 @@ static std::vector> cl_kernels{ "{", // NOLINT "for( int i = 0; i < 8; i++ )", // NOLINT "{", // NOLINT -"out0[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);", // NOLINT -"out0[( 8+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);", // NOLINT -"out0[(16+i) * out_pitch_y] = blockC20[i] + intel_sub_group_shuffle(bias[2], i);", // NOLINT -"out0[(24+i) * out_pitch_y] = blockC30[i] + intel_sub_group_shuffle(bias[3], i);", // NOLINT +"ACTIVATION_FUNCTION(dst, out0_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out0_offset + ( 8+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out0_offset + (16+i) * out_pitch_y, blockC20[i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out0_offset + (24+i) * out_pitch_y, blockC30[i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT "}", // NOLINT "}", // NOLINT "if( global_y * TILE_M + 1 < output_width * output_height )", // NOLINT "{", // NOLINT "for( int i = 0; i < 8; i++ )", // NOLINT "{", // NOLINT -"out1[( 0+i) * out_pitch_y] = blockC01[i] + intel_sub_group_shuffle(bias[0], i);", // NOLINT -"out1[( 8+i) * out_pitch_y] = blockC11[i] + intel_sub_group_shuffle(bias[1], i);", // NOLINT -"out1[(16+i) * out_pitch_y] = blockC21[i] + intel_sub_group_shuffle(bias[2], i);", // NOLINT -"out1[(24+i) * out_pitch_y] = blockC31[i] + intel_sub_group_shuffle(bias[3], i);", // NOLINT +"ACTIVATION_FUNCTION(dst, out1_offset + ( 0+i) * out_pitch_y, blockC01[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out1_offset + ( 8+i) * out_pitch_y, blockC11[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out1_offset + (16+i) * out_pitch_y, blockC21[i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out1_offset + (24+i) * out_pitch_y, blockC31[i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT "}", // NOLINT "}", // NOLINT "}", // NOLINT @@ -1816,13 +1851,11 @@ static std::vector> cl_kernels{ "", // NOLINT "// Dst resembles a cube of width x height x (output channel * batches). Each tile writes:", // NOLINT "// (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used.", // NOLINT -"__global float *out0 = dst", // NOLINT -"+ global_z * out_pitch_z // batch offset", // NOLINT +"int_tp out0_offset = global_z * out_pitch_z // batch offset", // NOLINT "+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT "+ ( ( global_y * TILE_M + 0 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M + 0 ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT -"__global float *out1 = dst", // NOLINT -"+ global_z * out_pitch_z // batch offset", // NOLINT +"int_tp out1_offset = global_z * out_pitch_z // batch offset", // NOLINT "+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT "+ ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT @@ -1835,20 +1868,20 @@ static std::vector> cl_kernels{ "{", // NOLINT "for( int i = 0; i < 8; i++ )", // NOLINT "{", // NOLINT -"if ( TILE_N_LAST_DIV8 > 0 ) out0[( 0+i) * out_pitch_y] = blockC0[0][i] + intel_sub_group_shuffle(bias[0], i);", // NOLINT -"if ( TILE_N_LAST_DIV8 > 1 ) out0[( 8+i) * out_pitch_y] = blockC0[1][i] + intel_sub_group_shuffle(bias[1], i);", // NOLINT -"if ( TILE_N_LAST_DIV8 > 2 ) out0[(16+i) * out_pitch_y] = blockC0[2][i] + intel_sub_group_shuffle(bias[2], i);", // NOLINT -"if ( TILE_N_LAST_DIV8 > 3 ) out0[(24+i) * out_pitch_y] = blockC0[3][i] + intel_sub_group_shuffle(bias[3], i);", // NOLINT +"if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out0_offset + ( 0+i) * out_pitch_y, blockC0[0][i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out0_offset + ( 8+i) * out_pitch_y, blockC0[1][i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out0_offset + (16+i) * out_pitch_y, blockC0[2][i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out0_offset + (24+i) * out_pitch_y, blockC0[3][i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT "}", // NOLINT "}", // NOLINT "if( global_y * TILE_M + 1 < output_width * output_height )", // NOLINT "{", // NOLINT "for( int i = 0; i < 8; i++ )", // NOLINT "{", // NOLINT -"if ( TILE_N_LAST_DIV8 > 0 ) out1[( 0+i) * out_pitch_y] = blockC1[0][i] + intel_sub_group_shuffle(bias[0], i);", // NOLINT -"if ( TILE_N_LAST_DIV8 > 1 ) out1[( 8+i) * out_pitch_y] = blockC1[1][i] + intel_sub_group_shuffle(bias[1], i);", // NOLINT -"if ( TILE_N_LAST_DIV8 > 2 ) out1[(16+i) * out_pitch_y] = blockC1[2][i] + intel_sub_group_shuffle(bias[2], i);", // NOLINT -"if ( TILE_N_LAST_DIV8 > 3 ) out1[(24+i) * out_pitch_y] = blockC1[3][i] + intel_sub_group_shuffle(bias[3], i);", // NOLINT +"if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out1_offset + ( 0+i) * out_pitch_y, blockC1[0][i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out1_offset + ( 8+i) * out_pitch_y, blockC1[1][i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out1_offset + (16+i) * out_pitch_y, blockC1[2][i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out1_offset + (24+i) * out_pitch_y, blockC1[3][i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT "}", // NOLINT "}", // NOLINT "}", // NOLINT diff --git a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl index a7e96f6f9d6..cf9912c1838 100644 --- a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl +++ b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl @@ -31,6 +31,9 @@ __kernel void TEMPLATE(conv_layer_spatial_phony,Dtype)(Dtype arg) { #ifdef MULTI __kernel void CFMultiNoPadding( +#ifdef FUSED_CONV_ELTWISE + __global Dtype* eltwise_data, +#endif __global Dtype* image_data, int_tp image_offset, __global Dtype* kernel_data, int_tp kernel_offset, @@ -115,7 +118,6 @@ __kernel void CFMultiNoPadding( //Begin IDLF kernels below here #ifdef IDLF -#define activation_function(x) (x) #define OUT_BLOCK_SIZE (OUT_BLOCK_WIDTH*OUT_BLOCK_HEIGHT) // Each work-item computes a OUT_BLOCK_WIDTH * OUT_BLOCK_HEIGHT region of one output map. @@ -124,8 +126,11 @@ __kernel void CFMultiNoPadding( // NOTE: for beignet this reqd_work_group_size does not guarantee that SIMD16/8 mode will be used, the compiler could choose to use two SIMD8 threads, and if that happens the code will break. __attribute__((reqd_work_group_size(1, 1, SIMD_SIZE))) -kernel void +__kernel void convolve_simd( // __global float *inputs, __global float* weights, __global float* outputs +#ifdef FUSED_CONV_ELTWISE + __global Dtype* eltwise_data, +#endif __global float* inputs_base, filter_qualifier float* weights_base, __global float* biases_base, @@ -133,7 +138,9 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo const ushort input_width, const ushort input_height, const ushort output_width, - const ushort output_height) + const ushort output_height, + const ushort last_block_width, + const ushort last_block_height) { __global float* outputs = outputs_base; __global float* inputs = inputs_base; @@ -161,8 +168,10 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo uint_tp input_batch_offset = num_in_batch * input_height * input_width * TOTAL_INPUT_DEPTH_SIZE; - int curr_y = or * STRIDEY + INPUT_START_Y + ( lid / ( TILE_X / 4 ) ); - int curr_x = oc * STRIDEX + INPUT_START_X + ( lid % ( TILE_X / 4 ) ) * 4; + int curr_local_y = ( lid / ( TILE_X / 4 ) ); + int curr_local_x = ( lid % ( TILE_X / 4 ) ) * 4; + int curr_y = or * STRIDEY + INPUT_START_Y + curr_local_y; + int curr_x = oc * STRIDEX + INPUT_START_X + curr_local_x; #if INPUT_PAD_W != 0 || INPUT_PAD_H != 0 int saved_y = curr_y; #endif @@ -180,6 +189,7 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo int_tp reg = 0; LOOP(INVEC_SIZE, reg, { + if (curr_local_y + reg * TILE_Y_STRIDE < TILE_Y || INVEC_SIZE * TILE_Y_STRIDE == TILE_Y || reg < INVEC_SIZE - 1) { #if INPUT_PAD_W != 0 || INPUT_PAD_H != 0 if (curr_y >= INPUT_PAD_H && curr_y < input_height + INPUT_PAD_H && curr_x + 3 >= INPUT_PAD_W && curr_x < input_width + INPUT_PAD_W) { if (curr_x < INPUT_PAD_W) { @@ -209,6 +219,7 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo #else in_buf.in_vec[reg] = *(global float4*)(inputs + in_offset); // read SIMD_SIZE elements #endif + } in_offset += input_width * TILE_Y_STRIDE; }); in_addr += input_height * input_width; @@ -287,24 +298,50 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo if (ALIGNED_NUM_FILTERS != NUM_FILTERS && fm > 0xfffffffeul) { outputs[0] = BLOCK_IN(fm % SIMD_SIZE); } - fm = fm % ALIGNED_NUM_FILTERS; - + if ((ALIGNED_NUM_FILTERS == NUM_FILTERS || fm < NUM_FILTERS)) { - - uint_tp out_addr = OUT_BUFF_OFFSET + ( num_in_batch * TOTAL_OUTPUT_DEPTH + fm ) * output_width * output_height; - out_addr += or * output_width + oc; - float bias = biases[fm]; - + uint_tp out_addr = OUT_BUFF_OFFSET + ( num_in_batch * TOTAL_OUTPUT_DEPTH + fm ) * output_width * output_height; + out_addr += or * output_width + oc; + float bias = biases[(fm % ALIGNED_NUM_FILTERS)]; +#ifndef WRITE_PADDED_VALUES + if (or + OUT_BLOCK_HEIGHT < output_height && + oc + OUT_BLOCK_WIDTH < output_width) + { +#endif for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) { - if (r + or >= output_height) break; for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) { - if (c + oc >= output_width) break; // this does a scattered write to SIMD_SIZE different feature maps, so that data within one map is contiguous, thus ready for input to next layer. - outputs[out_addr + r * output_width + c] = activation_function(bias + out[r * OUT_BLOCK_WIDTH + c]); + ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]); + } + } +#ifndef WRITE_PADDED_VALUES + } else if ( or + OUT_BLOCK_HEIGHT < output_height ) + { + for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) { + for(uint_tp c = 0; c < last_block_width; c++) { + ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]); } } } + else if ( oc + OUT_BLOCK_WIDTH < output_width ) + { + for(uint_tp r = 0; r < last_block_height; r++) { + for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) { + ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]); + } + } + } + else + { + for(uint_tp r = 0; r < last_block_height; r++) { + for(uint_tp c = 0; c < last_block_width; c++) { + ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]); + } + } + } +#endif //#ifndef WRITE_PADDED_VALUES + } } #endif @@ -554,12 +591,12 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Dst resembles a cube of width x height x (output channel * batches). Each tile writes: // (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used. - __global float *out = dst - + global_z * out_pitch_z // batch offset + int_tp out_offset = global_z * out_pitch_z // batch offset + ( group_x * TILE_N ) * out_pitch_y // channel offset + ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset + __global float *out = dst + out_offset; float bias[4]; float4 *bias_vec; bias_vec = (float4*)bias; @@ -569,10 +606,10 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) { for (int i = 0; i < 8; i++) { - out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i); - out[( 8+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i); - out[(16+i) * out_pitch_y] = blockC20[i] + intel_sub_group_shuffle(bias[2], i); - out[(24+i) * out_pitch_y] = blockC30[i] + intel_sub_group_shuffle(bias[3], i); + ACTIVATION_FUNCTION(dst, out_offset + ( 0 + i ) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); + ACTIVATION_FUNCTION(dst, out_offset + ( 8 + i ) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); + ACTIVATION_FUNCTION(dst, out_offset + ( 16 + i ) * out_pitch_y, blockC20[i] + intel_sub_group_shuffle(bias[2], i)); + ACTIVATION_FUNCTION(dst, out_offset + ( 24 + i ) * out_pitch_y, blockC30[i] + intel_sub_group_shuffle(bias[3], i)); } } } @@ -714,12 +751,12 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Dst resembles a cube of width x height x (output channel * batches). Each tile writes: // (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used. - __global float *out = dst - + global_z * out_pitch_z // batch offset + int_tp out_offset = global_z * out_pitch_z // batch offset + ( group_x * TILE_N ) * out_pitch_y // channel offset + ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset + __global float *out = dst + out_offset; float bias[4]; float4 *bias_vec; bias_vec = (float4*)bias; @@ -729,10 +766,10 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) { for (int i = 0; i < 8; i++) { - if ( TILE_N_LAST_DIV8 > 0 ) out[( 0+i) * out_pitch_y] = blockC[0][i] + intel_sub_group_shuffle(bias[0], i); - if ( TILE_N_LAST_DIV8 > 1 ) out[( 8+i) * out_pitch_y] = blockC[1][i] + intel_sub_group_shuffle(bias[1], i); - if ( TILE_N_LAST_DIV8 > 2 ) out[(16+i) * out_pitch_y] = blockC[2][i] + intel_sub_group_shuffle(bias[2], i); - if ( TILE_N_LAST_DIV8 > 3 ) out[(24+i) * out_pitch_y] = blockC[3][i] + intel_sub_group_shuffle(bias[3], i); + if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC[0][i] + intel_sub_group_shuffle(bias[0], i)); + if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out_offset + ( 8+i) * out_pitch_y, blockC[1][i] + intel_sub_group_shuffle(bias[1], i)); + if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC[2][i] + intel_sub_group_shuffle(bias[2], i)); + if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out_offset + (24+i) * out_pitch_y, blockC[3][i] + intel_sub_group_shuffle(bias[3], i)); } } } @@ -897,42 +934,43 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Dst resembles a cube of width x height x (output channel * batches). Each tile writes: // (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used. - __global Dtype *out = dst - + global_z * out_pitch_z // batch offset + int_tp out_offset = global_z * out_pitch_z // batch offset + ( group_x * TILE_N ) * out_pitch_y // channel offset + ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset + __global Dtype *out = dst + out_offset; Dtype bias[2]; Dtype2 *bias_vec; bias_vec = (Dtype2*)bias; *bias_vec = as_float2(intel_sub_group_block_read2((__global uint *)biases + group_x * TILE_N)); // Work around a potential compiler bug. - if (group_x > 0xFFFFFFFEul) + if (group_x > 0xFFFFFFFEul) { out[0] = bias[0] + bias[1]; + } if (global_y * TILE_M < output_width * output_height ) { #if ( ( OUT_DEPTH % TILE_N ) == 0 ) for (int i = 0; i < 16; i++) { - out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i); - out[(16+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);; + ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); + ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); } #elif ( ( OUT_DEPTH % 16 ) == 0 ) if ( ( global_x + 1 ) < get_global_size(0) ) { for ( int i = 0; i < 16; i++ ) { - out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);; - out[(16+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);; + ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); + ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); } } else { for (int i = 0; i < 16; i++) { - out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);; + ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); } } #else @@ -940,8 +978,8 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) { for ( int i = 0; i < 16; i++ ) { - out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);; - out[(16+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);; + ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); + ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); } } else @@ -950,18 +988,18 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) { for (int i = 0; i < 16 ; i++) { - out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);; + ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); } for (int i = 0; i < OUT_DEPTH % 16 ; i++) { - out[(16+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i);; + ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); } } #else { for (int i = 0; i < OUT_DEPTH % 16 ; i++) { - out[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i);; + ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); } } #endif @@ -1122,7 +1160,6 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) p4BlockB00[KERNEL_WIDTH - 1] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) ); src1_read += WIDTH1 * 2; } - // Perform MADs kernel_idx = 0; interleaved_y = 0; @@ -1174,13 +1211,11 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Dst resembles a cube of width x height x (output channel * batches). Each tile writes: // (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used. - __global float *out0 = dst - + global_z * out_pitch_z // batch offset + int_tp out0_offset = global_z * out_pitch_z // batch offset + ( group_x * TILE_N ) * out_pitch_y // channel offset + ( ( global_y * TILE_M + 0 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M + 0 ) % output_width ) + OUT_PADDING_LEFT; // x offset - __global float *out1 = dst - + global_z * out_pitch_z // batch offset + int_tp out1_offset = global_z * out_pitch_z // batch offset + ( group_x * TILE_N ) * out_pitch_y // channel offset + ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset @@ -1194,20 +1229,20 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) { for( int i = 0; i < 8; i++ ) { - out0[( 0+i) * out_pitch_y] = blockC00[i] + intel_sub_group_shuffle(bias[0], i); - out0[( 8+i) * out_pitch_y] = blockC10[i] + intel_sub_group_shuffle(bias[1], i); - out0[(16+i) * out_pitch_y] = blockC20[i] + intel_sub_group_shuffle(bias[2], i); - out0[(24+i) * out_pitch_y] = blockC30[i] + intel_sub_group_shuffle(bias[3], i); + ACTIVATION_FUNCTION(dst, out0_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); + ACTIVATION_FUNCTION(dst, out0_offset + ( 8+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); + ACTIVATION_FUNCTION(dst, out0_offset + (16+i) * out_pitch_y, blockC20[i] + intel_sub_group_shuffle(bias[2], i)); + ACTIVATION_FUNCTION(dst, out0_offset + (24+i) * out_pitch_y, blockC30[i] + intel_sub_group_shuffle(bias[3], i)); } } if( global_y * TILE_M + 1 < output_width * output_height ) { for( int i = 0; i < 8; i++ ) { - out1[( 0+i) * out_pitch_y] = blockC01[i] + intel_sub_group_shuffle(bias[0], i); - out1[( 8+i) * out_pitch_y] = blockC11[i] + intel_sub_group_shuffle(bias[1], i); - out1[(16+i) * out_pitch_y] = blockC21[i] + intel_sub_group_shuffle(bias[2], i); - out1[(24+i) * out_pitch_y] = blockC31[i] + intel_sub_group_shuffle(bias[3], i); + ACTIVATION_FUNCTION(dst, out1_offset + ( 0+i) * out_pitch_y, blockC01[i] + intel_sub_group_shuffle(bias[0], i)); + ACTIVATION_FUNCTION(dst, out1_offset + ( 8+i) * out_pitch_y, blockC11[i] + intel_sub_group_shuffle(bias[1], i)); + ACTIVATION_FUNCTION(dst, out1_offset + (16+i) * out_pitch_y, blockC21[i] + intel_sub_group_shuffle(bias[2], i)); + ACTIVATION_FUNCTION(dst, out1_offset + (24+i) * out_pitch_y, blockC31[i] + intel_sub_group_shuffle(bias[3], i)); } } } @@ -1382,13 +1417,11 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Dst resembles a cube of width x height x (output channel * batches). Each tile writes: // (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used. - __global float *out0 = dst - + global_z * out_pitch_z // batch offset + int_tp out0_offset = global_z * out_pitch_z // batch offset + ( group_x * TILE_N ) * out_pitch_y // channel offset + ( ( global_y * TILE_M + 0 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M + 0 ) % output_width ) + OUT_PADDING_LEFT; // x offset - __global float *out1 = dst - + global_z * out_pitch_z // batch offset + int_tp out1_offset = global_z * out_pitch_z // batch offset + ( group_x * TILE_N ) * out_pitch_y // channel offset + ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset @@ -1401,20 +1434,20 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) { for( int i = 0; i < 8; i++ ) { - if ( TILE_N_LAST_DIV8 > 0 ) out0[( 0+i) * out_pitch_y] = blockC0[0][i] + intel_sub_group_shuffle(bias[0], i); - if ( TILE_N_LAST_DIV8 > 1 ) out0[( 8+i) * out_pitch_y] = blockC0[1][i] + intel_sub_group_shuffle(bias[1], i); - if ( TILE_N_LAST_DIV8 > 2 ) out0[(16+i) * out_pitch_y] = blockC0[2][i] + intel_sub_group_shuffle(bias[2], i); - if ( TILE_N_LAST_DIV8 > 3 ) out0[(24+i) * out_pitch_y] = blockC0[3][i] + intel_sub_group_shuffle(bias[3], i); + if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out0_offset + ( 0+i) * out_pitch_y, blockC0[0][i] + intel_sub_group_shuffle(bias[0], i)); + if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out0_offset + ( 8+i) * out_pitch_y, blockC0[1][i] + intel_sub_group_shuffle(bias[1], i)); + if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out0_offset + (16+i) * out_pitch_y, blockC0[2][i] + intel_sub_group_shuffle(bias[2], i)); + if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out0_offset + (24+i) * out_pitch_y, blockC0[3][i] + intel_sub_group_shuffle(bias[3], i)); } } if( global_y * TILE_M + 1 < output_width * output_height ) { for( int i = 0; i < 8; i++ ) { - if ( TILE_N_LAST_DIV8 > 0 ) out1[( 0+i) * out_pitch_y] = blockC1[0][i] + intel_sub_group_shuffle(bias[0], i); - if ( TILE_N_LAST_DIV8 > 1 ) out1[( 8+i) * out_pitch_y] = blockC1[1][i] + intel_sub_group_shuffle(bias[1], i); - if ( TILE_N_LAST_DIV8 > 2 ) out1[(16+i) * out_pitch_y] = blockC1[2][i] + intel_sub_group_shuffle(bias[2], i); - if ( TILE_N_LAST_DIV8 > 3 ) out1[(24+i) * out_pitch_y] = blockC1[3][i] + intel_sub_group_shuffle(bias[3], i); + if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out1_offset + ( 0+i) * out_pitch_y, blockC1[0][i] + intel_sub_group_shuffle(bias[0], i)); + if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out1_offset + ( 8+i) * out_pitch_y, blockC1[1][i] + intel_sub_group_shuffle(bias[1], i)); + if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out1_offset + (16+i) * out_pitch_y, blockC1[2][i] + intel_sub_group_shuffle(bias[2], i)); + if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out1_offset + (24+i) * out_pitch_y, blockC1[3][i] + intel_sub_group_shuffle(bias[3], i)); } } } diff --git a/src/caffe/layers/conv_layer_spatial.cpp b/src/caffe/layers/conv_layer_spatial.cpp index 4c044fe4302..ff4493d1c5e 100644 --- a/src/caffe/layers/conv_layer_spatial.cpp +++ b/src/caffe/layers/conv_layer_spatial.cpp @@ -23,9 +23,10 @@ #include // #define TEST_ALL_KERNELS + namespace caffe { -#define ALIGN(val, N) (((val) + (N) - 1) & ~((N) - 1)) +#define ALIGN(val,N) ( ( (val) + (N) - 1 ) & ~( (N) - 1 ) ) template void ConvolutionLayerSpatial::compute_output_shape() { @@ -65,11 +66,19 @@ void ConvolutionLayerSpatial::LayerSetUp( dilation_w_ = dilation_data[1]; M_ = this->num_output_ / this->group_; K_ = this->channels_ * kernel_h_ * kernel_w_ / this->group_; - swizzled_weights_blob_.Reshape((this->num_output_ + 15) & ~15, + swizzled_weights_blob_.Reshape(ALIGN(this->num_output_, 16), this->channels_, - kernel_h_, (kernel_w_ + 1) & ~1); + kernel_h_, ALIGN(kernel_w_, 2)); swizzled_weights_ = NULL; bias_ = NULL; + + if (IsFusedWithEltwiseReLU()) { + CHECK(this->layer_param().eltwise_param().coeff_size() == 0); + CHECK(bottom.size() == 2); + op_ = this->layer_param_.eltwise_param().operation(); + CHECK(op_ == EltwiseParameter_EltwiseOp_SUM); + } + if (std::getenv("CLCAFFE_CACHE_PATH")) cache_path_ << std::getenv("CLCAFFE_CACHE_PATH"); else if (std::getenv("VIENNACL_CACHE_PATH")) @@ -97,7 +106,13 @@ void ConvolutionLayerSpatial::LayerSetUp( template void ConvolutionLayerSpatial::Reshape(const vector*>& bottom, const vector*>& top) { - BaseConvolutionLayer::Reshape(bottom, top); + //printf("handle layer %s bottom size %ld \n", this->layer_param_.name().c_str(), bottom.size()); + if (IsFusedWithEltwiseReLU()) { + const vector*> bottom_image(bottom.begin(), bottom.end() - 1); + BaseConvolutionLayer::Reshape(bottom_image, top); + } else { + BaseConvolutionLayer::Reshape(bottom, top); + } height_ = bottom[0]->shape(this->channel_axis_ + 1); width_ = bottom[0]->shape(this->channel_axis_ + 2); const int_tp kernel_extent_h = dilation_h_ * (kernel_h_ - 1) + 1; @@ -140,6 +155,7 @@ template void ConvolutionLayerSpatial::Forward_cpu( const vector*>& bottom, const vector*>& top) { const Dtype* weight = this->blobs_[0]->cpu_data(); + CHECK(IsFusedWithEltwiseReLU() == false && IsFusedWithReLU() == false); for (int_tp i = 0; i < bottom.size(); ++i) { const Dtype* bottom_data = bottom[i]->cpu_data(); Dtype* top_data = top[i]->mutable_cpu_data(); @@ -208,7 +224,8 @@ void ConvolutionLayerSpatial::Backward_cpu( template<> void ConvolutionLayerSpatial::generate_key() { std::stringstream keyBuilder; - keyBuilder << kernel_w_ << "_" + keyBuilder << this->layer_param_.convolution_param().fuse_type() << "_" + << kernel_w_ << "_" << kernel_h_ << "_" << channels_ << "_" << group_ << "_" @@ -373,7 +390,7 @@ void ConvolutionLayerSpatial::swizzleWeights( * kernel_w_ + c) * M_ + od] = weight_cpu[((od * this->channels_ + id) * kernel_h_ + r) * kernel_w_ + c ]; - interleaveMatrix(cpu_swizzled_weight, tmpSwizzledWeight, + interleaveMatrix( cpu_swizzled_weight, tmpSwizzledWeight, kernel_w_ * kernel_h_ * this->channels_, M_, interleavedRows, nonInterleavedRows, blockWidth, rowAlignment); free(tmpSwizzledWeight); @@ -414,9 +431,8 @@ bool ConvolutionLayerSpatial::create_basic_kernel( workItemOutput[1] = 1; workItemOutput[2] = 1; - kernel_name_ = "U"; + kernel_name_ = "BASIC_"; kernel_name_ += kernelUKey.c_str(); - kernel_name_ += "_BASIC"; // Build list of options and defines optionsString.str(""); @@ -433,6 +449,14 @@ bool ConvolutionLayerSpatial::create_basic_kernel( << " -D " << kernelDef.c_str() << " -D CFMultiNoPadding=" << kernel_name_; + if (IsFusedWithEltwiseReLU()) { + optionsString << " -DFUSED_CONV_RELU=1 -DFUSED_CONV_ELTWISE=1"; + } + + if (IsFusedWithReLU()) { + optionsString << " -DFUSED_CONV_RELU=1"; + } + viennacl::ocl::context &ctx = viennacl::ocl::get_context(this->device_->id()); if (IsBeignet(&ctx)) optionsString << " -D__BEIGNET__"; @@ -530,6 +554,9 @@ cl_int ConvolutionLayerSpatial::convolve( int_tp kernel_offset = kernel_h_ * kernel_w_ * (channels_ / group_) * M_ * g; cl_uint argIdx = 0; + if (IsFusedWithEltwiseReLU()) + kernel.arg(argIdx++, WrapHandle((cl_mem) bottom[1]->gpu_data(), &ctx)); + try { setBufferKernelArg(bottom, top, &kernel, argIdx++, &ctx, (cl_mem) bottom_data, @@ -562,6 +589,12 @@ cl_int ConvolutionLayerSpatial::convolve( kernel.arg(argIdx++, (uint16_t)output_h_); const int_tp output_block_w = config->workItem_output[0]; const int_tp output_block_h = config->workItem_output[1]; + const int_tp last_block_width = ((output_w_ % output_block_w) == 0) ? + output_block_w : output_w_ % output_block_w; + const int_tp last_block_height =((output_h_ % output_block_h) == 0) ? + output_block_h : output_h_ % output_block_h; + kernel.arg(argIdx++, (uint16_t)last_block_width); + kernel.arg(argIdx++, (uint16_t)last_block_height); size_t global_size[3] = { (size_t) (output_w_ + output_block_w - 1) / output_block_w, (size_t) (output_h_ + output_block_h - 1) / output_block_h, (size_t) config->global_work_size[2]}; @@ -594,6 +627,9 @@ cl_int ConvolutionLayerSpatial::convolve( int_tp output_image_offset = output_w_ * output_h_ * M_ * g; cl_uint argIdx = 0; + if (IsFusedWithEltwiseReLU()) + kernel.arg(argIdx++, WrapHandle((cl_mem) bottom[1]->gpu_data(), &ctx)); + int_tp kernel_offset = kernel_h_ * kernel_w_ * (channels_ / group_) * M_ * g; try { @@ -626,6 +662,31 @@ cl_int ConvolutionLayerSpatial::convolve( kernel.arg(argIdx++, (uint16_t)height_); kernel.arg(argIdx++, (uint16_t)output_w_); kernel.arg(argIdx++, (uint16_t)output_h_); + int out_pitch_y = output_w_ * output_h_; + int out_pitch_z = out_pitch_y * M_; + int aligned_input_size = height_ * width_ * channels_ / group_; + int slice_pitch = width_ * height_; + kernel.arg(argIdx++, (uint32_t)out_pitch_y); + kernel.arg(argIdx++, (uint32_t)out_pitch_z); + kernel.arg(argIdx++, (uint32_t)aligned_input_size); + kernel.arg(argIdx++, (uint32_t)slice_pitch); + + int blockM = config->workItem_output[0]; + int blockK = config->workItem_output[1]; + int blockN = config->workItem_output[2]; + int_tp alignedFilterWidth = ALIGN(M_, blockN); + int_tp alignedExpandHeight = ALIGN(output_w_ * output_h_, blockM); + int_tp globalWorkSizeDX = blockN; + int_tp globalWorkSizeDY = blockM; + size_t sgemm_m = alignedExpandHeight; + size_t sgemm_n = alignedFilterWidth; + size_t gx = (size_t) ceil( (float) sgemm_n / + (float) globalWorkSizeDX ); + size_t gy = (size_t) ceil( (float) sgemm_m / + (float) globalWorkSizeDY ); + gy = ALIGN(gy, blockK); + size_t global_size[3] = { gx, gy, config->global_work_size[2] }; + viennacl::ocl::context &ctx = viennacl::ocl::get_context(this->device_->id()); int out_pitch_y = output_w_ * output_h_; @@ -680,8 +741,12 @@ cl_int ConvolutionLayerSpatial::convolve( + output_w_ * output_h_ * M_ * g; cl_uint argIdx = 0; - int_tp kernel_offset = kernel_h_ * kernel_w_ - * (channels_ / group_) * M_ * g; + if (IsFusedWithEltwiseReLU()) + kernel.arg(argIdx++, + WrapHandle((cl_mem) bottom[1]->gpu_data(), &ctx)); + + int_tp kernel_offset = kernel_h_ * kernel_w_ * (channels_ / group_) * M_ + * g; kernel.arg(argIdx++, WrapHandle((cl_mem) bottom_data, &ctx)); kernel.arg(argIdx++, image_offset); @@ -792,10 +857,16 @@ bool ConvolutionLayerSpatial::verify_result( return true; else if (config->tested) return false; - greentea_memset(this->device_->id(), top[index]->count(), 0, + + greentea_memset(this->device_->id(), top[index]->count() * sizeof(float), + 0xff, (cl_mem)top[index]->mutable_gpu_data(), 0); config->executionTime = timed_convolve(bottom, top, index, numImages, config); + // Currently we can't do verification when conv is fused because the results + // won't match the results of forward_gpu_gemm. Need more work to fix it. + if (IsFused()) + return true; const float *verify_data = verify_blob.cpu_data(); const float *data = top[index]->cpu_data(); @@ -880,6 +951,13 @@ bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( " -DTILE_N_LAST=" << M_ % 32 << " -DTILE_N_LAST_DIV8=" << (M_ % 32) / 8; + if (IsFusedWithEltwiseReLU()) { + optionsString << " -DFUSED_CONV_RELU=1 -DFUSED_CONV_ELTWISE=1"; + } + + if (IsFusedWithReLU()) { + optionsString << " -DFUSED_CONV_RELU=1"; + } optionsString << " -DINPUT_PAD_W=" << pad_w_ << " -DINPUT_PAD_H=" << pad_h_; size_t gz = num_batches; size_t global_size[3] = { 0, 0, gz }; @@ -888,9 +966,6 @@ bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( viennacl::ocl::context &ctx = viennacl::ocl::get_context(this->device_->id()); if (IsBeignet(&ctx)) optionsString << " -D__BEIGNET__"; - else - optionsString << - " -cl-no-subgroup-ifp "; string options = optionsString.str(); viennacl::ocl::program & program = submit_conv_spatial_program(&ctx, @@ -938,7 +1013,7 @@ bool ConvolutionLayerSpatial::setup_IDLF( int_tp output_block_height = blockHeight; int_tp num_batches = num_; - kernel_name_ = "U"; + kernel_name_ = "IDLF_"; kernel_name_ += kernelUKey.c_str(); if (simd_size == 16) @@ -986,10 +1061,18 @@ bool ConvolutionLayerSpatial::setup_IDLF( optionsString << " -DINPUT_PAD_W=" << pad_w_ << " -DINPUT_PAD_H=" << pad_h_; - string options = optionsString.str(); + if (IsFusedWithEltwiseReLU()) { + optionsString << " -DFUSED_CONV_RELU=1 -DFUSED_CONV_ELTWISE=1"; + } + + if (IsFusedWithReLU()) { + optionsString << " -DFUSED_CONV_RELU=1"; + } + viennacl::ocl::context &ctx = viennacl::ocl::get_context(this->device_->id()); if (IsBeignet(&ctx)) optionsString << " -D__BEIGNET__"; + string options = optionsString.str(); viennacl::ocl::program & program = submit_conv_spatial_program(&ctx, kernel_name_, options); @@ -1346,7 +1429,10 @@ void ConvolutionLayerSpatial::Forward_gpu( if (bias_term_) bias_ = this->blobs_[1]->gpu_data(); - for (int_tp i = 0; i < bottom.size(); ++i) { + int bottom_size = bottom.size(); + if (IsFusedWithEltwiseReLU()) + bottom_size = 1; + for (int_tp i = 0; i < bottom_size; ++i) { bottom_index_ = i; bottom_data = bottom[i]->gpu_data(); top_data = top[i]->mutable_gpu_data(); From e270af670b2214353fa955d0011db1ff5cebfd56 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 02:55:00 +0800 Subject: [PATCH 05/33] Add LRN fusion with Pooling layer. Signed-off-by: Zhigang Gong --- include/caffe/layers/lrn_layer.hpp | 27 +++++ src/caffe/greentea/cl_kernels.cpp | 151 +++++++++++++++++++++++++++- src/caffe/greentea/cl_kernels/lrn.cl | 110 +++++++++++++++++++++ src/caffe/greentea/cl_kernels/pooling.cl | 49 +++++++-- src/caffe/layers/conv_layer_spatial.cpp | 25 ----- src/caffe/layers/lrn_layer.cpp | 43 +++++++- src/caffe/layers/lrn_layer.cu | 164 +++++++++++++++++++++++++++---- src/caffe/test/test_lrn_layer.cpp | 70 +++++++++++++ 8 files changed, 582 insertions(+), 57 deletions(-) diff --git a/include/caffe/layers/lrn_layer.hpp b/include/caffe/layers/lrn_layer.hpp index eb4a1a31304..f059a7d60fd 100644 --- a/include/caffe/layers/lrn_layer.hpp +++ b/include/caffe/layers/lrn_layer.hpp @@ -33,6 +33,18 @@ class LRNLayer : public Layer { virtual inline int_tp ExactNumBottomBlobs() const { return 1; } virtual inline int_tp ExactNumTopBlobs() const { return 1; } + bool IsFused() const + { + return (this->layer_param_.lrn_param().fuse_type() != LRNParameter_FuseType_UNFUSED); + } + + + bool IsFusedWithPoolMax() const + { + return (this->layer_param_.lrn_param().fuse_type() == LRNParameter_FuseType_FUSED_POOL_MAX); + } + + protected: virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); @@ -56,6 +68,10 @@ class LRNLayer : public Layer { virtual void WithinChannelBackward(const vector*>& top, const vector& propagate_down, const vector*>& bottom); + virtual void CrossChannelForward_fuse_pooling_gpu(const vector*>& bottom, + const vector*>& top, + bool use_fuse); + int_tp size_; int_tp pre_pad_; Dtype alpha_; @@ -66,6 +82,17 @@ class LRNLayer : public Layer { int_tp height_; int_tp width_; + int_tp pool_w_; + int_tp pool_h_; + int_tp pool_stride_w_; + int_tp pool_stride_h_; + int_tp pooled_width_; + int_tp pooled_height_; + bool fuse_tuned_; + bool tuned_use_fuse_; + Blob lrn_top_blob_; + vector*> lrn_top_vec_; // for pooling fusing + // Fields used for normalization ACROSS_CHANNELS // scale_ stores the int_tpermediate summing results Blob scale_; diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index a35ae6b2a8a..474ce231e3b 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -4959,6 +4959,116 @@ static std::vector> cl_kernels{ "}", // NOLINT "}", // NOLINT "", // NOLINT +"#define SIMD_WIDTH 16", // NOLINT +"#define TILE_W SIMD_WIDTH", // NOLINT +"#define TILE_H 8", // NOLINT +"", // NOLINT +"#ifndef BEIGNET", // NOLINT +"__attribute__((intel_reqd_sub_group_size(SIMD_WIDTH)))", // NOLINT +"#endif", // NOLINT +"// Fuse pooling max layer into LRN across channel layer.", // NOLINT +"// Currently, only support non-padding, non-dilation mode and pool_w/h == pool_stride_w + 1.", // NOLINT +"// This kernel only get better performance on those Intel platforms with edram.", // NOLINT +"__kernel void TEMPLATE(lrn_fuse_pool_max,Dtype)(", // NOLINT +"__global const Dtype* in,", // NOLINT +"const int_tp channels,", // NOLINT +"const int_tp height, const int_tp width,", // NOLINT +"const int_tp tiled_height, int_tp tiled_width,", // NOLINT +"const int_tp size,", // NOLINT +"const Dtype alpha_over_size, const Dtype k,", // NOLINT +"__global Dtype* const out,", // NOLINT +"const Dtype negative_beta,", // NOLINT +"const int_tp pool_h, const int_tp pool_w, const int_tp pool_stride_h, int_tp pool_stride_w,", // NOLINT +"const int_tp pooled_height, const int_tp pooled_width,", // NOLINT +"const int_tp tile_pooled_block_h, const int_tp tile_pooled_block_w) {", // NOLINT +"// find out the local offset", // NOLINT +"const int_tp block_x = get_global_id(0) % tiled_width;", // NOLINT +"const int_tp block_y = (get_global_id(0) / tiled_width) % tiled_height;", // NOLINT +"const int_tp n = get_global_id(0) / (tiled_width * tiled_height);", // NOLINT +"", // NOLINT +"const int_tp w = block_x * tile_pooled_block_w * pool_stride_w;", // NOLINT +"const int_tp h = block_y * tile_pooled_block_h * pool_stride_h;", // NOLINT +"const int_tp offset = (n * channels * height + h) * width + w;", // NOLINT +"const int_tp out_h = block_y * tile_pooled_block_h;", // NOLINT +"const int_tp out_w = block_x * tile_pooled_block_w;", // NOLINT +"const int_tp out_offset = (n * channels * pooled_height + out_h) * pooled_width + out_w + get_local_id(1);", // NOLINT +"const int_tp step = height * width;", // NOLINT +"const int_tp out_step = pooled_height * pooled_width;", // NOLINT +"__global const Dtype* in_off = in + offset + get_local_id(1);", // NOLINT +"__global Dtype* out_off = out + out_offset;", // NOLINT +"Dtype scale_val;", // NOLINT +"int_tp head = 0;", // NOLINT +"const int_tp pre_pad = (size - 1) / 2;", // NOLINT +"const int_tp post_pad = size - pre_pad - 1;", // NOLINT +"Dtype accum_scale[TILE_H] = {0};", // NOLINT +"if (w + get_local_id(1) >= width)", // NOLINT +"return;", // NOLINT +"", // NOLINT +"while ( head < channels + post_pad ) {", // NOLINT +"int ph = 0;", // NOLINT +"int cur_out_h = 0;", // NOLINT +"Dtype output_val = -FLT_MAX;", // NOLINT +"// fill the scale at [n, :, h, w]", // NOLINT +"// accumulate values", // NOLINT +"for( int lrn_out_h = 0; lrn_out_h < TILE_H && (lrn_out_h + h) < height; lrn_out_h++) {", // NOLINT +"Dtype prev_val = accum_scale[lrn_out_h];", // NOLINT +"// add", // NOLINT +"if (head < channels) {", // NOLINT +"prev_val += in_off[head * step + width * lrn_out_h] * in_off[head * step + width * lrn_out_h];", // NOLINT +"}", // NOLINT +"// subtract", // NOLINT +"if (head - size >= 0) {", // NOLINT +"prev_val -= in_off[(head - size) * step + width * lrn_out_h] * in_off[(head - size) * step + width * lrn_out_h];", // NOLINT +"}", // NOLINT +"// compute output.", // NOLINT +"if (head >= post_pad) {", // NOLINT +"scale_val = k + prev_val * alpha_over_size;", // NOLINT +"Dtype tmp = -FLT_MAX;", // NOLINT +"//if (w + get_local_id(1) < width)", // NOLINT +"tmp = in_off[(head - post_pad) * step + width * lrn_out_h] * native_powr(scale_val, negative_beta);", // NOLINT +"", // NOLINT +"Dtype h_max_val = -FLT_MAX;", // NOLINT +"int index = (get_local_id(1) * pool_stride_w) % SIMD_WIDTH;", // NOLINT +"for(int i = 0; i < pool_w; i++) {", // NOLINT +"Dtype val = intel_sub_group_shuffle(tmp, index);", // NOLINT +"if (h_max_val < val && (index + w < width))", // NOLINT +"h_max_val = val;", // NOLINT +"", // NOLINT +"index = (index + 1) % SIMD_WIDTH;", // NOLINT +"}", // NOLINT +"// update output value.", // NOLINT +"output_val = (output_val > h_max_val) ?", // NOLINT +"output_val : h_max_val;", // NOLINT +"// time to write previous output and move to next value", // NOLINT +"if (lrn_out_h - cur_out_h + 1 == pool_h) {", // NOLINT +"if (get_local_id(1) < tile_pooled_block_w && (out_w + get_local_id(1)) < pooled_width) {", // NOLINT +"out_off[(head - post_pad) * out_step + ph * pooled_width] = output_val;", // NOLINT +"", // NOLINT +"output_val = h_max_val;", // NOLINT +"}", // NOLINT +"++ph;", // NOLINT +"cur_out_h += pool_stride_h;", // NOLINT +"}", // NOLINT +"}", // NOLINT +"accum_scale[lrn_out_h] = prev_val;", // NOLINT +"}", // NOLINT +"// Handle the incomplete pool box", // NOLINT +"// an incomplete tiling box and we are not hitting the end of the pooled output.", // NOLINT +"if (head >= post_pad &&", // NOLINT +"ph < tile_pooled_block_h &&", // NOLINT +"ph + out_h < pooled_height &&", // NOLINT +"get_local_id(1) < tile_pooled_block_w &&", // NOLINT +"(out_w + get_local_id(1)) < pooled_width) {", // NOLINT +"out_off[(head - post_pad) * out_step + ph * pooled_width] = output_val;", // NOLINT +"}", // NOLINT +"head++;", // NOLINT +"}", // NOLINT +"}", // NOLINT +"", // NOLINT +"#undef TILE_W", // NOLINT +"#undef TILE_H", // NOLINT +"#undef SIMD_WIDTH", // NOLINT +"", // NOLINT "__kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int_tp nthreads, __global const Dtype* in,", // NOLINT "const int_tp num, const int_tp channels,", // NOLINT "const int_tp height, const int_tp width, const int_tp size,", // NOLINT @@ -5631,14 +5741,14 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(max_pool_forward,Dtype)(", // NOLINT +"void TEMPLATE(max_pool_forward_impl, Dtype)(", // NOLINT "const int_tp nthreads, __global const Dtype* bottom_data, const int_tp num,", // NOLINT "const int_tp channels, const int_tp height, const int_tp width,", // NOLINT "const int_tp pooled_height, const int_tp pooled_width, const int_tp kernel_h,", // NOLINT "const int_tp kernel_w, const int_tp stride_h, const int_tp stride_w, const int_tp pad_h,", // NOLINT "const int_tp pad_w,", // NOLINT "__global Dtype* top_data,", // NOLINT -"const int use_mask, __global int_tp* mask, __global Dtype* top_mask) {", // NOLINT +"const int use_mask, __global int_tp* mask, __global Dtype* top_mask, bool no_mask) {", // NOLINT "for (int_tp index = get_global_id(0); index < nthreads;", // NOLINT "index += get_global_size(0)) {", // NOLINT "const int_tp pw = index % pooled_width;", // NOLINT @@ -5664,6 +5774,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "}", // NOLINT "top_data[index] = maxval;", // NOLINT +"if (!no_mask) {", // NOLINT "if (use_mask == 1) {", // NOLINT "mask[index] = maxidx;", // NOLINT "} else {", // NOLINT @@ -5671,8 +5782,40 @@ static std::vector> cl_kernels{ "}", // NOLINT "}", // NOLINT "}", // NOLINT +"}", // NOLINT +"", // NOLINT +"__kernel void TEMPLATE(max_pool_forward_no_mask, Dtype)(", // NOLINT +"const int_tp nthreads, __global const Dtype* bottom_data, const int_tp num,", // NOLINT +"const int_tp channels, const int_tp height, const int_tp width,", // NOLINT +"const int_tp pooled_height, const int_tp pooled_width, const int_tp kernel_h,", // NOLINT +"const int_tp kernel_w, const int_tp stride_h, const int_tp stride_w, const int_tp pad_h,", // NOLINT +"const int_tp pad_w,", // NOLINT +"__global Dtype* top_data) {", // NOLINT +"", // NOLINT +"TEMPLATE(max_pool_forward_impl, Dtype)(", // NOLINT +"nthreads, bottom_data, num, channels, height, width,", // NOLINT +"pooled_height, pooled_width, kernel_h,", // NOLINT +"kernel_w, stride_h, stride_w, pad_h, pad_w, top_data, 0, NULL, NULL, true", // NOLINT +");", // NOLINT +"}", // NOLINT +"", // NOLINT +"__kernel void TEMPLATE(max_pool_forward, Dtype)(", // NOLINT +"const int_tp nthreads, __global const Dtype* bottom_data, const int_tp num,", // NOLINT +"const int_tp channels, const int_tp height, const int_tp width,", // NOLINT +"const int_tp pooled_height, const int_tp pooled_width, const int_tp kernel_h,", // NOLINT +"const int_tp kernel_w, const int_tp stride_h, const int_tp stride_w, const int_tp pad_h,", // NOLINT +"const int_tp pad_w,", // NOLINT +"__global Dtype* top_data,", // NOLINT +"const int use_mask, __global int_tp* mask, __global Dtype* top_mask) {", // NOLINT +"", // NOLINT +"TEMPLATE(max_pool_forward_impl, Dtype)(", // NOLINT +"nthreads, bottom_data, num, channels, height, width,", // NOLINT +"pooled_height, pooled_width, kernel_h,", // NOLINT +"kernel_w, stride_h, stride_w, pad_h, pad_w, top_data, use_mask, mask, top_mask, false", // NOLINT +");", // NOLINT +"}", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(ave_pool_forward,Dtype)(", // NOLINT +"__kernel void TEMPLATE(ave_pool_forward, Dtype)(", // NOLINT "const int_tp nthreads, __global const Dtype* const bottom_data, const int_tp num,", // NOLINT "const int_tp channels, const int_tp height, const int_tp width,", // NOLINT "const int_tp pooled_height, const int_tp pooled_width, const int_tp kernel_h,", // NOLINT @@ -5912,7 +6055,7 @@ static std::vector> cl_kernels{ "for (int_tp ph = phstart; ph < phend; ++ph) {", // NOLINT "for (int_tp pw = pwstart; pw < pwend; ++pw) {", // NOLINT "gradient += top_diff_slice[ph * pooled_width + pw]", // NOLINT -"* (index == (int_tp) (rand_idx_slice[ph * pooled_width + pw])?1.0:0.0);", // NOLINT +"* (Dtype)(index == (int_tp) (rand_idx_slice[ph * pooled_width + pw])?1.0:0.0);", // NOLINT "}", // NOLINT "}", // NOLINT "bottom_diff[index] = gradient;", // NOLINT diff --git a/src/caffe/greentea/cl_kernels/lrn.cl b/src/caffe/greentea/cl_kernels/lrn.cl index c548d022273..f4f38fcb5f2 100644 --- a/src/caffe/greentea/cl_kernels/lrn.cl +++ b/src/caffe/greentea/cl_kernels/lrn.cl @@ -120,6 +120,116 @@ __kernel void TEMPLATE(lrn_compute_diff,Dtype)(const int_tp nthreads, } } +#define SIMD_WIDTH 16 +#define TILE_W SIMD_WIDTH +#define TILE_H 8 + +#ifndef BEIGNET +__attribute__((intel_reqd_sub_group_size(SIMD_WIDTH))) +#endif +// Fuse pooling max layer into LRN across channel layer. +// Currently, only support non-padding, non-dilation mode and pool_w/h == pool_stride_w + 1. +// This kernel only get better performance on those Intel platforms with edram. +__kernel void TEMPLATE(lrn_fuse_pool_max,Dtype)( + __global const Dtype* in, + const int_tp channels, + const int_tp height, const int_tp width, + const int_tp tiled_height, int_tp tiled_width, + const int_tp size, + const Dtype alpha_over_size, const Dtype k, + __global Dtype* const out, + const Dtype negative_beta, + const int_tp pool_h, const int_tp pool_w, const int_tp pool_stride_h, int_tp pool_stride_w, + const int_tp pooled_height, const int_tp pooled_width, + const int_tp tile_pooled_block_h, const int_tp tile_pooled_block_w) { + // find out the local offset + const int_tp block_x = get_global_id(0) % tiled_width; + const int_tp block_y = (get_global_id(0) / tiled_width) % tiled_height; + const int_tp n = get_global_id(0) / (tiled_width * tiled_height); + + const int_tp w = block_x * tile_pooled_block_w * pool_stride_w; + const int_tp h = block_y * tile_pooled_block_h * pool_stride_h; + const int_tp offset = (n * channels * height + h) * width + w; + const int_tp out_h = block_y * tile_pooled_block_h; + const int_tp out_w = block_x * tile_pooled_block_w; + const int_tp out_offset = (n * channels * pooled_height + out_h) * pooled_width + out_w + get_local_id(1); + const int_tp step = height * width; + const int_tp out_step = pooled_height * pooled_width; + __global const Dtype* in_off = in + offset + get_local_id(1); + __global Dtype* out_off = out + out_offset; + Dtype scale_val; + int_tp head = 0; + const int_tp pre_pad = (size - 1) / 2; + const int_tp post_pad = size - pre_pad - 1; + Dtype accum_scale[TILE_H] = {0}; + if (w + get_local_id(1) >= width) + return; + + while ( head < channels + post_pad ) { + int ph = 0; + int cur_out_h = 0; + Dtype output_val = -FLT_MAX; + // fill the scale at [n, :, h, w] + // accumulate values + for( int lrn_out_h = 0; lrn_out_h < TILE_H && (lrn_out_h + h) < height; lrn_out_h++) { + Dtype prev_val = accum_scale[lrn_out_h]; + // add + if (head < channels) { + prev_val += in_off[head * step + width * lrn_out_h] * in_off[head * step + width * lrn_out_h]; + } + // subtract + if (head - size >= 0) { + prev_val -= in_off[(head - size) * step + width * lrn_out_h] * in_off[(head - size) * step + width * lrn_out_h]; + } + // compute output. + if (head >= post_pad) { + scale_val = k + prev_val * alpha_over_size; + Dtype tmp = -FLT_MAX; + //if (w + get_local_id(1) < width) + tmp = in_off[(head - post_pad) * step + width * lrn_out_h] * native_powr(scale_val, negative_beta); + + Dtype h_max_val = -FLT_MAX; + int index = (get_local_id(1) * pool_stride_w) % SIMD_WIDTH; + for(int i = 0; i < pool_w; i++) { + Dtype val = intel_sub_group_shuffle(tmp, index); + if (h_max_val < val && (index + w < width)) + h_max_val = val; + + index = (index + 1) % SIMD_WIDTH; + } + // update output value. + output_val = (output_val > h_max_val) ? + output_val : h_max_val; + // time to write previous output and move to next value + if (lrn_out_h - cur_out_h + 1 == pool_h) { + if (get_local_id(1) < tile_pooled_block_w && (out_w + get_local_id(1)) < pooled_width) { + out_off[(head - post_pad) * out_step + ph * pooled_width] = output_val; + + output_val = h_max_val; + } + ++ph; + cur_out_h += pool_stride_h; + } + } + accum_scale[lrn_out_h] = prev_val; + } + // Handle the incomplete pool box + // an incomplete tiling box and we are not hitting the end of the pooled output. + if (head >= post_pad && + ph < tile_pooled_block_h && + ph + out_h < pooled_height && + get_local_id(1) < tile_pooled_block_w && + (out_w + get_local_id(1)) < pooled_width) { + out_off[(head - post_pad) * out_step + ph * pooled_width] = output_val; + } + head++; + } +} + +#undef TILE_W +#undef TILE_H +#undef SIMD_WIDTH + __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int_tp nthreads, __global const Dtype* in, const int_tp num, const int_tp channels, const int_tp height, const int_tp width, const int_tp size, diff --git a/src/caffe/greentea/cl_kernels/pooling.cl b/src/caffe/greentea/cl_kernels/pooling.cl index cc56bab12d9..37400e6f84a 100644 --- a/src/caffe/greentea/cl_kernels/pooling.cl +++ b/src/caffe/greentea/cl_kernels/pooling.cl @@ -2,14 +2,14 @@ #include "header.cl" #endif -__kernel void TEMPLATE(max_pool_forward,Dtype)( +void TEMPLATE(max_pool_forward_impl, Dtype)( const int_tp nthreads, __global const Dtype* bottom_data, const int_tp num, const int_tp channels, const int_tp height, const int_tp width, const int_tp pooled_height, const int_tp pooled_width, const int_tp kernel_h, const int_tp kernel_w, const int_tp stride_h, const int_tp stride_w, const int_tp pad_h, const int_tp pad_w, __global Dtype* top_data, - const int use_mask, __global int_tp* mask, __global Dtype* top_mask) { + const int use_mask, __global int_tp* mask, __global Dtype* top_mask, bool no_mask) { for (int_tp index = get_global_id(0); index < nthreads; index += get_global_size(0)) { const int_tp pw = index % pooled_width; @@ -35,15 +35,48 @@ __kernel void TEMPLATE(max_pool_forward,Dtype)( } } top_data[index] = maxval; - if (use_mask == 1) { - mask[index] = maxidx; - } else { - top_mask[index] = maxidx; + if (!no_mask) { + if (use_mask == 1) { + mask[index] = maxidx; + } else { + top_mask[index] = maxidx; + } } } } -__kernel void TEMPLATE(ave_pool_forward,Dtype)( +__kernel void TEMPLATE(max_pool_forward_no_mask, Dtype)( + const int_tp nthreads, __global const Dtype* bottom_data, const int_tp num, + const int_tp channels, const int_tp height, const int_tp width, + const int_tp pooled_height, const int_tp pooled_width, const int_tp kernel_h, + const int_tp kernel_w, const int_tp stride_h, const int_tp stride_w, const int_tp pad_h, + const int_tp pad_w, + __global Dtype* top_data) { + + TEMPLATE(max_pool_forward_impl, Dtype)( + nthreads, bottom_data, num, channels, height, width, + pooled_height, pooled_width, kernel_h, + kernel_w, stride_h, stride_w, pad_h, pad_w, top_data, 0, NULL, NULL, true + ); +} + +__kernel void TEMPLATE(max_pool_forward, Dtype)( + const int_tp nthreads, __global const Dtype* bottom_data, const int_tp num, + const int_tp channels, const int_tp height, const int_tp width, + const int_tp pooled_height, const int_tp pooled_width, const int_tp kernel_h, + const int_tp kernel_w, const int_tp stride_h, const int_tp stride_w, const int_tp pad_h, + const int_tp pad_w, + __global Dtype* top_data, + const int use_mask, __global int_tp* mask, __global Dtype* top_mask) { + + TEMPLATE(max_pool_forward_impl, Dtype)( + nthreads, bottom_data, num, channels, height, width, + pooled_height, pooled_width, kernel_h, + kernel_w, stride_h, stride_w, pad_h, pad_w, top_data, use_mask, mask, top_mask, false + ); +} + +__kernel void TEMPLATE(ave_pool_forward, Dtype)( const int_tp nthreads, __global const Dtype* const bottom_data, const int_tp num, const int_tp channels, const int_tp height, const int_tp width, const int_tp pooled_height, const int_tp pooled_width, const int_tp kernel_h, @@ -283,7 +316,7 @@ __kernel void TEMPLATE(sto_pool_backward,Dtype)( for (int_tp ph = phstart; ph < phend; ++ph) { for (int_tp pw = pwstart; pw < pwend; ++pw) { gradient += top_diff_slice[ph * pooled_width + pw] - * (index == (int_tp) (rand_idx_slice[ph * pooled_width + pw])?1.0:0.0); + * (Dtype)(index == (int_tp) (rand_idx_slice[ph * pooled_width + pw])?1.0:0.0); } } bottom_diff[index] = gradient; diff --git a/src/caffe/layers/conv_layer_spatial.cpp b/src/caffe/layers/conv_layer_spatial.cpp index ff4493d1c5e..da6f824fe84 100644 --- a/src/caffe/layers/conv_layer_spatial.cpp +++ b/src/caffe/layers/conv_layer_spatial.cpp @@ -662,31 +662,6 @@ cl_int ConvolutionLayerSpatial::convolve( kernel.arg(argIdx++, (uint16_t)height_); kernel.arg(argIdx++, (uint16_t)output_w_); kernel.arg(argIdx++, (uint16_t)output_h_); - int out_pitch_y = output_w_ * output_h_; - int out_pitch_z = out_pitch_y * M_; - int aligned_input_size = height_ * width_ * channels_ / group_; - int slice_pitch = width_ * height_; - kernel.arg(argIdx++, (uint32_t)out_pitch_y); - kernel.arg(argIdx++, (uint32_t)out_pitch_z); - kernel.arg(argIdx++, (uint32_t)aligned_input_size); - kernel.arg(argIdx++, (uint32_t)slice_pitch); - - int blockM = config->workItem_output[0]; - int blockK = config->workItem_output[1]; - int blockN = config->workItem_output[2]; - int_tp alignedFilterWidth = ALIGN(M_, blockN); - int_tp alignedExpandHeight = ALIGN(output_w_ * output_h_, blockM); - int_tp globalWorkSizeDX = blockN; - int_tp globalWorkSizeDY = blockM; - size_t sgemm_m = alignedExpandHeight; - size_t sgemm_n = alignedFilterWidth; - size_t gx = (size_t) ceil( (float) sgemm_n / - (float) globalWorkSizeDX ); - size_t gy = (size_t) ceil( (float) sgemm_m / - (float) globalWorkSizeDY ); - gy = ALIGN(gy, blockK); - size_t global_size[3] = { gx, gy, config->global_work_size[2] }; - viennacl::ocl::context &ctx = viennacl::ocl::get_context(this->device_->id()); int out_pitch_y = output_w_ * output_h_; diff --git a/src/caffe/layers/lrn_layer.cpp b/src/caffe/layers/lrn_layer.cpp index 0bd0229664d..aad981e4c6e 100644 --- a/src/caffe/layers/lrn_layer.cpp +++ b/src/caffe/layers/lrn_layer.cpp @@ -63,6 +63,38 @@ void LRNLayer::LayerSetUp(const vector*>& bottom, product_layer_.reset(new EltwiseLayer(product_param)); product_layer_->SetUp(product_bottom_vec_, top); } + if (IsFused()) { + CHECK(this->layer_param_.lrn_param().norm_region() + == LRNParameter_NormRegion_ACROSS_CHANNELS); + CHECK(this->phase_ == caffe::TEST); + CHECK(this->layer_param_.lrn_param().pooling_param().pool() + == PoolingParameter_PoolMethod_MAX); + CHECK(this->layer_param_.lrn_param().pooling_param().kernel_size(0) + < 6); + CHECK(this->layer_param_.lrn_param().pooling_param().stride(0) < 4); + CHECK(this->layer_param_.lrn_param().pooling_param().dilation_size() == 0); + pool_w_ = this->layer_param_.lrn_param().pooling_param().kernel_size(0); + pool_h_ = this->layer_param_.lrn_param().pooling_param().kernel_size(0); + pool_stride_w_ = this->layer_param_.lrn_param().pooling_param().stride(0); + pool_stride_h_ = this->layer_param_.lrn_param().pooling_param().stride(0); + // currently, only support the stride == pool - 1 + CHECK(pool_w_ - pool_stride_w_ == 1 && pool_w_ < 8); + if (!this->layer_param_.lrn_param().unit_test_mode()) { + fuse_tuned_ = false; + tuned_use_fuse_ = false; + } else { + fuse_tuned_ = true; + tuned_use_fuse_ = this->layer_param_.lrn_param().unit_test_fuse_kernel(); + } + lrn_top_vec_.push_back(&lrn_top_blob_); + LayerParameter pooling_param; + pooling_param.mutable_pooling_param()->set_pool(PoolingParameter_PoolMethod_MAX); + pooling_param.mutable_pooling_param()->add_kernel_size(pool_w_); + pooling_param.mutable_pooling_param()->add_stride(pool_stride_w_); + pool_layer_.reset(new PoolingLayer(pooling_param)); + lrn_top_blob_.ReshapeLike(*bottom[0]); + pool_layer_->SetUp(lrn_top_vec_, top); + } } template @@ -76,7 +108,16 @@ void LRNLayer::Reshape(const vector*>& bottom, width_ = bottom[0]->width(); switch (this->layer_param_.lrn_param().norm_region()) { case LRNParameter_NormRegion_ACROSS_CHANNELS: - top[0]->Reshape(num_, channels_, height_, width_); + if (IsFused()) { + pooled_width_ = static_cast(ceil( + static_cast(width_ - pool_w_) / pool_stride_w_)) + 1; + pooled_height_ = static_cast(ceil( + static_cast(height_ - pool_h_) / pool_stride_h_)) + 1; + top[0]->Reshape(num_, channels_, pooled_width_, pooled_height_); + lrn_top_blob_.Reshape(num_, channels_, width_, height_); + } else { + top[0]->Reshape(num_, channels_, height_, width_); + } scale_.Reshape(num_, channels_, height_, width_); break; case LRNParameter_NormRegion_WITHIN_CHANNEL: diff --git a/src/caffe/layers/lrn_layer.cu b/src/caffe/layers/lrn_layer.cu index 3f7c0de7e20..cbe831172e6 100644 --- a/src/caffe/layers/lrn_layer.cu +++ b/src/caffe/layers/lrn_layer.cu @@ -2,6 +2,7 @@ #include "caffe/layers/lrn_layer.hpp" #include "caffe/util/math_functions.hpp" +#include "caffe/util/benchmark.hpp" namespace caffe { @@ -82,6 +83,98 @@ __global__ void LRNComputeOutput(const int_tp nthreads, const Dtype* const in, #endif // USE_CUDA template +void LRNLayer::CrossChannelForward_fuse_pooling_gpu( + const vector*>& bottom, + const vector*>& top, + bool use_fuse) { + + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + CHECK(IsFusedWithPoolMax() && this->device_->backend() == BACKEND_OpenCL); + + viennacl::ocl::context &ctx = viennacl::ocl::get_context( + this->device_->id()); + viennacl::ocl::program &program = this->device_->program(); + + if (use_fuse) { + viennacl::ocl::kernel &oclk_lrn_fill = program.get_kernel( + CL_KERNEL_SELECT("lrn_fuse_pool_max")); + #define TILE_W 16 + #define TILE_H 8 + size_t simd_size = TILE_W; + cl_uint argIdx = 0; + const int_tp tile_pooled_block_h = (TILE_H - pool_h_) / pool_stride_h_ + 1; + const int_tp tile_pooled_block_w = (TILE_W - pool_w_) / pool_stride_w_ + 1; + const int tiled_width = (width_ + tile_pooled_block_w * pool_stride_w_ - 1) + / (tile_pooled_block_w * pool_stride_w_); + const int tiled_height = (height_ + tile_pooled_block_h * pool_stride_h_ - 1) + / (tile_pooled_block_h * pool_stride_h_); + int_tp n_threads = num_ * tiled_width * tiled_height; + size_t global_work_size_[2] = {(size_t)n_threads, simd_size}; + size_t local_work_size[2] = {1, simd_size}; + oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) bottom_data, &ctx)); + oclk_lrn_fill.arg(argIdx++, channels_); + oclk_lrn_fill.arg(argIdx++, height_); + oclk_lrn_fill.arg(argIdx++, width_); + oclk_lrn_fill.arg(argIdx++, tiled_height); + oclk_lrn_fill.arg(argIdx++, tiled_width); + oclk_lrn_fill.arg(argIdx++, size_); + oclk_lrn_fill.arg(argIdx++, alpha_ / size_); + oclk_lrn_fill.arg(argIdx++, k_); + oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) top_data, &ctx)); + oclk_lrn_fill.arg(argIdx++, -beta_); + oclk_lrn_fill.arg(argIdx++, pool_h_); + oclk_lrn_fill.arg(argIdx++, pool_w_); + oclk_lrn_fill.arg(argIdx++, pool_stride_h_); + oclk_lrn_fill.arg(argIdx++, pool_stride_w_); + oclk_lrn_fill.arg(argIdx++, pooled_height_); + oclk_lrn_fill.arg(argIdx++, pooled_width_); + oclk_lrn_fill.arg(argIdx++, tile_pooled_block_h); + oclk_lrn_fill.arg(argIdx++, tile_pooled_block_w); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_lrn_fill.handle().get(), 2, NULL, + global_work_size_, local_work_size, 0, NULL, + NULL)); + } else { + Dtype* top_lrn_data = lrn_top_blob_.mutable_gpu_data(); + // Do LRN firstly. + cl_uint argIdx = 0; + int_tp n_threads = num_ * height_ * width_; + size_t global_work_size_[1] = {(size_t)n_threads}; + viennacl::ocl::kernel &oclk_lrn_fill = program.get_kernel( + CL_KERNEL_SELECT("lrn_full_no_scale")); + oclk_lrn_fill.arg(argIdx++, n_threads); + oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) bottom_data, &ctx)); + oclk_lrn_fill.arg(argIdx++, num_); + oclk_lrn_fill.arg(argIdx++, channels_); + oclk_lrn_fill.arg(argIdx++, height_); + oclk_lrn_fill.arg(argIdx++, width_); + oclk_lrn_fill.arg(argIdx++, size_); + oclk_lrn_fill.arg(argIdx++, alpha_ / size_); + oclk_lrn_fill.arg(argIdx++, k_); + oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) top_lrn_data, &ctx)); + oclk_lrn_fill.arg(argIdx++, -beta_); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_lrn_fill.handle().get(), 1, NULL, + global_work_size_, NULL, 0, NULL, + NULL)); + // Do pooling. + viennacl::ocl::kernel &oclk_max_pool_forward = program.get_kernel( + CL_KERNEL_SELECT("max_pool_forward_no_mask")); + + int_tp count = pooled_width_ * pooled_height_ * channels_ * num_; + viennacl::ocl::enqueue( + oclk_max_pool_forward(count, + WrapHandle((cl_mem) top_lrn_data, &ctx), + num_, channels_, height_, width_, + pooled_height_, pooled_width_, pool_h_, + pool_w_, pool_stride_h_, pool_stride_w_, 0, 0, + WrapHandle((cl_mem) top_data, &ctx)), + ctx.get_queue()); + } +} + +template void LRNLayer::CrossChannelForward_gpu( const vector*>& bottom, const vector*>& top) { // First, compute scale @@ -115,11 +208,11 @@ void LRNLayer::CrossChannelForward_gpu( this->device_->id()); viennacl::ocl::program &program = this->device_->program(); - int_tp n_threads = num_ * height_ * width_; - cl_uint argIdx = 0; - size_t global_work_size_[1] = {(size_t)n_threads}; if (this->phase_ == caffe::TRAIN) { + cl_uint argIdx = 0; + int_tp n_threads = num_ * height_ * width_; + size_t global_work_size_[1] = {(size_t)n_threads}; viennacl::ocl::kernel &oclk_lrn_fill = program.get_kernel( CL_KERNEL_SELECT("lrn_full")); @@ -141,26 +234,54 @@ void LRNLayer::CrossChannelForward_gpu( global_work_size_, NULL, 0, NULL, NULL)); } else { - viennacl::ocl::kernel &oclk_lrn_fill = program.get_kernel( - CL_KERNEL_SELECT("lrn_full_no_scale")); - - cl_uint argIdx = 0; - oclk_lrn_fill.arg(argIdx++, n_threads); - oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) bottom_data, &ctx)); - oclk_lrn_fill.arg(argIdx++, num_); - oclk_lrn_fill.arg(argIdx++, channels_); - oclk_lrn_fill.arg(argIdx++, height_); - oclk_lrn_fill.arg(argIdx++, width_); - oclk_lrn_fill.arg(argIdx++, size_); - oclk_lrn_fill.arg(argIdx++, alpha_ / size_); - oclk_lrn_fill.arg(argIdx++, k_); - oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) top_data, &ctx)); - oclk_lrn_fill.arg(argIdx++, -beta_); - OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + if (!IsFused()) { + cl_uint argIdx = 0; + int_tp n_threads = num_ * height_ * width_; + size_t global_work_size_[1] = {(size_t)n_threads}; + viennacl::ocl::kernel &oclk_lrn_fill = program.get_kernel( + CL_KERNEL_SELECT("lrn_full_no_scale")); + oclk_lrn_fill.arg(argIdx++, n_threads); + oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) bottom_data, &ctx)); + oclk_lrn_fill.arg(argIdx++, num_); + oclk_lrn_fill.arg(argIdx++, channels_); + oclk_lrn_fill.arg(argIdx++, height_); + oclk_lrn_fill.arg(argIdx++, width_); + oclk_lrn_fill.arg(argIdx++, size_); + oclk_lrn_fill.arg(argIdx++, alpha_ / size_); + oclk_lrn_fill.arg(argIdx++, k_); + oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) top_data, &ctx)); + oclk_lrn_fill.arg(argIdx++, -beta_); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), oclk_lrn_fill.handle().get(), 1, NULL, global_work_size_, NULL, 0, NULL, NULL)); + } else if (IsFusedWithPoolMax()) { + // We can't make sure the fused kernel be the faster for all platforms. + // have to apply a simple tuning here. + if (fuse_tuned_) + CrossChannelForward_fuse_pooling_gpu(bottom, top, tuned_use_fuse_); + else { + float elapsedTime[2]; + bool use_fuse[2] = {true, false}; + // warm up. + CrossChannelForward_fuse_pooling_gpu(bottom, top, true); + CrossChannelForward_fuse_pooling_gpu(bottom, top, false); + for (int i = 0; i < 2; i++) { + Timer timer; + timer.initted(); + timer.Start(); + int loop_cnt = 2; + for (int j = 0; j < loop_cnt; j++) { + CrossChannelForward_fuse_pooling_gpu(bottom, top, use_fuse[i]); + } + timer.Stop(); + elapsedTime[i] = timer.MilliSeconds() / loop_cnt; + } + tuned_use_fuse_ = elapsedTime[0] < elapsedTime[1]; + fuse_tuned_ = true; + } + } } #endif // USE_GREENTEA } @@ -170,6 +291,11 @@ template void LRNLayer::CrossChannelForward_gpu( template void LRNLayer::CrossChannelForward_gpu( const vector*>& bottom, const vector*>& top); +template void LRNLayer::CrossChannelForward_fuse_pooling_gpu( + const vector*>& bottom, const vector*>& top, bool); +template void LRNLayer::CrossChannelForward_fuse_pooling_gpu( + const vector*>& bottom, const vector*>& top, bool); + template void LRNLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, diff --git a/src/caffe/test/test_lrn_layer.cpp b/src/caffe/test/test_lrn_layer.cpp index 0ec2d03c85f..8967d8efd28 100644 --- a/src/caffe/test/test_lrn_layer.cpp +++ b/src/caffe/test/test_lrn_layer.cpp @@ -458,4 +458,74 @@ TYPED_TEST(CuDNNLRNLayerTest, TestGradientAcrossChannelsLargeRegionCuDNN) { #endif +template +class LRNFuseLayerTest : public GPUDeviceTest { + protected: + LRNFuseLayerTest() + : epsilon_(Dtype(1e-3)), + blob_bottom_(new Blob()), + blob_top_(new Blob()) {} + virtual void SetUp() { + Caffe::set_random_seed(1701, Caffe::GetDefaultDevice()); + blob_bottom_->Reshape(1, 32, 55, 55); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~LRNFuseLayerTest() { delete blob_bottom_; delete blob_top_; } + + Dtype epsilon_; + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(LRNFuseLayerTest, TestDtypes); + +TYPED_TEST(LRNFuseLayerTest, TestForwardAcrossChannelsFusePoolMax) { + LayerParameter layer_param; + + Blob top_reference; + LRNLayer lrnLayer(layer_param); + + // calculate reference value by lrn layer followed by pooling layer + lrnLayer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + lrnLayer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + LayerParameter pooling_param; + pooling_param.mutable_pooling_param()->set_pool(PoolingParameter_PoolMethod_MAX); + pooling_param.mutable_pooling_param()->add_kernel_size(3); + pooling_param.mutable_pooling_param()->add_stride(2); + PoolingLayer pooling_layer(pooling_param); + vector*> top_reference_vec; + top_reference_vec.push_back(&top_reference); + pooling_layer.SetUp(this->blob_top_vec_, top_reference_vec); + pooling_layer.Forward(this->blob_top_vec_, top_reference_vec); + // calculate result by lrn fused with pooling layer. + LayerParameter fused_layer_param; + fused_layer_param.set_phase(TEST); + fused_layer_param.mutable_lrn_param()->set_fuse_type(LRNParameter_FuseType_FUSED_POOL_MAX); + fused_layer_param.mutable_lrn_param()->set_unit_test_mode(true); + fused_layer_param.mutable_lrn_param()->mutable_pooling_param()->set_pool(PoolingParameter_PoolMethod_MAX); + fused_layer_param.mutable_lrn_param()->mutable_pooling_param()->add_kernel_size(3); + fused_layer_param.mutable_lrn_param()->mutable_pooling_param()->add_stride(2); + + bool test_fuse_kernel[2] = {true, false}; + for (int_tp index = 0; index < 2; index++) { + fused_layer_param.mutable_lrn_param()->set_unit_test_fuse_kernel(test_fuse_kernel[index]); + LRNLayer fused_layer(fused_layer_param); + fused_layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + fused_layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + for (int_tp i = 0; i < top_reference.count(); ++i) { + EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i], + this->epsilon_); + } + memset(this->blob_top_->mutable_cpu_data(), 0, top_reference.count()); + } +} + } // namespace caffe From a5dd29772f8b58142ea9b6f7efbe1f271bf395d7 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 03:27:20 +0800 Subject: [PATCH 06/33] Enable image based GEMM interface for inner product layer. Signed-off-by: Zhigang Gong --- include/caffe/layers/inner_product_layer.hpp | 17 +- src/caffe/layers/inner_product_layer.cpp | 25 ++ src/caffe/layers/inner_product_layer.cu | 37 ++- src/caffe/test/test_inner_product_layer.cpp | 442 ++++++++++++++++++++++++++- 4 files changed, 513 insertions(+), 8 deletions(-) diff --git a/include/caffe/layers/inner_product_layer.hpp b/include/caffe/layers/inner_product_layer.hpp index 16e7cd15fb5..fa639a8c3f5 100644 --- a/include/caffe/layers/inner_product_layer.hpp +++ b/include/caffe/layers/inner_product_layer.hpp @@ -19,7 +19,11 @@ template class InnerProductLayer : public Layer { public: explicit InnerProductLayer(const LayerParameter& param) - : Layer(param) {} + : Layer(param) { +#ifdef USE_GREENTEA + weight_image_ = NULL; +#endif + } virtual void LayerSetUp(const vector*>& bottom, const vector*>& top); virtual void Reshape(const vector*>& bottom, @@ -28,7 +32,13 @@ class InnerProductLayer : public Layer { virtual inline const char* type() const { return "InnerProduct"; } virtual inline int_tp ExactNumBottomBlobs() const { return 1; } virtual inline int_tp ExactNumTopBlobs() const { return 1; } - +#ifdef USE_GREENTEA + ~InnerProductLayer() { + if (weight_image_) + clReleaseMemObject(weight_image_); + weight_image_ = NULL; + } +#endif protected: virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); @@ -45,6 +55,9 @@ class InnerProductLayer : public Layer { bool bias_term_; Blob bias_multiplier_; bool transpose_; ///< if true, assume transposed weights +#ifdef USE_GREENTEA + cl_mem weight_image_; +#endif }; } // namespace caffe diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index 8cb9b63bb0a..70711887223 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -53,6 +53,31 @@ void InnerProductLayer::LayerSetUp(const vector*>& bottom, } } // parameter initialization this->param_propagate_down_.resize(this->blobs_.size(), true); + + if (this->device_->backend() == BACKEND_OpenCL && this->phase_ == TEST) { + viennacl::ocl::context &ctx = + viennacl::ocl::get_context(this->device_->id()); + size_t max_image_size = std::min(ctx.devices()[0].image2d_max_width(), + ctx.devices()[0].image2d_max_height()); + // For inference only, we can load the weights data to image on Intel platform. + // As image based GEMM is much faster than the buffer based GEMM for most cases. + if (N_ <= max_image_size && + K_ <= max_image_size && + std::is_same::value && + this->device_->CheckCapability("cl_intel_subgroups")) { + const Dtype* weight = this->blobs_[0]->gpu_data(); + int height = !transpose_ ? N_ : K_; + int width = !transpose_ ? K_ : N_; + int padded_height = !transpose_ ? height : (height + ((height & 7) ? 1 : 0)); + int padded_width = !transpose_ ? width : (width + ((width & 7) ? 1 : 0)); + greentea_gpu_gemm_copy_buffer_to_image(this->device_->id(), + &weight_image_, (cl_mem) weight, 0, + false, !transpose_, + true, padded_height, padded_width, + height, width, (int)0, NULL, NULL); + } + } + } template diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index ca999436775..744039fa7a2 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -46,11 +46,38 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, (cl_mem) (this->blobs_[1]->gpu_data()), 0, (cl_mem) top_data, 0); } else { - greentea_gpu_gemm(this->device_->id(), CblasNoTrans, - transpose_ ? CblasNoTrans : CblasTrans, - M_, N_, K_, (Dtype) 1., - (cl_mem) bottom_data, 0, (cl_mem) weight, 0, - (Dtype) 0., (cl_mem) top_data, 0); + viennacl::ocl::context &ctx = + viennacl::ocl::get_context(this->device_->id()); + size_t max_image_size = std::min(ctx.devices()[0].image2d_max_width(), + ctx.devices()[0].image2d_max_height()); + if (M_ <= max_image_size && + N_ <= max_image_size && + K_ <= max_image_size && + std::is_same::value && + this->device_->CheckCapability("cl_intel_subgroups")) { + if (this->phase_ != TEST) { + int height = !transpose_ ? N_ : K_; + int width = !transpose_ ? K_ : N_; + int padded_height = !transpose_ ? height : (height + ((height & 7) ? 1 : 0)); + int padded_width = !transpose_ ? width : (width + ((width & 7) ? 1 : 0)); + greentea_gpu_gemm_copy_buffer_to_image(this->device_->id(), + &weight_image_, (cl_mem) weight, 0, + false, !transpose_, + true, padded_height, padded_width, + height, width, (int)0, NULL, NULL); + } + greentea_gpu_gemm(this->device_->id(), CblasNoTrans, + transpose_ ? CblasNoTrans : CblasTrans, + M_, N_, K_, (Dtype) 1., + (cl_mem) bottom_data, 0, (cl_mem) weight_image_, 0, + (Dtype) 0., (cl_mem) top_data, 0, false, true); + } else + greentea_gpu_gemm(this->device_->id(), CblasNoTrans, + transpose_ ? CblasNoTrans : CblasTrans, + M_, N_, K_, (Dtype) 1., + (cl_mem) bottom_data, 0, (cl_mem) weight, 0, + (Dtype) 0., (cl_mem) top_data, 0); + if (bias_term_) greentea_gpu_gemm(this->device_->id(), CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype) 1., diff --git a/src/caffe/test/test_inner_product_layer.cpp b/src/caffe/test/test_inner_product_layer.cpp index be937f1cbf2..2cc780e3b3c 100644 --- a/src/caffe/test/test_inner_product_layer.cpp +++ b/src/caffe/test/test_inner_product_layer.cpp @@ -9,13 +9,16 @@ #include "caffe/test/test_caffe_main.hpp" #include "caffe/test/test_gradient_check_util.hpp" +#include "caffe/util/benchmark.hpp" +#include "caffe/util/math_functions.hpp" namespace caffe { template class InnerProductLayerTest : public MultiDeviceTest { typedef typename TypeParam::Dtype Dtype; - protected: + +protected: InnerProductLayerTest() : blob_bottom_(new Blob(2, 3, 4, 5)), blob_bottom_nobatch_(new Blob(1, 2, 3, 4)), @@ -31,11 +34,18 @@ class InnerProductLayerTest : public MultiDeviceTest { delete blob_bottom_nobatch_; delete blob_top_; } + + virtual Blob* MakeReferenceTop(Blob* top) { + this->ref_blob_top_.reset(new Blob()); + this->ref_blob_top_->ReshapeLike(*top); + return this->ref_blob_top_.get(); + } Blob* const blob_bottom_; Blob* const blob_bottom_nobatch_; Blob* const blob_top_; vector*> blob_bottom_vec_; vector*> blob_top_vec_; + shared_ptr > ref_blob_top_; }; TYPED_TEST_CASE(InnerProductLayerTest, TestDtypesAndDevices); @@ -122,6 +132,436 @@ TYPED_TEST(InnerProductLayerTest, TestForward) { } } +TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { + typedef typename TypeParam::Dtype Dtype; + FillerParameter filler_param; + UniformFiller filler(filler_param); + caffe::Caffe::SetDevice(0); + + for(auto i = 1; i <= 8; i*=2) { + Blob* const blob_bottom = new Blob(i, 392, 8, 8); + Blob* const blob_top = new Blob(); + filler.Fill(blob_bottom); + + this->blob_bottom_vec_.clear(); + this->blob_bottom_vec_.push_back(blob_bottom); + this->blob_top_vec_.clear(); + this->blob_top_vec_.push_back(blob_top); + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(4096); + inner_product_param->set_bias_term(false); + inner_product_param->set_transpose(false); + inner_product_param->mutable_weight_filler()->set_type("uniform"); + shared_ptr > layer( + new InnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + this->MakeReferenceTop(blob_top); + const Dtype* A = blob_bottom->cpu_data(); + const Dtype* B = layer->blobs()[0]->cpu_data(); + Dtype* C = this->ref_blob_top_->mutable_cpu_data(); + int_tp M = blob_bottom->shape()[0]; + int_tp N = layer->blobs()[0]->shape(0); + int_tp K = layer->blobs()[0]->shape(1); + + caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, (Dtype)1., A, B, (Dtype)0., C); + + const Dtype* data = blob_top->cpu_data(); + const int_tp count = blob_top->count(); + for (int_tp i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], C[i], 1e-1); + } + { + Timer timer; + timer.initted(); + timer.Start(); + auto times = 10; + for (auto i = 0; i < times; ++i) { + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + } + timer.Stop(); + float elapsedTime = timer.MilliSeconds(); + elapsedTime /= times; + std::cout << "MNK(" << M << ","< filler(filler_param); + caffe::Caffe::SetDevice(0); + + for(auto i = 1; i <= 8; i*=2) { + Blob* const blob_bottom = new Blob(i, 25088+1, 1, 1); + Blob* const blob_top = new Blob(); + filler.Fill(blob_bottom); + + this->blob_bottom_vec_.clear(); + this->blob_bottom_vec_.push_back(blob_bottom); + this->blob_top_vec_.clear(); + this->blob_top_vec_.push_back(blob_top); + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(4096+1); + inner_product_param->set_bias_term(false); + inner_product_param->set_transpose(false); + inner_product_param->mutable_weight_filler()->set_type("uniform"); + shared_ptr > layer( + new InnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + this->MakeReferenceTop(blob_top); + const Dtype* A = blob_bottom->cpu_data(); + const Dtype* B = layer->blobs()[0]->cpu_data(); + Dtype* C = this->ref_blob_top_->mutable_cpu_data(); + int_tp M = blob_bottom->shape()[0]; + int_tp N = layer->blobs()[0]->shape(0); + int_tp K = layer->blobs()[0]->shape(1); + + caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, (Dtype)1., A, B, (Dtype)0., C); + + const Dtype* data = blob_top->cpu_data(); + const int_tp count = blob_top->count(); + std::cout << blob_top->count() << std::endl; + for (int_tp i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], C[i], 1e-1); + } + { + Timer timer; + timer.initted(); + timer.Start(); + auto times = 10; + for (auto i = 0; i < times; ++i) { + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + } + timer.Stop(); + float elapsedTime = timer.MilliSeconds(); + elapsedTime /= times; + std::cout << "MNK(" << M << ","< +void gemv(const vector > >& A, + const int offA, + const int M, + const int N, + const Blob* x, + const int offx, + Blob* y, + const int offy, + const float alpha, + const float beta) { + const unsigned rows = M; + const unsigned cols = N; + const Dtype* mat = A[0]->cpu_data() + offA; + const Dtype* vec = x->cpu_data() + offx; + Dtype* out_data = y->mutable_cpu_data() + offy; + + for (unsigned int r = 0; r < rows; r++) { + out_data[r] = beta * out_data[r]; + for (unsigned int c = 0; c < cols; c++) { + out_data[r] += alpha * mat[r * cols + c] * vec[c]; + } + } +} + +template void gemv(const vector > >& A, + const int offA, + const int M, + const int N, + const Blob* x, + const int offx, + Blob* y, + const int offy, + const float alpha, + const float beta); + +template void gemv(const vector > >& A, + const int offA, + const int M, + const int N, + const Blob* x, + const int offx, + Blob* y, + const int offy, + const float alpha, + const float beta); + +TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC6) { + typedef typename TypeParam::Dtype Dtype; + + Blob* const blob_bottom = new Blob(1, 256, 6, 6); + Blob* const blob_top = new Blob(); + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(blob_bottom); + + this->blob_bottom_vec_.clear(); + this->blob_bottom_vec_.push_back(blob_bottom); + this->blob_top_vec_.clear(); + this->blob_top_vec_.push_back(blob_top); + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(4096); + inner_product_param->set_bias_term(false); + inner_product_param->mutable_weight_filler()->set_type("uniform"); + shared_ptr > layer( + new InnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + gemv(layer->blobs(), 0, layer->blobs()[0]->shape(0), + layer->blobs()[0]->shape(1), + blob_bottom, 0, + this->MakeReferenceTop(blob_top), 0, 1., 0.); + + const Dtype* data = blob_top->cpu_data(); + const Dtype* ref_data = this->ref_blob_top_->cpu_data(); + const int_tp count = blob_top->count(); + for (int_tp i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], ref_data[i], 1e-1); + } + + Timer timer; + timer.initted(); + timer.Start(); + for (uint i = 0; i < 100; ++i) { + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + } + timer.Stop(); + float elapsedTime = timer.MilliSeconds(); + std::cout << "GEMV(4096x9216) Time is: " << elapsedTime / 100.f + <<" ms" << std::endl; + + delete blob_bottom; + delete blob_top; +} + +TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC7) { + typedef typename TypeParam::Dtype Dtype; + + Blob* const blob_bottom = new Blob(1, 4096, 1, 1); + Blob* const blob_top = new Blob(); + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(blob_bottom); + + this->blob_bottom_vec_.clear(); + this->blob_bottom_vec_.push_back(blob_bottom); + this->blob_top_vec_.clear(); + this->blob_top_vec_.push_back(blob_top); + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(4096); + inner_product_param->set_bias_term(false); + inner_product_param->mutable_weight_filler()->set_type("uniform"); + shared_ptr > layer( + new InnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + gemv(layer->blobs(), 0, layer->blobs()[0]->shape(0), + layer->blobs()[0]->shape(1), + blob_bottom, 0, + this->MakeReferenceTop(blob_top), 0, 1., 0.); + + const Dtype* data = blob_top->cpu_data(); + const Dtype* ref_data = this->ref_blob_top_->cpu_data(); + const int_tp count = blob_top->count(); + for (int_tp i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], ref_data[i], 1e-1); + } + + Timer timer; + timer.initted(); + timer.Start(); + for (uint i = 0; i < 100; ++i) { + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + } + timer.Stop(); + float elapsedTime = timer.MilliSeconds(); + std::cout << "GEMV(4096x4096) Time is: " << elapsedTime / 100.f + <<" ms" << std::endl; + delete blob_bottom; + delete blob_top; +} + +TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC8) { + typedef typename TypeParam::Dtype Dtype; + + Blob* const blob_bottom = new Blob(1, 4096, 1, 1); + Blob* const blob_top = new Blob(); + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(blob_bottom); + + this->blob_bottom_vec_.clear(); + this->blob_bottom_vec_.push_back(blob_bottom); + this->blob_top_vec_.clear(); + this->blob_top_vec_.push_back(blob_top); + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(1000); + inner_product_param->set_bias_term(false); + inner_product_param->mutable_weight_filler()->set_type("uniform"); + shared_ptr > layer( + new InnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + gemv(layer->blobs(), 0, layer->blobs()[0]->shape(0), + layer->blobs()[0]->shape(1), + blob_bottom, 0, + this->MakeReferenceTop(blob_top), 0, 1., 0.); + + const Dtype* data = blob_top->cpu_data(); + const Dtype* ref_data = this->ref_blob_top_->cpu_data(); + const int_tp count = blob_top->count(); + for (int_tp i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], ref_data[i], 1e-1); + } + Timer timer; + timer.initted(); + timer.Start(); + for (uint i = 0; i < 100; ++i) { + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + } + timer.Stop(); + float elapsedTime = timer.MilliSeconds(); + std::cout << "GEMV(1000x4096) Time is: " << elapsedTime / 100.f + <<" ms" << std::endl; + + delete blob_bottom; + delete blob_top; +} + +TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC_dev1) { + typedef typename TypeParam::Dtype Dtype; + + Blob* const blob_bottom = new Blob(1, 4099, 1, 1); + Blob* const blob_top = new Blob(); + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(blob_bottom); + + this->blob_bottom_vec_.clear(); + this->blob_bottom_vec_.push_back(blob_bottom); + this->blob_top_vec_.clear(); + this->blob_top_vec_.push_back(blob_top); + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(1003); + inner_product_param->set_bias_term(false); + inner_product_param->mutable_weight_filler()->set_type("uniform"); + shared_ptr > layer( + new InnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + gemv(layer->blobs(), 0, layer->blobs()[0]->shape(0), + layer->blobs()[0]->shape(1), + blob_bottom, 0, + this->MakeReferenceTop(blob_top), 0, 1., 0.); + + const Dtype* data = blob_top->cpu_data(); + const Dtype* ref_data = this->ref_blob_top_->cpu_data(); + const int_tp count = blob_top->count(); + for (int_tp i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], ref_data[i], 1e-1); + } + + Timer timer; + timer.initted(); + timer.Start(); + for (uint i = 0; i < 100; ++i) { + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + } + timer.Stop(); + float elapsedTime = timer.MilliSeconds(); + std::cout << "GEMV(1003x4099) Time is: " << elapsedTime / 100.f + <<" ms" << std::endl; + delete blob_bottom; + delete blob_top; +} + +TYPED_TEST(InnerProductLayerTest, TestGEMV) { + typedef typename TypeParam::Dtype Dtype; + if (Caffe::mode() == Caffe::GPU) { + + Blob* const blob_bottom = new Blob(1, 4099, 1, 1); + Blob* const blob_top = new Blob(); + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(blob_bottom); + + this->blob_bottom_vec_.clear(); + this->blob_bottom_vec_.push_back(blob_bottom); + this->blob_top_vec_.clear(); + this->blob_top_vec_.push_back(blob_top); + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(1003); + inner_product_param->set_bias_term(false); + inner_product_param->mutable_weight_filler()->set_type("uniform"); + shared_ptr > layer( + new InnerProductLayer(layer_param)); + + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + device *dc = Caffe::GetDefaultDevice(); + const Dtype* x = this->blob_bottom_vec_[0]->gpu_data(); + Dtype* y = this->blob_top_vec_[0]->mutable_gpu_data(); + Dtype alpha = 2.5; + Dtype beta = 2; + unsigned int M = layer->blobs()[0]->shape(0); + unsigned int N = layer->blobs()[0]->shape(1); + //add offset + unsigned int offA = M * N / 2; + unsigned int offx = 0; + unsigned int offy = M / 2; + M /= 2; + greentea_gpu_gemv(dc->id(), CblasNoTrans, + M, N, + alpha, + (cl_mem)layer->blobs()[0]->gpu_data(), offA, (cl_mem)x, + offx, beta, (cl_mem)y, + offy); + gemv(layer->blobs(), offA, M, N, + blob_bottom, offx, this->MakeReferenceTop(blob_top), offy, alpha, beta); + + const Dtype* data = blob_top->cpu_data(); + const Dtype* ref_data = this->ref_blob_top_->cpu_data(); + const int_tp count = blob_top->count(); + for (int_tp i = offy; i < count; ++i) { + EXPECT_NEAR(data[i], ref_data[i], 1e-1); + } + + delete blob_bottom; + delete blob_top; + } +} /** * @brief Init. an IP layer without transpose + random weights, * run Forward, save the result. From 6abe23d940179443bda4b87b4ffad9c1ba19724b Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 03:40:32 +0800 Subject: [PATCH 07/33] Optimize BN layer for inference only. Signed-off-by: Zhigang Gong --- src/caffe/greentea/cl_kernels/batch_norm.cl | 58 ++++++++---- src/caffe/layers/batch_norm_layer.cu | 132 +++++++++++++++++++++------- 2 files changed, 142 insertions(+), 48 deletions(-) diff --git a/src/caffe/greentea/cl_kernels/batch_norm.cl b/src/caffe/greentea/cl_kernels/batch_norm.cl index b8c5365eb93..08d0ddeff53 100644 --- a/src/caffe/greentea/cl_kernels/batch_norm.cl +++ b/src/caffe/greentea/cl_kernels/batch_norm.cl @@ -2,11 +2,12 @@ #include "header.cl" #endif -__kernel void TEMPLATE(batch_norm_use_global_stats_in_place,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, - const Dtype scale, const Dtype eps, +Dtype TEMPLATE(bn_common,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, + const Dtype scale, const Dtype eps, __global const Dtype* mean, __global const Dtype* variance, - __global Dtype* top) { + __global const Dtype* data, + int_tp *out_off) { const int_tp idx_num = get_global_id(0); const int_tp idx_chans = get_global_id(1); const int_tp idx_spatial_dim = get_global_id(2); @@ -17,26 +18,49 @@ __kernel void TEMPLATE(batch_norm_use_global_stats_in_place,Dtype)(const int_tp m = -scale * m; v = (Dtype)native_powr((float)mad(scale, v, eps), (float)-0.5); - const int_tp out_off = (idx_num * channels + idx_chans) * spatial_dim + idx_spatial_dim; - top[out_off] = v * (top[out_off] + m); + *out_off = (idx_num * channels + idx_chans) * spatial_dim + idx_spatial_dim; + return (v * (data[*out_off] + m)); } -__kernel void TEMPLATE(batch_norm_use_global_stats,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, - const Dtype scale, const Dtype eps, + +__kernel void TEMPLATE(bn_use_global_stats_in_place,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, + const Dtype scale, const Dtype eps, __global const Dtype* mean, __global const Dtype* variance, - __global const Dtype* bottom, __global Dtype* top) { - const int_tp idx_num = get_global_id(0); - const int_tp idx_chans = get_global_id(1); - const int_tp idx_spatial_dim = get_global_id(2); + int_tp out_off; + Dtype val = TEMPLATE(bn_common,Dtype)(num, channels, spatial_dim, scale, eps, mean, variance, top, &out_off); + top[out_off] = val; +} - Dtype m = mean[idx_chans]; - Dtype v = variance[idx_chans]; +__kernel void TEMPLATE(bn_use_global_stats_in_place_fused_relu,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, + const Dtype scale, const Dtype eps, + __global const Dtype* mean, + __global const Dtype* variance, + __global Dtype* top) { + int_tp out_off; + Dtype val = TEMPLATE(bn_common,Dtype)(num, channels, spatial_dim, scale, eps, mean, variance, top, &out_off); + top[out_off] = val > 0.0f ? val : 0.0f; +} - m = -scale * m; - v = (Dtype)native_powr((float)mad(scale, v, eps), (float)-0.5); +__kernel void TEMPLATE(bn_use_global_stats,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, + const Dtype scale, const Dtype eps, + __global const Dtype* mean, + __global const Dtype* variance, + __global const Dtype* bottom, + __global Dtype* top) { + int_tp out_off; + Dtype val = TEMPLATE(bn_common,Dtype)(num, channels, spatial_dim, scale, eps, mean, variance, bottom, &out_off); + top[out_off] = val; +} - const int_tp out_off = (idx_num * channels + idx_chans) * spatial_dim + idx_spatial_dim; - top[out_off] = v * (bottom[out_off] + m); +__kernel void TEMPLATE(bn_use_global_stats_fused_relu,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, + const Dtype scale, const Dtype eps, + __global const Dtype* mean, + __global const Dtype* variance, + __global const Dtype* bottom, + __global Dtype* top) { + int_tp out_off; + Dtype val = TEMPLATE(bn_common,Dtype)(num, channels, spatial_dim, scale, eps, mean, variance, bottom, &out_off); + top[out_off] = val > 0.0f ? val : 0.0f; } diff --git a/src/caffe/layers/batch_norm_layer.cu b/src/caffe/layers/batch_norm_layer.cu index 2ebb00ec950..3fbf1f0bbff 100644 --- a/src/caffe/layers/batch_norm_layer.cu +++ b/src/caffe/layers/batch_norm_layer.cu @@ -6,6 +6,16 @@ namespace caffe { +#define SET_COMMON_KERNEL_PARAMS \ + oclk_bn_use_global_stats.arg(argIdx++, num); \ + oclk_bn_use_global_stats.arg(argIdx++, channels_); \ + oclk_bn_use_global_stats.arg(argIdx++, spatial_dim); \ + oclk_bn_use_global_stats.arg(argIdx++, scale_factor); \ + oclk_bn_use_global_stats.arg(argIdx++, eps_); \ + oclk_bn_use_global_stats.arg(argIdx++, WrapHandle((cl_mem) this->blobs_[0]->gpu_data(), &ctx)); \ + oclk_bn_use_global_stats.arg(argIdx++, WrapHandle((cl_mem) this->blobs_[1]->gpu_data(), &ctx)); + + template void BatchNormLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { @@ -95,35 +105,94 @@ void BatchNormLayer::Forward_gpu(const vector*>& bottom, viennacl::ocl::context &ctx = viennacl::ocl::get_context( this->device_->id()); - if (bottom[0] != top[0]) { - greentea_copy(bottom[0]->count(), (cl_mem) bottom_data, 0, - (cl_mem) top_data, 0, &ctx); - } - - if (use_global_stats_) { - // use the stored mean/variance estimates. - const Dtype scale_factor = this->blobs_[2]->cpu_data()[0] == 0 ? - 0 : 1 / this->blobs_[2]->cpu_data()[0]; - greentea_gpu_scale(this->device_->id(), variance_.count(), scale_factor, - (cl_mem) (this->blobs_[0]->gpu_data()), 0, - (cl_mem) (mean_.mutable_gpu_data()), 0); - greentea_gpu_scale(this->device_->id(), variance_.count(), scale_factor, - (cl_mem) (this->blobs_[1]->gpu_data()), 0, - (cl_mem) (variance_.mutable_gpu_data()), 0); + if (use_global_stats_) { + const Dtype scale_factor = + this->blobs_[2]->cpu_data()[0] == 0 ? + 0 : 1 / this->blobs_[2]->cpu_data()[0]; + + viennacl::ocl::program &program = this->device_->program(); + + bool fused_relu = this->layer_param_.batch_norm_param().fused_relu(); + + cl_uint argIdx = 0; + size_t global_work_size_[3] = {(size_t)num, + (size_t)channels_, + (size_t)spatial_dim}; + if (bottom[0] == top[0]) { + if (fused_relu) { + viennacl::ocl::kernel &oclk_bn_use_global_stats = program.get_kernel( + CL_KERNEL_SELECT("bn_use_global_stats_in_place_fused_relu")); + + SET_COMMON_KERNEL_PARAMS + + oclk_bn_use_global_stats.arg(argIdx++, + WrapHandle((cl_mem) top_data, + &ctx)); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_bn_use_global_stats.handle().get(), 3, NULL, + global_work_size_, NULL, 0, NULL, NULL)); + } + else { + viennacl::ocl::kernel &oclk_bn_use_global_stats = program.get_kernel( + CL_KERNEL_SELECT("bn_use_global_stats_in_place")); + + SET_COMMON_KERNEL_PARAMS + + oclk_bn_use_global_stats.arg(argIdx++, WrapHandle((cl_mem) top_data, &ctx)); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_bn_use_global_stats.handle().get(), 3, NULL, global_work_size_, NULL, 0, NULL, NULL)); + } } else { + if (fused_relu) { + viennacl::ocl::kernel &oclk_bn_use_global_stats = + program.get_kernel( + CL_KERNEL_SELECT("bn_use_global_stats_fused_relu")); + + SET_COMMON_KERNEL_PARAMS + + oclk_bn_use_global_stats.arg(argIdx++, + WrapHandle((cl_mem) bottom_data, &ctx)); + oclk_bn_use_global_stats.arg(argIdx++, + WrapHandle((cl_mem) top_data, &ctx)); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_bn_use_global_stats.handle().get(), 3, NULL, + global_work_size_, NULL, 0, NULL, NULL)); + } + else { + viennacl::ocl::kernel &oclk_bn_use_global_stats = + program.get_kernel(CL_KERNEL_SELECT("bn_use_global_stats")); + + SET_COMMON_KERNEL_PARAMS + + oclk_bn_use_global_stats.arg(argIdx++, + WrapHandle((cl_mem) bottom_data, &ctx)); + oclk_bn_use_global_stats.arg(argIdx++, + WrapHandle((cl_mem) top_data, &ctx)); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_bn_use_global_stats.handle().get(), 3, NULL, + global_work_size_, NULL, 0, NULL, NULL)); + } + } + } else { // compute mean - greentea_gpu_gemv(this->device_->id(), - CblasNoTrans, channels_ * num, spatial_dim, - 1. / (num * spatial_dim), (cl_mem) bottom_data, 0, - (cl_mem) (spatial_sum_multiplier_.gpu_data()), 0, 0., - (cl_mem) (num_by_chans_.mutable_gpu_data()), 0); - greentea_gpu_gemv(this->device_->id(), - CblasTrans, num, channels_, 1., - (cl_mem) (num_by_chans_.gpu_data()), 0, - (cl_mem) (batch_sum_multiplier_.gpu_data()), 0, 0., - (cl_mem) (mean_.mutable_gpu_data()), 0); + if (bottom[0] != top[0]) { + greentea_copy(bottom[0]->count(), (cl_mem) bottom_data, 0, + (cl_mem) top_data, 0, &ctx); } + + greentea_gpu_gemv(this->device_->id(), + CblasNoTrans, channels_ * num, spatial_dim, + 1. / (num * spatial_dim), (cl_mem) bottom_data, 0, + (cl_mem) (spatial_sum_multiplier_.gpu_data()), 0, 0., + (cl_mem) (num_by_chans_.mutable_gpu_data()), 0); + greentea_gpu_gemv(this->device_->id(), + CblasTrans, num, channels_, 1., + (cl_mem) (num_by_chans_.gpu_data()), 0, + (cl_mem) (batch_sum_multiplier_.gpu_data()), 0, 0., + (cl_mem) (mean_.mutable_gpu_data()), 0); + + // subtract mean greentea_gpu_gemm(this->device_->id(), CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, @@ -170,10 +239,10 @@ void BatchNormLayer::Forward_gpu(const vector*>& bottom, // normalize variance greentea_gpu_add_scalar(this->device_->id(), variance_.count(), - eps_, (cl_mem) (variance_.mutable_gpu_data()), 0); + eps_, (cl_mem) (variance_.mutable_gpu_data()), 0); greentea_gpu_sqrt(this->device_->id(), variance_.count(), - (cl_mem) (variance_.gpu_data()), 0, - (cl_mem) (variance_.mutable_gpu_data()), 0); + (cl_mem) (variance_.gpu_data()), 0, + (cl_mem) (variance_.mutable_gpu_data()), 0); // replicate variance to input size greentea_gpu_gemm(this->device_->id(), CblasNoTrans, CblasNoTrans, @@ -187,13 +256,14 @@ void BatchNormLayer::Forward_gpu(const vector*>& bottom, (cl_mem) (spatial_sum_multiplier_.gpu_data()), 0, 0., (cl_mem) (temp_.mutable_gpu_data()), 0); greentea_gpu_div(this->device_->id(), temp_.count(), - (cl_mem) top_data, 0, (cl_mem) (temp_.gpu_data()), 0, - (cl_mem) top_data, 0); + (cl_mem) top_data, 0, (cl_mem) (temp_.gpu_data()), 0, + (cl_mem) top_data, 0); // TODO(cdoersch): The caching is only needed // because later in-place layers might clobber the data. // Can we skip this if they won't? greentea_copy(x_norm_.count(), (cl_mem)top_data, 0, - (cl_mem) (x_norm_.mutable_gpu_data()), 0, &ctx); + (cl_mem) (x_norm_.mutable_gpu_data()), 0, &ctx); + } #endif // USE_GREENTEA } } From 052c332e4af9a65717405aa5c348ccd60a60e91d Mon Sep 17 00:00:00 2001 From: "Richman, Reuven" Date: Wed, 7 Sep 2016 16:15:37 +0300 Subject: [PATCH 08/33] softmax layer cpu fwd - no need to max values with themselves --- src/caffe/layers/softmax_layer.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp index bde200c82e1..5015c84fe44 100644 --- a/src/caffe/layers/softmax_layer.cpp +++ b/src/caffe/layers/softmax_layer.cpp @@ -43,7 +43,8 @@ void SoftmaxLayer::Forward_cpu(const vector*>& bottom, for (int_tp i = 0; i < outer_num_; ++i) { // initialize scale_data to the first plane caffe_cpu_copy(inner_num_, bottom_data + i * dim, scale_data); - for (int_tp j = 0; j < channels; j++) { + // start max after the first inner_num values (j=1) since they were just copied + for (int_tp j = 1; j < channels; j++) { for (int_tp k = 0; k < inner_num_; k++) { scale_data[k] = std::max(scale_data[k], bottom_data[i * dim + j * inner_num_ + k]); From 0e4994aa19c01b1cd609ba839e465420df0b4e90 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 03:50:53 +0800 Subject: [PATCH 10/33] Add new lt option to caffe tool. By default, we will not measure per layer timing now. As the per layer timing measurement brings siginificant overhead for many net models. Signed-off-by: Zhigang Gong --- tools/caffe.cpp | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/tools/caffe.cpp b/tools/caffe.cpp index d0bd5c2f9fb..6ceaf94fe46 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -61,6 +61,9 @@ DEFINE_string(sigint_effect, "stop", DEFINE_string(sighup_effect, "snapshot", "Optional; action to take when a SIGHUP signal is received: " "snapshot, stop or none."); +DEFINE_bool(lt, false, + "Optional; enable per layer timings"); + // A simple registry for caffe commands. typedef int (*BrewFunction)(); @@ -430,16 +433,24 @@ int time() { std::vector backward_time_per_layer(layers.size(), 0.0); double forward_time = 0.0; double backward_time = 0.0; + for (int_tp j = 0; j < FLAGS_iterations; ++j) { Timer iter_timer; iter_timer.Start(); forward_timer.Start(); for (int_tp i = 0; i < layers.size(); ++i) { - timer.Start(); + if (FLAGS_lt) { + timer.Start(); + } + layers[i]->Forward(bottom_vecs[i], top_vecs[i]); - Caffe::Synchronize(Caffe::GetDefaultDevice()->id()); - forward_time_per_layer[i] += timer.MicroSeconds(); + + if (FLAGS_lt) { + Caffe::Synchronize(Caffe::GetDefaultDevice()->id()); + forward_time_per_layer[i] += timer.MicroSeconds(); + } } + Caffe::Synchronize(Caffe::GetDefaultDevice()->id()); forward_time += forward_timer.MicroSeconds(); if (phase == caffe::TRAIN) { backward_timer.Start(); @@ -455,15 +466,18 @@ int time() { LOG(INFO) << "Iteration: " << j + 1 << " forward-backward time: " << iter_timer.MilliSeconds() << " ms."; } - LOG(INFO) << "Average time per layer: "; - for (int_tp i = 0; i < layers.size(); ++i) { - const caffe::string& layername = layers[i]->layer_param().name(); - LOG(INFO) << std::setfill(' ') << std::setw(10) << layername << - "\tforward: " << forward_time_per_layer[i] / 1000 / - FLAGS_iterations << " ms."; - LOG(INFO) << std::setfill(' ') << std::setw(10) << layername << - "\tbackward: " << backward_time_per_layer[i] / 1000 / - FLAGS_iterations << " ms."; + + if (FLAGS_lt) { + LOG(INFO) << "Average time per layer: "; + for (int_tp i = 0; i < layers.size(); ++i) { + const caffe::string& layername = layers[i]->layer_param().name(); + LOG(INFO) << std::setfill(' ') << std::setw(10) << layername << + "\tforward: " << forward_time_per_layer[i] / 1000 / + FLAGS_iterations << " ms."; + LOG(INFO) << std::setfill(' ') << std::setw(10) << layername << + "\tbackward: " << backward_time_per_layer[i] / 1000 / + FLAGS_iterations << " ms."; + } } total_timer.Stop(); LOG(INFO) << "Average Forward pass: " << forward_time / 1000 / From 89f631501eaacf239aefe144e5d54231150be55e Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 03:53:56 +0800 Subject: [PATCH 11/33] Use explicit constant value type rather than the default double type. For those compilers don't support double type, use implicit double type constant may bring some issues. Signed-off-by: Zhigang Gong --- src/caffe/greentea/cl_kernels/activation.cl | 12 ++++++------ src/caffe/greentea/cl_kernels/bnll.cl | 4 ++-- src/caffe/greentea/cl_kernels/contrastive_loss.cl | 2 +- src/caffe/greentea/cl_kernels/dropout.cl | 4 ++-- src/caffe/greentea/cl_kernels/elu.cl | 2 +- src/caffe/greentea/cl_kernels/solvers.cl | 4 ++-- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/caffe/greentea/cl_kernels/activation.cl b/src/caffe/greentea/cl_kernels/activation.cl index 2a1a1c1ddec..6f0eaedc4e1 100644 --- a/src/caffe/greentea/cl_kernels/activation.cl +++ b/src/caffe/greentea/cl_kernels/activation.cl @@ -18,7 +18,7 @@ __kernel void TEMPLATE(relu_backward,Dtype)(const int_tp n, Dtype negative_slope) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { out_diff[index] = in_diff[index] - * ((in_data[index] > 0?1.0:0.0) + (in_data[index] <= 0?1.0:0.0) * negative_slope); + * ((Dtype)(in_data[index] > 0?1.0:0.0) + (Dtype)(in_data[index] <= 0?1.0:0.0) * negative_slope); } } @@ -36,7 +36,7 @@ __kernel void TEMPLATE(tanh_backward,Dtype)(const int_tp n, __global Dtype* out_diff) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { Dtype tanhx = out_data[index]; - out_diff[index] = in_diff[index] * (1 - tanhx * tanhx); + out_diff[index] = in_diff[index] * ((Dtype)1 - tanhx * tanhx); } } @@ -44,7 +44,7 @@ __kernel void TEMPLATE(sigmoid_forward,Dtype)(const int_tp n, __global const Dtype* in, __global Dtype* out) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { - out[index] = 1.0 / (1.0 + exp(-in[index])); + out[index] = (Dtype)1.0 / ((Dtype)1.0 + exp(-in[index])); } } @@ -54,7 +54,7 @@ __kernel void TEMPLATE(sigmoid_backward,Dtype)(const int_tp n, __global Dtype* out_diff) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { const Dtype sigmoid_x = out_data[index]; - out_diff[index] = in_diff[index] * sigmoid_x * (1 - sigmoid_x); + out_diff[index] = in_diff[index] * sigmoid_x * ((Dtype)1 - sigmoid_x); } } @@ -98,11 +98,11 @@ __kernel void TEMPLATE(prelu_param_backward,Dtype)(const int_tp n, const int_tp __global const Dtype* in_data, __global Dtype* out_diff) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { - out_diff[index] = in_diff[index] * in_data[index] * (in_data[index] <= 0?1.0:0.0); + out_diff[index] = in_diff[index] * in_data[index] * (Dtype)(in_data[index] <= 0?1.0:0.0); for (int k = 1; k < rows; k++) { out_diff[index] += in_diff[index + k * rowPitch] * in_data[index + k * rowPitch] - * (in_data[index + k * rowPitch] <= 0?1.0:0.0); + * (Dtype)(in_data[index + k * rowPitch] <= 0?1.0:0.0); } } } diff --git a/src/caffe/greentea/cl_kernels/bnll.cl b/src/caffe/greentea/cl_kernels/bnll.cl index a385484e857..2a23389a942 100644 --- a/src/caffe/greentea/cl_kernels/bnll.cl +++ b/src/caffe/greentea/cl_kernels/bnll.cl @@ -7,7 +7,7 @@ __kernel void TEMPLATE(bnll_forward,Dtype)(const int_tp n, __global Dtype* out) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { if (in[index] > 0.0f) { - out[index] = in[index] + log((Dtype) (1.0 + exp(-in[index]))); + out[index] = in[index] + log((Dtype) ((Dtype)1.0 + exp(-in[index]))); } else { out[index] = log((Dtype) (1.0 + exp(in[index]))); } @@ -21,6 +21,6 @@ __kernel void TEMPLATE(bnll_backward,Dtype)(const int_tp n, Dtype kBNLL_THRESHOLD = 50.; for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { Dtype expval = exp(min(in_data[index], kBNLL_THRESHOLD)); - out_diff[index] = in_diff[index] * expval / (expval + 1.); + out_diff[index] = in_diff[index] * expval / (expval + (Dtype)1.); } } diff --git a/src/caffe/greentea/cl_kernels/contrastive_loss.cl b/src/caffe/greentea/cl_kernels/contrastive_loss.cl index 73141d472be..bb71dfec8b0 100644 --- a/src/caffe/greentea/cl_kernels/contrastive_loss.cl +++ b/src/caffe/greentea/cl_kernels/contrastive_loss.cl @@ -16,7 +16,7 @@ __kernel void TEMPLATE(cll_backward,Dtype)(const int_tp count, const int_tp chan Dtype beta = 0.; Dtype dist = sqrt(dist_sq[n]); mdist = (margin - dist); - beta = -alpha * mdist / (dist + 1e-4) * diff[i]; + beta = -alpha * mdist / (dist + (Dtype)1e-4) * diff[i]; if (mdist > 0.) { bottom_diff[i] = beta; } else { diff --git a/src/caffe/greentea/cl_kernels/dropout.cl b/src/caffe/greentea/cl_kernels/dropout.cl index a3debfa6d52..103ab889c56 100644 --- a/src/caffe/greentea/cl_kernels/dropout.cl +++ b/src/caffe/greentea/cl_kernels/dropout.cl @@ -9,7 +9,7 @@ __kernel void TEMPLATE(dropout_forward,Dtype)(const int_tp n, const Dtype scale, __global Dtype* out) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { - out[index] = in[index] * ((mask[index] > threshold)?1.0:0.0) * scale; + out[index] = in[index] * (Dtype)((mask[index] > threshold)?1.0:0.0) * scale; } } @@ -19,6 +19,6 @@ __kernel void TEMPLATE(dropout_backward,Dtype)( const Dtype scale, __global Dtype* out_diff) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { - out_diff[index] = in_diff[index] * ((mask[index] > threshold)?1.0:0.0) * scale; + out_diff[index] = in_diff[index] * (Dtype)((mask[index] > threshold)?1.0:0.0) * scale; } } diff --git a/src/caffe/greentea/cl_kernels/elu.cl b/src/caffe/greentea/cl_kernels/elu.cl index 08cd0e38bd5..0e3ef6f0d8c 100644 --- a/src/caffe/greentea/cl_kernels/elu.cl +++ b/src/caffe/greentea/cl_kernels/elu.cl @@ -6,7 +6,7 @@ __kernel void TEMPLATE(elu_forward,Dtype)(const int n, __global const Dtype* in, __global Dtype* out, Dtype alpha) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { - out[index] = in[index] > 0 ? in[index] : alpha * (exp(in[index]) - 1.0); + out[index] = in[index] > 0 ? in[index] : alpha * (exp(in[index]) - (Dtype)1.0); } } diff --git a/src/caffe/greentea/cl_kernels/solvers.cl b/src/caffe/greentea/cl_kernels/solvers.cl index 5e5ca0cc57a..7d792cd9d5a 100644 --- a/src/caffe/greentea/cl_kernels/solvers.cl +++ b/src/caffe/greentea/cl_kernels/solvers.cl @@ -10,9 +10,9 @@ __kernel void TEMPLATE(ada_delta_update,Dtype)(int_tp N, __global Dtype* g, Dtype local_rate) { for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) { Dtype gi = g[i]; - Dtype hi = h[i] = momentum * h[i] + (1.0 - momentum) * gi * gi; + Dtype hi = h[i] = momentum * h[i] + ((Dtype)1.0 - momentum) * gi * gi; gi = gi * sqrt((h2[i] + delta) / (hi + delta)); - h2[i] = momentum * h2[i] + (1.0 - momentum) * gi * gi; + h2[i] = momentum * h2[i] + ((Dtype)1.0 - momentum) * gi * gi; g[i] = local_rate * gi; } } From 4eaf1087cd58465cb8f010b4a00af96a3655ba10 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 07:36:00 +0800 Subject: [PATCH 12/33] Fix a bug in inner product layer. This bug is caused by two hacky cases: 1. SharedWeights is called from RNN layer's forwarding only path which may change some inner product layer's weights data with a hacky way after the layer setup. 2. The LSTM's gradient test case set the phase to TEST, but it will call into inner product's backward path. Signed-off-by: Zhigang Gong --- include/caffe/layers/inner_product_layer.hpp | 2 ++ src/caffe/layers/inner_product_layer.cpp | 4 ++++ src/caffe/layers/inner_product_layer.cu | 10 ++++++++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/include/caffe/layers/inner_product_layer.hpp b/include/caffe/layers/inner_product_layer.hpp index fa639a8c3f5..0c4bd95038b 100644 --- a/include/caffe/layers/inner_product_layer.hpp +++ b/include/caffe/layers/inner_product_layer.hpp @@ -57,6 +57,8 @@ class InnerProductLayer : public Layer { bool transpose_; ///< if true, assume transposed weights #ifdef USE_GREENTEA cl_mem weight_image_; + const SyncedMemory * copied_weight_data_; + bool test_only_; #endif }; diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index 70711887223..cef42443b0e 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -75,9 +75,13 @@ void InnerProductLayer::LayerSetUp(const vector*>& bottom, false, !transpose_, true, padded_height, padded_width, height, width, (int)0, NULL, NULL); + copied_weight_data_ = this->blobs_[0]->data().get(); } + } else { + copied_weight_data_ = NULL; } + test_only_ = this->phase_ == TEST; } template diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index 744039fa7a2..c6e98bec391 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -12,7 +12,6 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); const Dtype* weight = this->blobs_[0]->gpu_data(); - if (this->device_->backend() == BACKEND_CUDA) { #ifdef USE_CUDA if (M_ == 1) { @@ -55,16 +54,21 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, K_ <= max_image_size && std::is_same::value && this->device_->CheckCapability("cl_intel_subgroups")) { - if (this->phase_ != TEST) { + if (!test_only_ || copied_weight_data_ != this->blobs_[0]->data().get()) { int height = !transpose_ ? N_ : K_; int width = !transpose_ ? K_ : N_; int padded_height = !transpose_ ? height : (height + ((height & 7) ? 1 : 0)); int padded_width = !transpose_ ? width : (width + ((width & 7) ? 1 : 0)); + if (weight_image_) { + clReleaseMemObject((cl_mem)weight_image_); + weight_image_ = NULL; + } greentea_gpu_gemm_copy_buffer_to_image(this->device_->id(), &weight_image_, (cl_mem) weight, 0, false, !transpose_, true, padded_height, padded_width, height, width, (int)0, NULL, NULL); + copied_weight_data_ = this->blobs_[0]->data().get(); } greentea_gpu_gemm(this->device_->id(), CblasNoTrans, transpose_ ? CblasNoTrans : CblasTrans, @@ -93,6 +97,8 @@ template void InnerProductLayer::Backward_gpu( const vector*>& top, const vector& propagate_down, const vector*>& bottom) { + + test_only_ = false; if (this->device_->backend() == BACKEND_CUDA) { #ifdef USE_CUDA if (this->param_propagate_down_[0]) { From 742d803295f9ef57a34407db2a0cbb174c8aa528 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sat, 1 Apr 2017 08:42:17 +0800 Subject: [PATCH 13/33] Reduce the maximum block size for spatial convolution engine. Signed-off-by: Zhigang Gong --- src/caffe/layers/conv_layer_spatial.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/caffe/layers/conv_layer_spatial.cpp b/src/caffe/layers/conv_layer_spatial.cpp index da6f824fe84..897f28db741 100644 --- a/src/caffe/layers/conv_layer_spatial.cpp +++ b/src/caffe/layers/conv_layer_spatial.cpp @@ -1221,7 +1221,7 @@ void ConvolutionLayerSpatial::setup_convolution( if (simd_size == 8) { width_max = 16; height_max = 16; - block_size_max = 64; + block_size_max = 48; } else { width_max = 14; height_max = 14; From db3da6eacaed5ada7afc60631a753e99e491294a Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Wed, 5 Apr 2017 03:50:52 +0800 Subject: [PATCH 14/33] Simplify IDLF kernel's output logic. As we are using dynamic image size now, no need to use the last block width and height. This patch could fix some of the performance regression caused by the dynamic image size change. But still have some performance gap between the previous constant image size version. Signed-off-by: Zhigang Gong --- src/caffe/greentea/cl_kernels.cpp | 37 ++-------------------- .../greentea/cl_kernels/conv_layer_spatial.cl | 37 ++-------------------- src/caffe/layers/conv_layer_spatial.cpp | 6 ---- 3 files changed, 6 insertions(+), 74 deletions(-) diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index 474ce231e3b..7eb89544046 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -635,9 +635,7 @@ static std::vector> cl_kernels{ "const ushort input_width,", // NOLINT "const ushort input_height,", // NOLINT "const ushort output_width,", // NOLINT -"const ushort output_height,", // NOLINT -"const ushort last_block_width,", // NOLINT -"const ushort last_block_height)", // NOLINT +"const ushort output_height)", // NOLINT "{", // NOLINT "__global float* outputs = outputs_base;", // NOLINT "__global float* inputs = inputs_base;", // NOLINT @@ -801,43 +799,14 @@ static std::vector> cl_kernels{ "uint_tp out_addr = OUT_BUFF_OFFSET + ( num_in_batch * TOTAL_OUTPUT_DEPTH + fm ) * output_width * output_height;", // NOLINT "out_addr += or * output_width + oc;", // NOLINT "float bias = biases[(fm % ALIGNED_NUM_FILTERS)];", // NOLINT -"#ifndef WRITE_PADDED_VALUES", // NOLINT -"if (or + OUT_BLOCK_HEIGHT < output_height &&", // NOLINT -"oc + OUT_BLOCK_WIDTH < output_width)", // NOLINT -"{", // NOLINT -"#endif", // NOLINT "for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) {", // NOLINT +"if (r + or >= output_height) break;", // NOLINT "for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) {", // NOLINT +"if (c + oc >= output_width) break;", // NOLINT "// this does a scattered write to SIMD_SIZE different feature maps, so that data within one map is contiguous, thus ready for input to next layer.", // NOLINT "ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]);", // NOLINT "}", // NOLINT "}", // NOLINT -"#ifndef WRITE_PADDED_VALUES", // NOLINT -"} else if ( or + OUT_BLOCK_HEIGHT < output_height )", // NOLINT -"{", // NOLINT -"for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) {", // NOLINT -"for(uint_tp c = 0; c < last_block_width; c++) {", // NOLINT -"ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]);", // NOLINT -"}", // NOLINT -"}", // NOLINT -"}", // NOLINT -"else if ( oc + OUT_BLOCK_WIDTH < output_width )", // NOLINT -"{", // NOLINT -"for(uint_tp r = 0; r < last_block_height; r++) {", // NOLINT -"for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) {", // NOLINT -"ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]);", // NOLINT -"}", // NOLINT -"}", // NOLINT -"}", // NOLINT -"else", // NOLINT -"{", // NOLINT -"for(uint_tp r = 0; r < last_block_height; r++) {", // NOLINT -"for(uint_tp c = 0; c < last_block_width; c++) {", // NOLINT -"ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]);", // NOLINT -"}", // NOLINT -"}", // NOLINT -"}", // NOLINT -"#endif //#ifndef WRITE_PADDED_VALUES", // NOLINT "}", // NOLINT "}", // NOLINT "#endif", // NOLINT diff --git a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl index cf9912c1838..77583f3a434 100644 --- a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl +++ b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl @@ -138,9 +138,7 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo const ushort input_width, const ushort input_height, const ushort output_width, - const ushort output_height, - const ushort last_block_width, - const ushort last_block_height) + const ushort output_height) { __global float* outputs = outputs_base; __global float* inputs = inputs_base; @@ -304,43 +302,14 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo uint_tp out_addr = OUT_BUFF_OFFSET + ( num_in_batch * TOTAL_OUTPUT_DEPTH + fm ) * output_width * output_height; out_addr += or * output_width + oc; float bias = biases[(fm % ALIGNED_NUM_FILTERS)]; -#ifndef WRITE_PADDED_VALUES - if (or + OUT_BLOCK_HEIGHT < output_height && - oc + OUT_BLOCK_WIDTH < output_width) - { -#endif for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) { + if (r + or >= output_height) break; for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) { + if (c + oc >= output_width) break; // this does a scattered write to SIMD_SIZE different feature maps, so that data within one map is contiguous, thus ready for input to next layer. ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]); } } -#ifndef WRITE_PADDED_VALUES - } else if ( or + OUT_BLOCK_HEIGHT < output_height ) - { - for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) { - for(uint_tp c = 0; c < last_block_width; c++) { - ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]); - } - } - } - else if ( oc + OUT_BLOCK_WIDTH < output_width ) - { - for(uint_tp r = 0; r < last_block_height; r++) { - for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) { - ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]); - } - } - } - else - { - for(uint_tp r = 0; r < last_block_height; r++) { - for(uint_tp c = 0; c < last_block_width; c++) { - ACTIVATION_FUNCTION(outputs, out_addr + r * output_width + c, bias + out[r * OUT_BLOCK_WIDTH + c]); - } - } - } -#endif //#ifndef WRITE_PADDED_VALUES } } #endif diff --git a/src/caffe/layers/conv_layer_spatial.cpp b/src/caffe/layers/conv_layer_spatial.cpp index 897f28db741..f3bbb4a5799 100644 --- a/src/caffe/layers/conv_layer_spatial.cpp +++ b/src/caffe/layers/conv_layer_spatial.cpp @@ -589,12 +589,6 @@ cl_int ConvolutionLayerSpatial::convolve( kernel.arg(argIdx++, (uint16_t)output_h_); const int_tp output_block_w = config->workItem_output[0]; const int_tp output_block_h = config->workItem_output[1]; - const int_tp last_block_width = ((output_w_ % output_block_w) == 0) ? - output_block_w : output_w_ % output_block_w; - const int_tp last_block_height =((output_h_ % output_block_h) == 0) ? - output_block_h : output_h_ % output_block_h; - kernel.arg(argIdx++, (uint16_t)last_block_width); - kernel.arg(argIdx++, (uint16_t)last_block_height); size_t global_size[3] = { (size_t) (output_w_ + output_block_w - 1) / output_block_w, (size_t) (output_h_ + output_block_h - 1) / output_block_h, (size_t) config->global_work_size[2]}; From d121f01604a00741e02f89df76bc0d5aafc979bb Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Wed, 5 Apr 2017 07:59:08 +0800 Subject: [PATCH 15/33] Add one infernece optimized model file for AlexNet. Signed-off-by: Zhigang Gong --- .../inference_optimize/AlexNet-merged-1.prototxt | 270 +++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 models/inference_optimize/AlexNet-merged-1.prototxt diff --git a/models/inference_optimize/AlexNet-merged-1.prototxt b/models/inference_optimize/AlexNet-merged-1.prototxt new file mode 100644 index 00000000000..0c3274727e1 --- /dev/null +++ b/models/inference_optimize/AlexNet-merged-1.prototxt @@ -0,0 +1,270 @@ +name: "AlexNet" +input: "data" +input_dim: 1 +input_dim: 3 +input_dim: 227 +input_dim: 227 +layer { + name: "label" + type: "Input" + top: "label" + input_param { shape: { dim: 1} } +} +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + fuse_type: FUSED_CONV_RELU + num_output: 96 + kernel_size: 11 + stride: 4 + } +} +layer { + name: "norm1" + type: "LRN" + bottom: "conv1" + top: "pool1" + lrn_param { + fuse_type: FUSED_POOL_MAX + local_size: 5 + alpha: 0.0001 + beta: 0.75 + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + } +} +#layer { +# name: "pool1" +# type: "Pooling" +# bottom: "norm1" +# top: "pool1" +# pooling_param { +# pool: MAX +# kernel_size: 3 +# stride: 2 +# } +#} +layer { + name: "conv2" + type: "Convolution" + bottom: "pool1" + top: "conv2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + fuse_type: FUSED_CONV_RELU + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + } +} +layer { + name: "norm2" + type: "LRN" + bottom: "conv2" + top: "pool2" + lrn_param { + fuse_type: FUSED_POOL_MAX + local_size: 5 + alpha: 0.0001 + beta: 0.75 + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + } +} +#layer { +# name: "pool2" +# type: "Pooling" +# bottom: "norm2" +# top: "pool2" +# pooling_param { +# pool: MAX +# kernel_size: 3 +# stride: 2 +# } +#} +layer { + name: "conv3" + type: "Convolution" + bottom: "pool2" + top: "conv3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + fuse_type: FUSED_CONV_RELU + num_output: 384 + pad: 1 + kernel_size: 3 + } +} +layer { + name: "conv4" + type: "Convolution" + bottom: "conv3" + top: "conv4" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + fuse_type: FUSED_CONV_RELU + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + } +} +layer { + name: "conv5" + type: "Convolution" + bottom: "conv4" + top: "conv5" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + fuse_type: FUSED_CONV_RELU + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + } +} +layer { + name: "pool5" + type: "Pooling" + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layer { + name: "fc6" + type: "InnerProduct" + bottom: "pool5" + top: "fc6" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + inner_product_param { + num_output: 4096 + } +} +layer { + name: "relu6" + type: "ReLU" + bottom: "fc6" + top: "fc6" +} +layer { + name: "drop6" + type: "Dropout" + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layer { + name: "fc7" + type: "InnerProduct" + bottom: "fc6" + top: "fc7" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + inner_product_param { + num_output: 4096 + } +} +layer { + name: "relu7" + type: "ReLU" + bottom: "fc7" + top: "fc7" +} +layer { + name: "drop7" + type: "Dropout" + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layer { + name: "fc8" + type: "InnerProduct" + bottom: "fc7" + top: "fc8" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + inner_product_param { + num_output: 1000 + } +} +layer { + name: "loss" + type: "SoftmaxWithLoss" + bottom: "fc8" + bottom: "label" + top: "loss" +} From 44b345a7c99a288b3e5590309dd3520a66f47359 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Thu, 4 May 2017 08:40:04 +0800 Subject: [PATCH 18/33] Add fused activation function. Forgot to add these macros in previous commit. Signed-off-by: Zhigang Gong --- src/caffe/greentea/cl_kernels.cpp | 12 +++++++++++- src/caffe/greentea/cl_kernels/conv_layer_spatial.cl | 12 +++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index 7eb89544046..b19f46b53fb 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -503,7 +503,17 @@ static std::vector> cl_kernels{ "Dtype out = arg;", // NOLINT "}", // NOLINT "", // NOLINT -"#define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = (_data_);} while(0)", // NOLINT +"#ifdef FUSED_CONV_RELU", // NOLINT +"#define ACTIVATION_RELU_FUNCTION(x) max((Dtype)(x), (Dtype)0.0f)", // NOLINT +"#else", // NOLINT +"#define ACTIVATION_RELU_FUNCTION(x) (x)", // NOLINT +"#endif", // NOLINT +"", // NOLINT +"#ifdef FUSED_CONV_ELTWISE", // NOLINT +"#define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = ACTIVATION_RELU_FUNCTION(eltwise_data[(_offset_)] + (_data_));} while(0)", // NOLINT +"#else", // NOLINT +"#define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = ACTIVATION_RELU_FUNCTION(_data_);} while(0)", // NOLINT +"#endif", // NOLINT "", // NOLINT "#define __CAT(x, y) x##y", // NOLINT "#define CAT(x, y) __CAT(x, y)", // NOLINT diff --git a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl index 77583f3a434..e35b632d1ce 100644 --- a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl +++ b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl @@ -6,7 +6,17 @@ __kernel void TEMPLATE(conv_layer_spatial_phony,Dtype)(Dtype arg) { Dtype out = arg; } -#define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = (_data_);} while(0) +#ifdef FUSED_CONV_RELU +#define ACTIVATION_RELU_FUNCTION(x) max((Dtype)(x), (Dtype)0.0f) +#else +#define ACTIVATION_RELU_FUNCTION(x) (x) +#endif + +#ifdef FUSED_CONV_ELTWISE +#define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = ACTIVATION_RELU_FUNCTION(eltwise_data[(_offset_)] + (_data_));} while(0) +#else +#define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = ACTIVATION_RELU_FUNCTION(_data_);} while(0) +#endif #define __CAT(x, y) x##y #define CAT(x, y) __CAT(x, y) From 1af4a95fd49ca98f041b307f3381c88327c783be Mon Sep 17 00:00:00 2001 From: "Lin, Lixiang" Date: Fri, 5 May 2017 10:29:23 +0800 Subject: [PATCH 19/33] Enable model fuse script to generate merged-model and adding an example to show the usage. --- .../inference-optimize/googlenet_inference_test.sh | 20 ++ examples/inference-optimize/readme.md | 18 ++ tools/inference-optimize/model_fuse.py | 279 +++++++++++++++++++++ 3 files changed, 317 insertions(+) create mode 100755 examples/inference-optimize/googlenet_inference_test.sh create mode 100644 examples/inference-optimize/readme.md create mode 100644 tools/inference-optimize/model_fuse.py diff --git a/examples/inference-optimize/googlenet_inference_test.sh b/examples/inference-optimize/googlenet_inference_test.sh new file mode 100755 index 00000000000..b16ca77c82c --- /dev/null +++ b/examples/inference-optimize/googlenet_inference_test.sh @@ -0,0 +1,20 @@ +#!/bin/sh +export CAFFE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"/../.. +export PYTHONPATH=$CAFFE_ROOT/"python" +# generate new fused model +python $CAFFE_ROOT/tools/inference-optimize/model_fuse.py \ + --indefinition $CAFFE_ROOT/models/bvlc_googlenet/deploy.prototxt \ + --inmodel $CAFFE_ROOT/models/bvlc_googlenet/bvlc_googlenet.caffemodel \ + --outdefinition $CAFFE_ROOT/models/bvlc_googlenet/fused_deploy.prototxt \ + --outmodel $CAFFE_ROOT/models/bvlc_googlenet/fused_bvlc_googlenet.caffemodel \ + --half_precision_mode=HALF_NONE \ + +#Use cpp_classfication to test +$CAFFE_ROOT/build/examples/cpp_classification/classification.bin \ + $CAFFE_ROOT/models/bvlc_googlenet/fused_deploy.prototxt \ + $CAFFE_ROOT/models/bvlc_googlenet/fused_bvlc_googlenet.caffemodel \ + $CAFFE_ROOT/data/ilsvrc12/imagenet_mean.binaryproto \ + $CAFFE_ROOT/data/ilsvrc12/synset_words.txt \ + $CAFFE_ROOT/examples/images/cat.jpg + + diff --git a/examples/inference-optimize/readme.md b/examples/inference-optimize/readme.md new file mode 100644 index 00000000000..293c665f68c --- /dev/null +++ b/examples/inference-optimize/readme.md @@ -0,0 +1,18 @@ +# Using model fuse to run inference-ontpimzed caffe + +The example use fused-model prototxt and weightfile to using layer-fused classification. + +Take googlenet as an example: + +1. Download GoogleNet model form "Model Zoo" using following script: +``` + $CAFFE_ROOT/scripts/download_model_binary.py models/bvlc_googlenet +``` +2. ImageNet label file required by: +``` + $CAFFE_ROOT/data/ilsvrc12/get_ilsvrc_aux.sh +``` +3. Use model_fuse.py to generate fused model and cpp_classifcation to test the clasify funtionality with script: +``` + ./googlenet_inference_test.sh +``` diff --git a/tools/inference-optimize/model_fuse.py b/tools/inference-optimize/model_fuse.py new file mode 100644 index 00000000000..9d86d7817aa --- /dev/null +++ b/tools/inference-optimize/model_fuse.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python + +import os,sys +import caffe +import numpy as np +#import utils as ut +import csv,math +import subprocess, sys +import string +import copy +import argparse +from google.protobuf import text_format +from pdb import set_trace + +def resnet_block_to_fuse_type(model, cur_conv_index): + actual = [model.layer[cur_conv_index+1].type, model.layer[cur_conv_index+2].type, model.layer[cur_conv_index+3].type, model.layer[cur_conv_index+4].type] + resnet = ['BatchNorm', 'Scale', 'ReLU'] + resnet_merged = ['ReLU'] + resnet_elt = ['BatchNorm', 'Scale', 'Eltwise', 'ReLU'] + resnet_elt_merged = ['Eltwise', 'ReLU'] + if actual[:1] == resnet_merged or actual[:3] == resnet: + return 2 #FUSED_CONV_RELU TODO: not magic number + if actual[:2] == resnet_elt_merged or actual == resnet_elt: + return 3 #FUSED_CONV_ELTWISE_RELU + return 0 #UNFUSED + +def find_fused_blob_names(model, cur_conv_index): + i = cur_conv_index + 1 + new_top = None + elt_bottom = None + while model.layer[i].type in {'BatchNorm', 'Scale', 'Eltwise', 'ReLU'}: + if model.layer[i].type == 'Eltwise': + elt_bottom = model.layer[i].bottom[0] + i = i + 1 + new_top = model.layer[i].bottom[0] + return new_top, elt_bottom + +def str_to_precision_mode(mode): + if mode == 'HALF_NONE': + return 0 + if mode == 'HALF_FLOAT_DATA': + return 1 + if mode == 'HALF_HALF_DATA': + return 2 + if mode == 'HALF_ALL': + return 3 + +def set_input(in_model, out_model, half_precision_mode): + out_model.name = in_model.name + if(half_precision_mode != 'HALF_NONE'): + out_model.half_precision_mode = str_to_precision_mode(half_precision_mode) + #For input + for i in range(len(in_model.input)): + out_model.input.extend([in_model.input[i]]) + if len(in_model.input_shape) < i: + out_model.input_shape.extend([in_model.input_shape[i]]) + for i in range(len(in_model.input_dim)): + out_model.input_dim.extend([in_model.input_dim[i]]) + +def is_fused_layer(model, layer_index): + if model.layer[layer_index].type in {'BatchNorm', 'Scale'}: + return True + # Fuse with Conv layer. + elif model.layer[layer_index].type in {'ReLU', 'Eltwise'} and model.layer[layer_index-1].type not in {'InnerProduct'}: + return True # Skip ReLU in case of layers fusing + # Fuse with LRN layer. + elif model.layer[layer_index].type in {'Pooling'} and model.layer[layer_index-1].type in {'LRN'} and model.layer[layer_index].pooling_param.pool == 0: + return True + else: + return False + +def fuse_conv_layer(in_model, in_index, out_model, new_index): + if out_model.layer[new_index].type == 'Convolution': + fuse_mode = resnet_block_to_fuse_type(in_model, in_index) + if [in_model.layer[in_index+1].type, in_model.layer[in_index+2].type] == ['BatchNorm', 'Scale']: + out_model.layer[new_index].convolution_param.bias_term = True + out_model.layer[new_index].convolution_param.fuse_type = fuse_mode + if fuse_mode == 3: # FUSED_CONV_ELTWISE_RELU, need to change top name to orig ReLU's top name + new_top, elt_bottom = find_fused_blob_names(in_model, in_index) + out_model.layer[new_index].top.remove(out_model.layer[new_index].top[0]) + out_model.layer[new_index].top.append(new_top) + out_model.layer[new_index].bottom.append(elt_bottom) + +def fuse_lrn_layer(in_model, in_index, out_model, out_index): + if out_model.layer[out_index].type == 'LRN' and in_model.layer[in_index + 1].type == 'Pooling' and in_model.layer[ + in_index + 1].pooling_param.pool == 0: + new_top = in_model.layer[in_index + 1].top[0] + out_model.layer[out_index].top.remove(out_model.layer[out_index].top[0]) + out_model.layer[out_index].top.append(new_top) + out_model.layer[out_index].lrn_param.fuse_type = 1 # 'FUSED_POOL_MAX' + pooling_param = in_model.layer[in_index + 1].pooling_param + out_model.layer[out_index].lrn_param.pooling_param.pool = pooling_param.pool + out_model.layer[out_index].lrn_param.pooling_param.kernel_size.append(pooling_param.kernel_size[0]) + out_model.layer[out_index].lrn_param.pooling_param.stride.append(pooling_param.stride[0]) + +def set_layers(in_model, out_model): + out_index = 0 + for in_index in range(0, len(in_model.layer)): + if is_fused_layer(in_model, in_index): + continue + else: + out_model.layer.extend([in_model.layer[in_index]]) + fuse_conv_layer(in_model, in_index, out_model, out_index) + fuse_lrn_layer(in_model, in_index, out_model, out_index) + out_index = out_index + 1 + +def create_new_model(in_model, half_precision_mode): + out_model = caffe.proto.caffe_pb2.NetParameter() + set_input(in_model, out_model, half_precision_mode) + set_layers(in_model, out_model) + return out_model + +def load_model(filename): + model = caffe.proto.caffe_pb2.NetParameter() + input_file = open(filename, 'r') + text_format.Merge(str(input_file.read()), model) + input_file.close() + return model + +def save_model(model, filename): + output_file = open(filename, 'w') + text_format.PrintMessage(model, output_file) + output_file.close() + +def find_layerindex_by_name(model, layer_name): + k = 0 + while model.layer[k].name != layer_name: + k += 1 + if (k > len(model.layer)): + raise IOError('layer with name %s not found' % layer_name) + return k + +def define_arguments(parser): + parser.add_argument('--indefinition', type=str, + default='deploy.prototxt', + help='input network definition (prototxt)') + parser.add_argument('--inmodel', type=str, + default='bvlc_alexnet.caffemodel', + help='input network parameters (caffemodel)') + parser.add_argument('--outdefinition', type=str, + default='new_deploy.prototxt', + help='output network definition (prototxt)') + parser.add_argument('--outmodel', type=str, + default='new_bvlc_alexnet.caffemodel', + help='output network parameters (caffemodel; will be overwritten)') + parser.add_argument('--half_precision_mode', type = str, + default='HALF_NONE', + help='float half precision mode') + parser.add_argument('--fuse_resnet_block', dest='fuse_resnet_block', action='store_true', + default=True, + help='indicates whether to fuse conv-(batchnorm-scale)-relu block into the conv') + parser.add_argument('--proto_only', dest='proto_only', action='store_true', + default=False, + help='indicates whether to generate merged network definition (prototxt) only, without touching the weights') + +def parse_args(): + parser = argparse.ArgumentParser(description='convert a network using ' + + 'batch normalization into an equivalent network that does not. Assumes that ' + + 'parameter layer names have the form \'conv*\' or \'fc*\' and are directly ' + + 'followed by their respective batch norm layers. These BatchNorm layers must ' + + 'have names which are identical to the names of the layers that they modify, ' + + 'except that \'conv\' or \'fc\' is replaced by \'bn\'. E.g. conv3 is directly ' + + 'followed by \'bn3\'. Does not copy fc7, fc8, or fc9.') + define_arguments(parser) + + args = parser.parse_args() + return args + +def generate_weights(in_model, args): + in_net = caffe.Net(args.indefinition, args.inmodel ,caffe.TEST) + #required for working with the fused layers + caffe.set_device(0) + caffe.set_mode_gpu() + out_net =caffe.Net(args.outdefinition,caffe.TEST) + tocopy=out_net.params + + for prm in tocopy: + k = find_layerindex_by_name(in_model, prm) + if in_model.layer[k].type in {'InnerProduct', 'Scale'}: + for i in range(0,len(in_net.params[prm])): + out_net.params[prm][i].data[...]=np.copy(in_net.params[prm][i].data[...]) + continue + #TODO:Need fix conv+bn + if (in_model.layer[k].type == 'Convolution'): # Assuming convolution is followed by bn and scale + next1type = in_model.layer[k + 1].type + next2type = in_model.layer[k + 2].type + else: + print 'Warning: ' + prm + ' has parameters but I can\'t infer its layer type.' + continue + if next2type not in {'Scale'}: + print next2type + ' not found, just ignoring scale ' + prm + isScale = False + else: + isScale = True + sclprm = in_model.layer[k + 2].name + if next1type not in {'BatchNorm'}: + print next1type + ' not found, just copying ' + prm + for i in range(0,len(in_net.params[prm])): + out_net.params[prm][i].data[...]=np.copy(in_net.params[prm][i].data[...]); + continue; + else: + bnprm = in_model.layer[k + 1].name + if in_net.params[prm][0].data.shape != out_net.params[prm][0].data.shape: + print 'Warning: ' + prm + ' has parameters but they are of different sizes in the different protos. skipping.' + continue; + print 'Removing batchnorm from ' + prm; + + #for i in range(0,len(net2.params[prm])): # first blob for conv layers is the weights, second is the bias. No need for the loop + i = 0 + if True: + prmval=np.copy(in_net.params[prm][i].data).reshape(out_net.params[prm][i].data.shape); + + + meanval=np.copy(in_net.params[bnprm][0].data); + stdval=np.copy(in_net.params[bnprm][1].data); + scaleFactor =np.copy(in_net.params[bnprm][2].data); + + meanval/=in_net.params[bnprm][2].data[...].reshape(-1); + stdval/=in_net.params[bnprm][2].data[...].reshape(-1); + eps=None; + for j in range(0, len(in_model.layer)): + if str(in_model.layer[j].name) == bnprm: + eps=in_model.layer[j].batch_norm_param.eps; + + if eps is None: + raise ValueError("Unable to get epsilon for layer " + nbprm); + + stdval+=eps; + + stdval=np.sqrt(stdval); + + prmval /= stdval.reshape((-1,1,1,1)); + bias1 = -meanval / stdval + + if isScale: + print 'Removing Scale Layer' + Scale_Layer_param =np.copy(in_net.params[sclprm][0].data) + Scale_Layer_param_bias =np.copy(in_net.params[sclprm][1].data) + + Scale_Layer_paramBeta =np.copy(in_net.params[sclprm][1].data) + prmval= prmval*Scale_Layer_param.reshape((-1,1,1,1)) + + mul_bias1_scale = [x * y for x, y in zip(bias1, Scale_Layer_param)] + bias1 = Scale_Layer_param_bias + mul_bias1_scale + + out_net.params[prm][i].data[:]=prmval + no_prior_bias = False + if len(out_net.params[prm]) < 2 : #no bias + out_net.params[prm].add_blob() + out_net.params[prm][1].reshape(len(bias1)) + no_prior_bias = True + + if no_prior_bias: + out_net.params[prm][1].data[:] = bias1 + else: + out_net.params[prm][1].data[:] = bias1 + out_net.params[prm][1].data[:] + + print 'New caffemodel done' + out_net.save(args.outmodel); + +def generate_prototxt(in_proto, args): + out_model = create_new_model(in_proto, args.half_precision_mode) + save_model(out_model, args.outdefinition) + print 'New proto done' + +def generate_new_model(args): + in_model = load_model(args.indefinition) + generate_prototxt(in_model, args) + if not args.proto_only: + generate_weights(in_model, args) + + +def main(argv): + # parse args + args = parse_args() + generate_new_model(args) + +if __name__ == '__main__': + main(sys.argv) From 6ab99445b97ca5bc56d1fb5ce888a8ec4997bdec Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Wed, 10 May 2017 09:57:20 +0800 Subject: [PATCH 20/33] Always allocate zero-copy capable memory. Signed-off-by: Zhigang Gong --- src/caffe/syncedmem.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index e5dcda1b022..d8f3e3d2afc 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -46,9 +46,16 @@ void CaffeMallocHost(void** ptr, int_tp size, device* dev) { } #endif #ifdef USE_MKL +#ifndef USE_GREENTEA *ptr = mkl_malloc(size ? size:1, 64); #else - *ptr = malloc(size); + *ptr = mkl_malloc(size ? ALIGN(size, OPENCL_CACHE_ALIGN) : 64, OPENCL_PAGE_ALIGN); +#endif +#else + CHECK_EQ(0, posix_memalign(ptr, OPENCL_PAGE_ALIGN, + ((size - 1)/OPENCL_CACHE_ALIGN + 1) * OPENCL_CACHE_ALIGN)) + << "Host memory allocation error of size: " + << size << " B"; #endif // USE_MKL CHECK(*ptr) << "host allocation of size " << size << " failed"; } From 9681b26a5797b539e4b670c31a6204dd14abf3ef Mon Sep 17 00:00:00 2001 From: wzw Date: Sat, 10 Jun 2017 01:59:54 +0800 Subject: [PATCH 21/33] Fix "nan" value bug for matvec_mul.cl If the output buffer didn't initialized, the kernel code: "result[row_gid] = alpha * work[0] + beta * result[row_gid];". will make "result[row_gid]" to be "nan" no matter whether "beta" is zero. --- src/caffe/greentea/cl_kernels.cpp | 13 +++++++++++-- src/caffe/greentea/cl_kernels/matvec_mul.cl | 19 ++++++++++++++----- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index b19f46b53fb..3914696dd06 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -5437,10 +5437,16 @@ static std::vector> cl_kernels{ "if(lid < stride)", // NOLINT "work[lid] += work[lid+stride];", // NOLINT "}", // NOLINT -"if(lid == 0)", // NOLINT +"", // NOLINT +"if(lid == 0) {", // NOLINT +"if(beta == (Dtype)0)", // NOLINT +"result[row_gid] = alpha * work[0];", // NOLINT +"else", // NOLINT "result[row_gid] = alpha * work[0] + beta * result[row_gid];", // NOLINT "}", // NOLINT "", // NOLINT +"}", // NOLINT +"", // NOLINT "/* This kernel used for the trailing rows when row_of_A %4 !=0 */", // NOLINT "__kernel void TEMPLATE(matvec_mul1,Dtype)(", // NOLINT "__global const float * A,", // NOLINT @@ -5498,9 +5504,12 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "if(lid == 0) {", // NOLINT +"if(beta == (Dtype)0) {", // NOLINT +"result[row_gid+row_offset] = alpha * work[0];", // NOLINT +"} else {", // NOLINT "result[row_gid+row_offset] *= beta;", // NOLINT "result[row_gid+row_offset] += alpha * work[0];", // NOLINT -"//result[row_gid+row_offset] = alpha * work[0] + beta * result[row_gid+row_offset];", // NOLINT +"}", // NOLINT "}", // NOLINT "}", // NOLINT ""}, // NOLINT diff --git a/src/caffe/greentea/cl_kernels/matvec_mul.cl b/src/caffe/greentea/cl_kernels/matvec_mul.cl index d0b9bacb385..dee7779ce9c 100644 --- a/src/caffe/greentea/cl_kernels/matvec_mul.cl +++ b/src/caffe/greentea/cl_kernels/matvec_mul.cl @@ -75,8 +75,14 @@ __kernel void TEMPLATE(matvec_mul4,Dtype)( if(lid < stride) work[lid] += work[lid+stride]; } - if(lid == 0) - result[row_gid] = alpha * work[0] + beta * result[row_gid]; + + if(lid == 0) { + if(beta == (Dtype)0) + result[row_gid] = alpha * work[0]; + else + result[row_gid] = alpha * work[0] + beta * result[row_gid]; + } + } /* This kernel used for the trailing rows when row_of_A %4 !=0 */ @@ -136,8 +142,11 @@ __kernel void TEMPLATE(matvec_mul1,Dtype)( } if(lid == 0) { - result[row_gid+row_offset] *= beta; - result[row_gid+row_offset] += alpha * work[0]; - //result[row_gid+row_offset] = alpha * work[0] + beta * result[row_gid+row_offset]; + if(beta == (Dtype)0) { + result[row_gid+row_offset] = alpha * work[0]; + } else { + result[row_gid+row_offset] *= beta; + result[row_gid+row_offset] += alpha * work[0]; + } } } From dd8555a2223bc7d445ee96f914b155132cc36948 Mon Sep 17 00:00:00 2001 From: "Lin, Lixiang" Date: Wed, 26 Apr 2017 13:32:42 +0800 Subject: [PATCH 22/33] 1, Enable gemm_fast_image blocks computing logic; 2, Refined innerprod autotuneing logic Change-Id: I20b74574845a2d0d0b33fb0de340c5346d763897 --- include/caffe/greentea/greentea_math_functions.hpp | 11 +- include/caffe/layers/inner_product_layer.hpp | 48 + src/caffe/greentea/cl_kernels.cpp | 1067 +++++++----- src/caffe/greentea/cl_kernels/gemm.cl | 1700 +++++++++++++------- src/caffe/greentea/greentea_math_functions.cpp | 906 ++--------- src/caffe/layers/inner_product_layer.cpp | 28 +- src/caffe/layers/inner_product_layer.cu | 741 ++++++++- src/caffe/test/test_inner_product_layer.cpp | 14 +- 8 files changed, 2599 insertions(+), 1916 deletions(-) diff --git a/include/caffe/greentea/greentea_math_functions.hpp b/include/caffe/greentea/greentea_math_functions.hpp index 1462b9b7c48..58364dca3ed 100644 --- a/include/caffe/greentea/greentea_math_functions.hpp +++ b/include/caffe/greentea/greentea_math_functions.hpp @@ -53,16 +53,7 @@ void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const int_tp N, const int_tp K, const Dtype alpha, const cl_mem A, const int_tp offA, const cl_mem B, const int_tp offB, const Dtype beta, cl_mem C, - const int_tp offC , const bool is_image_a = false, - const bool is_image_b = false); - -void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, - cl_mem *image, cl_mem buffer, int offset, - bool is_matrix_a, bool transpose, - bool padding, int padded_height, - int padded_width, int height, - int width, int wait_list_size, - cl_event *wait_list, cl_event *event); + const int_tp offC); template void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, diff --git a/include/caffe/layers/inner_product_layer.hpp b/include/caffe/layers/inner_product_layer.hpp index 0c4bd95038b..12a52adfb59 100644 --- a/include/caffe/layers/inner_product_layer.hpp +++ b/include/caffe/layers/inner_product_layer.hpp @@ -6,6 +6,9 @@ #include "caffe/blob.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" +#ifdef USE_GREENTEA +#include +#endif namespace caffe { @@ -15,6 +18,14 @@ namespace caffe { * * TODO(dox): thorough documentation for Forward, Backward, and proto params. */ + +enum gemm_type_t { + GEMM_TYPE_DEFAULT = 0, + GEMM_TYPE_FAST_IMAGE_32_1, + GEMM_TYPE_FAST_IMAGE_32_2, + GEMM_TYPE_FAST_IMAGE_B_IMAGE, + GEMM_TYPE_FAST_BUFFER +}; template class InnerProductLayer : public Layer { public: @@ -22,6 +33,31 @@ class InnerProductLayer : public Layer { : Layer(param) { #ifdef USE_GREENTEA weight_image_ = NULL; + weight_image_seq_ = -1; + innerprod_type_ = GEMM_TYPE_DEFAULT; + tuned_ = false; + + if (std::getenv("CLCAFFE_CACHE_PATH")) + cache_path_ << std::getenv("CLCAFFE_CACHE_PATH"); + else if (std::getenv("VIENNACL_CACHE_PATH")) + cache_path_ << std::getenv("VIENNACL_CACHE_PATH") << "/clCaffe"; + else if (std::getenv("HOME")) { + cache_path_ << std::getenv("HOME") << "/.cache/clCaffe"; + } + cache_path_ << "/innerprod/"; + const boost::filesystem::path& path = cache_path_.str(); + const boost::filesystem::path& dir = + boost::filesystem::unique_path(path).string(); + bool hasCacheDir = false; + if (!boost::filesystem::exists(dir)) + hasCacheDir = boost::filesystem::create_directories(dir); + else + hasCacheDir = boost::filesystem::is_directory(dir); + + if (hasCacheDir != true) { + std::cout << "Failed to create cache directory," + << "will tune again for next running" << std::endl; + } #endif } virtual void LayerSetUp(const vector*>& bottom, @@ -48,6 +84,13 @@ class InnerProductLayer : public Layer { const vector& propagate_down, const vector*>& bottom); virtual void Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); +#ifdef USE_GREENTEA + virtual void generate_key(); + virtual void tune_innerprod_type(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransB, const cl_mem A, + const cl_mem B, const cl_mem B_image, const size_t max_image_size); + virtual bool load_cache(); +#endif int_tp M_; int_tp K_; @@ -59,6 +102,11 @@ class InnerProductLayer : public Layer { cl_mem weight_image_; const SyncedMemory * copied_weight_data_; bool test_only_; + uint64_t weight_image_seq_; + gemm_type_t innerprod_type_; + bool tuned_; + std::stringstream cache_path_; + std::string key_; #endif }; diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index 3914696dd06..455f03103b7 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -3024,35 +3024,67 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT +"#if TYPE != TYPE_DOUBLE", // NOLINT +"", // NOLINT "#define TILE_M 32", // NOLINT "#define TILE_K 8", // NOLINT -"#define TILE_N 8", // NOLINT "", // NOLINT "// common block to calculate (alpha * AxB + beta * C) and output to destination image.", // NOLINT "", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read_us8( __image, __coord )", // NOLINT +"#define SHUFFLE_TYPE2(val) as_ushort2(val)", // NOLINT +"#define SHUFFLE_TYPE8(val) as_ushort8(val)", // NOLINT +"#define READ_IMAGE(__image, __coord) read_imageh(__image, sampler, __coord)", // NOLINT +"#define SIZE_OF_ELEMENT sizeof(ushort)", // NOLINT +"#define SIMD_SIZE_GEMM 16", // NOLINT +"#define TILE_N 16", // NOLINT +"#else", // NOLINT +"#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read8( __image, __coord )", // NOLINT +"#define SHUFFLE_TYPE2(val) val", // NOLINT +"#define SHUFFLE_TYPE8(val) val", // NOLINT +"#define READ_IMAGE(__image, __coord) read_imagef(__image, sampler, __coord)", // NOLINT +"#define SIZE_OF_ELEMENT sizeof(uint)", // NOLINT +"#define SIMD_SIZE_GEMM 8", // NOLINT +"#define TILE_N 8", // NOLINT +"#endif", // NOLINT +"", // NOLINT "//#define USE_IMAGE_C", // NOLINT "#ifdef USE_IMAGE_C", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read_us8( _C, _coordC ) )", // NOLINT +"#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write_us8( _C, _coordC, as_ushort8( _val ) )", // NOLINT +"#else", // NOLINT "#define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read8( _C, _coordC ) )", // NOLINT "#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) )", // NOLINT +"#endif", // NOLINT "#define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst", // NOLINT "#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint))", // NOLINT "#else", // NOLINT -"#define BLOCKC_READ8( _C, _coordC ) (float8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * N + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * N + _coordC.x + get_local_id(0) ] : 0)", // NOLINT +"#define BLOCKC_READ8( _C, _coordC ) (float8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * ldc + _coordC.x + get_local_id(0) ] : 0)", // NOLINT "", // NOLINT -"#define BLOCKC_WRITE8( _C, _coordC, _val) do { if (_coordC.x + get_local_id(0) < N) { if (_coordC.y < M) _C[ _coordC.y * N + _coordC.x + get_local_id(0) ] = _val.s0; if (_coordC.y + 1 < M) _C[ ( _coordC.y + 1 )* N + _coordC.x + get_local_id(0) ] = _val.s1; if (_coordC.y + 2 < M) _C[ ( _coordC.y + 2 )* N + _coordC.x + get_local_id(0) ] = _val.s2; if (_coordC.y + 3 < M) _C[ ( _coordC.y + 3 )* N + _coordC.x + get_local_id(0) ] = _val.s3; if (_coordC.y + 4 < M) _C[ ( _coordC.y + 4 )* N + _coordC.x + get_local_id(0) ] = _val.s4; if (_coordC.y + 5 < M) _C[ ( _coordC.y + 5 )* N + _coordC.x + get_local_id(0) ] = _val.s5; if (_coordC.y + 6 < M) _C[ ( _coordC.y + 6 )* N + _coordC.x + get_local_id(0) ] = _val.s6; if (_coordC.y + 7 < M) _C[ ( _coordC.y + 7 )* N + _coordC.x + get_local_id(0) ] = _val.s7; }} while(0)", // NOLINT -"#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N", // NOLINT +"#define BLOCKC_WRITE8( _C, _coordC, _val) do { if (_coordC.x + get_local_id(0) < N) { if (_coordC.y < M) _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] = _val.s0; if (_coordC.y + 1 < M) _C[ ( _coordC.y + 1 )* ldc + _coordC.x + get_local_id(0) ] = _val.s1; if (_coordC.y + 2 < M) _C[ ( _coordC.y + 2 )* ldc + _coordC.x + get_local_id(0) ] = _val.s2; if (_coordC.y + 3 < M) _C[ ( _coordC.y + 3 )* ldc + _coordC.x + get_local_id(0) ] = _val.s3; if (_coordC.y + 4 < M) _C[ ( _coordC.y + 4 )* ldc + _coordC.x + get_local_id(0) ] = _val.s4; if (_coordC.y + 5 < M) _C[ ( _coordC.y + 5 )* ldc + _coordC.x + get_local_id(0) ] = _val.s5; if (_coordC.y + 6 < M) _C[ ( _coordC.y + 6 )* ldc + _coordC.x + get_local_id(0) ] = _val.s6; if (_coordC.y + 7 < M) _C[ ( _coordC.y + 7 )* ldc + _coordC.x + get_local_id(0) ] = _val.s7; }} while(0)", // NOLINT +"#define MATC_PARAMETER __global float * C, const int offC, const int M, const int N, const int ldc", // NOLINT "#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1)", // NOLINT "#endif", // NOLINT "", // NOLINT -"#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); int2 coordC = coordDst; float8 blockC00; float8 blockC01; float8 blockC02; float8 blockC03; if (BETA_NOT0) { blockC00 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC01 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC02 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC03 = BLOCKC_READ8( _C, coordC ); if (!ALPHA1) { blockC00 *= beta; blockC01 *= beta; blockC02 *= beta; blockC03 *= beta; blockC00 = mad(blockAxB00, (float8)alpha, blockC00); blockC01 = mad(blockAxB01, (float8)alpha, blockC01); blockC02 = mad(blockAxB02, (float8)alpha, blockC02); blockC03 = mad(blockAxB03, (float8)alpha, blockC03); } else { blockC00 = mad(blockC00, (float8)beta, blockAxB00); blockC01 = mad(blockC01, (float8)beta, blockAxB01); blockC02 = mad(blockC02, (float8)beta, blockAxB02); blockC03 = mad(blockC03, (float8)beta, blockAxB03); } } else { if (!ALPHA1) { blockC00 = blockAxB00 * alpha; blockC01 = blockAxB01 * alpha; blockC02 = blockAxB02 * alpha; blockC03 = blockAxB03 * alpha; } else { blockC00 = blockAxB00; blockC01 = blockAxB01; blockC02 = blockAxB02; blockC03 = blockAxB03; } } BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC03 );", // NOLINT +"#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); int2 coordC = coordDst; float8 blockC00; float8 blockC01; float8 blockC02; float8 blockC03; if (BETA_NOT0) { blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); if (!ALPHA1) { blockC00 = mad(blockAxB00, (float8)alpha, blockC00); blockC01 = mad(blockAxB01, (float8)alpha, blockC01); blockC02 = mad(blockAxB02, (float8)alpha, blockC02); blockC03 = mad(blockAxB03, (float8)alpha, blockC03); } else { blockC00 += blockAxB00; blockC01 += blockAxB01; blockC02 += blockAxB02; blockC03 += blockAxB03; } } else { blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); if (!ALPHA1) { blockC00 = mad(blockAxB00, (float8)alpha, blockC00); blockC01 = mad(blockAxB01, (float8)alpha, blockC01); blockC02 = mad(blockAxB02, (float8)alpha, blockC02); blockC03 = mad(blockAxB03, (float8)alpha, blockC03); } else { blockC00 += blockAxB00; blockC01 += blockAxB01; blockC02 += blockAxB02; blockC03 += blockAxB03; } } BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC03 );", // NOLINT "", // NOLINT "// Get the specified column of the block of the block", // NOLINT "#define TRANSPOSE_BLOCK_8( _block, _col ) (float8)( intel_sub_group_shuffle( _block.s0, _col ), intel_sub_group_shuffle( _block.s1, _col ), intel_sub_group_shuffle( _block.s2, _col ), intel_sub_group_shuffle( _block.s3, _col ), intel_sub_group_shuffle( _block.s4, _col ), intel_sub_group_shuffle( _block.s5, _col ), intel_sub_group_shuffle( _block.s6, _col ), intel_sub_group_shuffle( _block.s7, _col ) );", // NOLINT "", // NOLINT "// A's column block multiply B 's row block.", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB00, _blockB01 ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); const float8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); const float8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); const float8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); const float8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); const float8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); const float8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); const float8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); const float8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); _result = mad( (float8)(_blockB00.s0), acol0, _result ); _result = mad( (float8)(_blockB00.s1), acol1, _result ); _result = mad( (float8)(_blockB00.s2), acol2, _result ); _result = mad( (float8)(_blockB00.s3), acol3, _result ); _result = mad( (float8)(_blockB00.s4), acol4, _result ); _result = mad( (float8)(_blockB00.s5), acol5, _result ); _result = mad( (float8)(_blockB00.s6), acol6, _result ); _result = mad( (float8)(_blockB00.s7), acol7, _result ); _result = mad( (float8)(_blockB01.s0), acol8, _result ); _result = mad( (float8)(_blockB01.s1), acol9, _result ); _result = mad( (float8)(_blockB01.s2), acola, _result ); _result = mad( (float8)(_blockB01.s3), acolb, _result ); _result = mad( (float8)(_blockB01.s4), acolc, _result ); _result = mad( (float8)(_blockB01.s5), acold, _result ); _result = mad( (float8)(_blockB01.s6), acole, _result ); _result = mad( (float8)(_blockB01.s7), acolf, _result ); }", // NOLINT +"#else", // NOLINT "#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); _result = mad( (float8)(_blockB.s0), acol0, _result ); _result = mad( (float8)(_blockB.s1), acol1, _result ); _result = mad( (float8)(_blockB.s2), acol2, _result ); _result = mad( (float8)(_blockB.s3), acol3, _result ); _result = mad( (float8)(_blockB.s4), acol4, _result ); _result = mad( (float8)(_blockB.s5), acol5, _result ); _result = mad( (float8)(_blockB.s6), acol6, _result ); _result = mad( (float8)(_blockB.s7), acol7, _result ); }", // NOLINT +"#endif", // NOLINT "", // NOLINT -"#define GEMM_NN(ALPHA1, BETA_NOT0) __attribute__((reqd_work_group_size(8, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha, float beta, int width0) { const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( ( group_x * TILE_N ) * sizeof(uint), 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( intel_sub_group_block_read8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.x += TILE_K * sizeof(uint); MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define GEMM_NN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha_in, float beta_in, int width0, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0; float8 blockAxB01 = 0; float8 blockAxB02 = 0; float8 blockAxB03 = 0; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; float8 blockB01 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, blockB01 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, blockB01 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, blockB01 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, blockB01 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#else", // NOLINT +"#define GEMM_NN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha_in, float beta_in, int width0, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#endif", // NOLINT "", // NOLINT "GEMM_NN(1, 0) // ALPHA == 1, BETA == 0", // NOLINT "GEMM_NN(1, 1) // ALPHA == 1, BETA != 0", // NOLINT @@ -3063,11 +3095,15 @@ static std::vector> cl_kernels{ "#undef MULTIPLY_BLOCKS_8x8", // NOLINT "", // NOLINT "// replicate the first row to column block.", // NOLINT -"#define TRANSPOSE_BLOCK_8(_vec) (float8)( intel_sub_group_shuffle(_vec, 0), intel_sub_group_shuffle(_vec, 1), intel_sub_group_shuffle(_vec, 2), intel_sub_group_shuffle(_vec, 3), intel_sub_group_shuffle(_vec, 4), intel_sub_group_shuffle(_vec, 5), intel_sub_group_shuffle(_vec, 6), intel_sub_group_shuffle(_vec, 7) )", // NOLINT +"#define TRANSPOSE_BLOCK_8(_vec, _col) (float8)( intel_sub_group_shuffle(_vec, _col + 0), intel_sub_group_shuffle(_vec, _col + 1), intel_sub_group_shuffle(_vec, _col + 2), intel_sub_group_shuffle(_vec, _col + 3), intel_sub_group_shuffle(_vec, _col + 4), intel_sub_group_shuffle(_vec, _col + 5), intel_sub_group_shuffle(_vec, _col + 6), intel_sub_group_shuffle(_vec, _col + 7) )", // NOLINT "", // NOLINT -"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { _result = mad( (float8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0), _result ); _result = mad( (float8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1), _result ); _result = mad( (float8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2), _result ); _result = mad( (float8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3), _result ); _result = mad( (float8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4), _result ); _result = mad( (float8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5), _result ); _result = mad( (float8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6), _result ); _result = mad( (float8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7), _result ); }", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) { _result = mad( (float8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); _result = mad( (float8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); _result = mad( (float8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); _result = mad( (float8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); _result = mad( (float8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); _result = mad( (float8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); _result = mad( (float8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); _result = mad( (float8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); }", // NOLINT "", // NOLINT -"#define GEMM_TN(ALPHA1, BETA_NOT0) __attribute__((reqd_work_group_size(8, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha, float beta, int width0) { const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * sizeof(uint), 0 ); int2 coordB = (int2)( ( group_x * TILE_N ) * sizeof(uint), 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( intel_sub_group_block_read8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define GEMM_TN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha_in, float beta_in, int width0, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0; float8 blockAxB01 = 0; float8 blockAxB02 = 0; float8 blockAxB03 = 0; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_half8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT; float8 blockA01 = as_half8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#else", // NOLINT +"#define GEMM_TN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha_in, float beta_in, int width0, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, 0 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#endif", // NOLINT "", // NOLINT "GEMM_TN(1, 0) // ALPHA == 1, BETA == 0", // NOLINT "GEMM_TN(1, 1) // ALPHA == 1, BETA != 0", // NOLINT @@ -3080,14 +3116,23 @@ static std::vector> cl_kernels{ "// The same as GEMM_NN", // NOLINT "#define TRANSPOSE_BLOCK_8( _block, _col ) (float8)( intel_sub_group_shuffle( _block.s0, _col), intel_sub_group_shuffle( _block.s1, _col), intel_sub_group_shuffle( _block.s2, _col), intel_sub_group_shuffle( _block.s3, _col), intel_sub_group_shuffle( _block.s4, _col), intel_sub_group_shuffle( _block.s5, _col), intel_sub_group_shuffle( _block.s6, _col), intel_sub_group_shuffle( _block.s7, _col) )", // NOLINT "", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); const float8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); const float8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); const float8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); const float8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); const float8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); const float8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); const float8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); const float8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); _result = mad( (float8)_blockB.s0, acol0, _result ); _result = mad( (float8)_blockB.s1, acol1, _result ); _result = mad( (float8)_blockB.s2, acol2, _result ); _result = mad( (float8)_blockB.s3, acol3, _result ); _result = mad( (float8)_blockB.s4, acol4, _result ); _result = mad( (float8)_blockB.s5, acol5, _result ); _result = mad( (float8)_blockB.s6, acol6, _result ); _result = mad( (float8)_blockB.s7, acol7, _result ); _result = mad( (float8)_blockB.s8, acol8, _result ); _result = mad( (float8)_blockB.s9, acol9, _result ); _result = mad( (float8)_blockB.sa, acola, _result ); _result = mad( (float8)_blockB.sb, acolb, _result ); _result = mad( (float8)_blockB.sc, acolc, _result ); _result = mad( (float8)_blockB.sd, acold, _result ); _result = mad( (float8)_blockB.se, acole, _result ); _result = mad( (float8)_blockB.sf, acolf, _result ); }", // NOLINT +"#else", // NOLINT "#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); _result = mad( (float8)_blockB.s0, acol0, _result ); _result = mad( (float8)_blockB.s1, acol1, _result ); _result = mad( (float8)_blockB.s2, acol2, _result ); _result = mad( (float8)_blockB.s3, acol3, _result ); _result = mad( (float8)_blockB.s4, acol4, _result ); _result = mad( (float8)_blockB.s5, acol5, _result ); _result = mad( (float8)_blockB.s6, acol6, _result ); _result = mad( (float8)_blockB.s7, acol7, _result ); }", // NOLINT +"#endif", // NOLINT "", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha_in, float beta_in, int padded_k, int k, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0; float8 blockAxB01 = 0; float8 blockAxB02 = 0; float8 blockAxB03 = 0; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float16 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#else", // NOLINT +"#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha_in, float beta_in, int padded_k, int k, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#endif", // NOLINT "", // NOLINT -"", // NOLINT -"#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((reqd_work_group_size(8, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha, float beta, int padded_k, int k) { const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.x += TILE_K * sizeof(uint); MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT -"", // NOLINT -"", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb.s0123 = read_imagef(_B, sampler, _coordBTemp); _coordBTemp.x += 1; _blockb.s4567 = read_imagef(_B, sampler, _coordBTemp); _coordB.x += 2;", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s89ab = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.scdef = READ_IMAGE(_B, _coordBTemp); _coordB.x += 4;", // NOLINT +"#else", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;", // NOLINT +"#endif", // NOLINT "", // NOLINT "#define MATB_PARAMETER __read_only image2d_t B", // NOLINT "", // NOLINT @@ -3098,9 +3143,13 @@ static std::vector> cl_kernels{ "#undef BLOCKB_READ8", // NOLINT "#undef MATB_PARAMETER", // NOLINT "", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb = *(__global float8*)&_B[_coordBTemp.y * k + _coordBTemp.x + offB]; _coordB.x += TILE_K;", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); _blockb = as_half16(as_ushort16(vload8(0, B_read))); _coordB.x += TILE_K * 2;", // NOLINT +"#else", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); _blockb = vload8(0, B_read); _coordB.x += TILE_K;", // NOLINT +"#endif", // NOLINT "", // NOLINT -"#define MATB_PARAMETER __global float *B, int offB", // NOLINT +"#define MATB_PARAMETER __global float *B, int offB, int ldb", // NOLINT "", // NOLINT "GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0", // NOLINT "GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0", // NOLINT @@ -3109,8 +3158,11 @@ static std::vector> cl_kernels{ "#undef BLOCKB_READ8", // NOLINT "#undef MATB_PARAMETER", // NOLINT "", // NOLINT -"", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s8 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s9 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sa = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sb = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sc = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sd = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.se = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sf = temp.s0; _coordB.x += 16;", // NOLINT +"#else", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT +"#endif", // NOLINT "", // NOLINT "#define MATB_PARAMETER __read_only image2d_t B", // NOLINT "", // NOLINT @@ -3125,13 +3177,17 @@ static std::vector> cl_kernels{ "#undef TRANSPOSE_BLOCK_8", // NOLINT "", // NOLINT "//The same as GEMM_TN.", // NOLINT -"#define TRANSPOSE_BLOCK_8(_vec) (float8)( intel_sub_group_shuffle(_vec, 0), intel_sub_group_shuffle(_vec, 1), intel_sub_group_shuffle(_vec, 2), intel_sub_group_shuffle(_vec, 3), intel_sub_group_shuffle(_vec, 4), intel_sub_group_shuffle(_vec, 5), intel_sub_group_shuffle(_vec, 6), intel_sub_group_shuffle(_vec, 7) );", // NOLINT +"#define TRANSPOSE_BLOCK_8(_vec, _col) (float8)( intel_sub_group_shuffle(_vec, _col + 0), intel_sub_group_shuffle(_vec, _col + 1), intel_sub_group_shuffle(_vec, _col + 2), intel_sub_group_shuffle(_vec, _col + 3), intel_sub_group_shuffle(_vec, _col + 4), intel_sub_group_shuffle(_vec, _col + 5), intel_sub_group_shuffle(_vec, _col + 6), intel_sub_group_shuffle(_vec, _col + 7) );", // NOLINT "", // NOLINT -"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7 ); _result = mad( (float8)_blockB.s0, acol0, _result ); _result = mad( (float8)_blockB.s1, acol1, _result ); _result = mad( (float8)_blockB.s2, acol2, _result ); _result = mad( (float8)_blockB.s3, acol3, _result ); _result = mad( (float8)_blockB.s4, acol4, _result ); _result = mad( (float8)_blockB.s5, acol5, _result ); _result = mad( (float8)_blockB.s6, acol6, _result ); _result = mad( (float8)_blockB.s7, acol7, _result ); }", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); _result = mad( (float8)_blockB.s0, acol0, _result ); _result = mad( (float8)_blockB.s1, acol1, _result ); _result = mad( (float8)_blockB.s2, acol2, _result ); _result = mad( (float8)_blockB.s3, acol3, _result ); _result = mad( (float8)_blockB.s4, acol4, _result ); _result = mad( (float8)_blockB.s5, acol5, _result ); _result = mad( (float8)_blockB.s6, acol6, _result ); _result = mad( (float8)_blockB.s7, acol7, _result ); }", // NOLINT "", // NOLINT -"#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((reqd_work_group_size(8, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha, float beta, int padded_k, int k) { const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * sizeof(uint), 0 ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0);}", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha_in, float beta_in, int padded_k, int k, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0; float8 blockAxB01 = 0; float8 blockAxB02 = 0; float8 blockAxB03 = 0; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0);}", // NOLINT +"#else", // NOLINT +"#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha_in, float beta_in, int padded_k, int k, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00, 0 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0);}", // NOLINT +"#endif", // NOLINT "", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); blockB00.s0123 = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; blockB00.s4567 = read_imagef(B, _coordBTemp); _coordB.x += 2;", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;", // NOLINT "", // NOLINT "#define MATB_PARAMETER __read_only image2d_t B", // NOLINT "", // NOLINT @@ -3142,9 +3198,13 @@ static std::vector> cl_kernels{ "#undef BLOCKB_READ8", // NOLINT "#undef MATB_PARAMETER", // NOLINT "", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb = *(__global float8*)&_B[_coordBTemp.y * k + _coordBTemp.x + offB]; _coordB.x += TILE_K;", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); _blockb = as_half8(as_ushort8(vload4(0, B_read))); _coordB.x += TILE_K;", // NOLINT +"#else", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); _blockb = vload8(0, B_read); _coordB.x += TILE_K;", // NOLINT +"#endif", // NOLINT "", // NOLINT -"#define MATB_PARAMETER __global float *B, int offB", // NOLINT +"#define MATB_PARAMETER __global float *B, int offB, int ldb", // NOLINT "", // NOLINT "GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0", // NOLINT "GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0", // NOLINT @@ -3153,7 +3213,7 @@ static std::vector> cl_kernels{ "#undef BLOCKB_READ8", // NOLINT "#undef MATB_PARAMETER", // NOLINT "", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT "", // NOLINT "#define MATB_PARAMETER __read_only image2d_t B", // NOLINT "", // NOLINT @@ -3171,32 +3231,69 @@ static std::vector> cl_kernels{ "#undef TILE_K", // NOLINT "#undef TILE_N", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(gemm_buffer_copy_image,Dtype)(", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_copy_image_transpose, Dtype)(", // NOLINT +"__global float* A,", // NOLINT +"__write_only image2d_t ImA,", // NOLINT +"int offA,", // NOLINT +"int width,", // NOLINT +"int height,", // NOLINT +"int ldA)", // NOLINT +"{", // NOLINT +"const int gidx = get_global_id(0);", // NOLINT +"const int gidy = get_global_id(1);", // NOLINT +"int2 coord_dst = (int2)(gidx, gidy);", // NOLINT +"__global float* A_off = A + offA;", // NOLINT +"float srcA = A_off[gidy * ldA + gidx];", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"write_imageh(ImA, coord_dst, (float4)srcA);", // NOLINT +"#else", // NOLINT +"write_imagef(ImA, coord_dst, (float4)srcA);", // NOLINT +"#endif", // NOLINT +"}", // NOLINT +"", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)(", // NOLINT "__global float* A,", // NOLINT "__write_only image2d_t ImA,", // NOLINT "int offA,", // NOLINT "int width,", // NOLINT -"int height)", // NOLINT +"int height,", // NOLINT +"int ldA)", // NOLINT "{", // NOLINT "const int gidx = get_global_id(0);", // NOLINT "const int gidy = get_global_id(1);", // NOLINT "int2 coord_dst = (int2)(gidx, gidy);", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"if (gidx >= width || gidy >= height) {", // NOLINT +"write_imageh(ImA, coord_dst, 0);", // NOLINT +"return;", // NOLINT +"}", // NOLINT +"__global float* A_off = A + offA;", // NOLINT +"write_imageh(ImA, coord_dst, A_off[gidy * ldA + gidx]);", // NOLINT +"#else", // NOLINT "if (gidx >= width || gidy >= height) {", // NOLINT "write_imageui(ImA, coord_dst, (uint4)0);", // NOLINT "return;", // NOLINT "}", // NOLINT "__global float* A_off = A + offA;", // NOLINT -"uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * width + gidx]));", // NOLINT +"uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * ldA + gidx]));", // NOLINT "write_imageui(ImA, coord_dst, srcA);", // NOLINT +"#endif", // NOLINT "}", // NOLINT "", // NOLINT +"", // NOLINT "#define VEC_SIZE 4", // NOLINT "#define LWG_HEIGHT 4", // NOLINT "#define TILE_M 8", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define TILE_K 32", // NOLINT +"#define TILE_N 64", // NOLINT +"#else", // NOLINT "#define TILE_K 16", // NOLINT "#define TILE_N 32", // NOLINT +"#endif", // NOLINT "", // NOLINT -"__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT +"__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1)))", // NOLINT +"__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM)))", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NN, Dtype)(", // NOLINT "const __global float *src0, int off0,", // NOLINT "const __global float *src1, int off1,", // NOLINT @@ -3204,10 +3301,12 @@ static std::vector> cl_kernels{ "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha,", // NOLINT -"float beta,", // NOLINT +"float alpha_in,", // NOLINT +"float beta_in,", // NOLINT "int start_index)", // NOLINT "{", // NOLINT +"const float alpha = (float)alpha_in;", // NOLINT +"const float beta = (float)beta_in;", // NOLINT "const int group_x = get_group_id(0);", // NOLINT "const int group_y = get_group_id(1);", // NOLINT "const int local_x = get_local_id(0);", // NOLINT @@ -3218,11 +3317,11 @@ static std::vector> cl_kernels{ "float4 brow;", // NOLINT "float2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7;", // NOLINT "", // NOLINT -"__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"__global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT "", // NOLINT -"const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * K + start_index + off0;", // NOLINT +"const __global float *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0;", // NOLINT "", // NOLINT -"const __global float *src1_read0 = src1 + local_x * VEC_SIZE + ( group_x * TILE_N ) + start_index * N + off1;", // NOLINT +"const __global float *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;", // NOLINT "", // NOLINT "int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M);", // NOLINT "", // NOLINT @@ -3235,28 +3334,28 @@ static std::vector> cl_kernels{ "int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border;", // NOLINT "int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border;", // NOLINT "", // NOLINT -"float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : beta * ((__global float4 *)dst_write0)[0];", // NOLINT -"float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 1 * N))[0] : beta * ((__global float4 *)(dst_write0 + 1 * N))[0];", // NOLINT -"float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : beta * ((__global float4 *)(dst_write0 + 2 * N))[0];", // NOLINT -"float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : beta * ((__global float4 *)(dst_write0 + 3 * N))[0];", // NOLINT -"float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : beta * ((__global float4 *)(dst_write0 + 4 * N))[0];", // NOLINT -"float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : beta * ((__global float4 *)(dst_write0 + 5 * N))[0];", // NOLINT -"float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : beta * ((__global float4 *)(dst_write0 + 6 * N))[0];", // NOLINT -"float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : beta * ((__global float4 *)(dst_write0 + 7 * N))[0];", // NOLINT +"float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);", // NOLINT +"float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N);", // NOLINT +"float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);", // NOLINT +"float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);", // NOLINT +"float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);", // NOLINT +"float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);", // NOLINT +"float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);", // NOLINT +"float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);", // NOLINT "", // NOLINT "int end_index = min(start_index + 256, K);", // NOLINT "int w = start_index;", // NOLINT "while( w + TILE_K <= end_index ) {", // NOLINT -"arow0 = alpha * ((__global float2 *)(src0_read + row0 * K))[0];", // NOLINT -"arow1 = alpha * ((__global float2 *)(src0_read + row1 * K))[0];", // NOLINT -"arow2 = alpha * ((__global float2 *)(src0_read + row2 * K))[0];", // NOLINT -"arow3 = alpha * ((__global float2 *)(src0_read + row3 * K))[0];", // NOLINT -"arow4 = alpha * ((__global float2 *)(src0_read + row4 * K))[0];", // NOLINT -"arow5 = alpha * ((__global float2 *)(src0_read + row5 * K))[0];", // NOLINT -"arow6 = alpha * ((__global float2 *)(src0_read + row6 * K))[0];", // NOLINT -"arow7 = alpha * ((__global float2 *)(src0_read + row7 * K))[0];", // NOLINT -"", // NOLINT -"#define MM_DOT_PRODUCT( index, suffix ) brow = ((__global float4 *)src1_read0)[0]; src1_read0 += N; dot00 = mad( (float4)(intel_sub_group_shuffle( arow0, index ).s##suffix), brow, dot00 ); dot01 = mad( (float4)(intel_sub_group_shuffle( arow1, index ).s##suffix), brow, dot01 ); dot02 = mad( (float4)(intel_sub_group_shuffle( arow2, index ).s##suffix), brow, dot02 ); dot03 = mad( (float4)(intel_sub_group_shuffle( arow3, index ).s##suffix), brow, dot03 ); dot04 = mad( (float4)(intel_sub_group_shuffle( arow4, index ).s##suffix), brow, dot04 ); dot05 = mad( (float4)(intel_sub_group_shuffle( arow5, index ).s##suffix), brow, dot05 ); dot06 = mad( (float4)(intel_sub_group_shuffle( arow6, index ).s##suffix), brow, dot06 ); dot07 = mad( (float4)(intel_sub_group_shuffle( arow7, index ).s##suffix), brow, dot07 );", // NOLINT +"arow0 = alpha * vload2(0, src0_read + row0 * K);", // NOLINT +"arow1 = alpha * vload2(0, src0_read + row1 * K);", // NOLINT +"arow2 = alpha * vload2(0, src0_read + row2 * K);", // NOLINT +"arow3 = alpha * vload2(0, src0_read + row3 * K);", // NOLINT +"arow4 = alpha * vload2(0, src0_read + row4 * K);", // NOLINT +"arow5 = alpha * vload2(0, src0_read + row5 * K);", // NOLINT +"arow6 = alpha * vload2(0, src0_read + row6 * K);", // NOLINT +"arow7 = alpha * vload2(0, src0_read + row7 * K);", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( index, suffix ) brow = vload4(0, src1_read0); src1_read0 += N; dot00 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); dot01 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); dot02 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); dot03 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); dot04 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); dot05 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); dot06 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); dot07 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );", // NOLINT "MM_DOT_PRODUCT(0, 0);", // NOLINT "MM_DOT_PRODUCT(0, 1);", // NOLINT "MM_DOT_PRODUCT(1, 0);", // NOLINT @@ -3273,6 +3372,24 @@ static std::vector> cl_kernels{ "MM_DOT_PRODUCT(6, 1);", // NOLINT "MM_DOT_PRODUCT(7, 0);", // NOLINT "MM_DOT_PRODUCT(7, 1);", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"MM_DOT_PRODUCT(8, 0);", // NOLINT +"MM_DOT_PRODUCT(8, 1);", // NOLINT +"MM_DOT_PRODUCT(9, 0);", // NOLINT +"MM_DOT_PRODUCT(9, 1);", // NOLINT +"MM_DOT_PRODUCT(10, 0);", // NOLINT +"MM_DOT_PRODUCT(10, 1);", // NOLINT +"MM_DOT_PRODUCT(11, 0);", // NOLINT +"MM_DOT_PRODUCT(11, 1);", // NOLINT +"MM_DOT_PRODUCT(12, 0);", // NOLINT +"MM_DOT_PRODUCT(12, 1);", // NOLINT +"MM_DOT_PRODUCT(13, 0);", // NOLINT +"MM_DOT_PRODUCT(13, 1);", // NOLINT +"MM_DOT_PRODUCT(14, 0);", // NOLINT +"MM_DOT_PRODUCT(14, 1);", // NOLINT +"MM_DOT_PRODUCT(15, 0);", // NOLINT +"MM_DOT_PRODUCT(15, 1);", // NOLINT +"#endif", // NOLINT "#undef MM_DOT_PRODUCT", // NOLINT "", // NOLINT "src0_read += TILE_K;", // NOLINT @@ -3297,7 +3414,7 @@ static std::vector> cl_kernels{ "arow7.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row7 * K)[0] : 0.0f;", // NOLINT "arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f;", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( index, suffix ) brow = (w < K) ? ((__global float4 *)src1_read0)[0] : 0.0f; src1_read0 += N; w++; dot00 = mad( (float4)(intel_sub_group_shuffle( arow0, index ).s##suffix), brow, dot00 ); dot01 = mad( (float4)(intel_sub_group_shuffle( arow1, index ).s##suffix), brow, dot01 ); dot02 = mad( (float4)(intel_sub_group_shuffle( arow2, index ).s##suffix), brow, dot02 ); dot03 = mad( (float4)(intel_sub_group_shuffle( arow3, index ).s##suffix), brow, dot03 ); dot04 = mad( (float4)(intel_sub_group_shuffle( arow4, index ).s##suffix), brow, dot04 ); dot05 = mad( (float4)(intel_sub_group_shuffle( arow5, index ).s##suffix), brow, dot05 ); dot06 = mad( (float4)(intel_sub_group_shuffle( arow6, index ).s##suffix), brow, dot06 ); dot07 = mad( (float4)(intel_sub_group_shuffle( arow7, index ).s##suffix), brow, dot07 );", // NOLINT +"#define MM_DOT_PRODUCT( index, suffix ) brow = (w < K) ? vload4(0, src1_read0) : (float4)0.0f; src1_read0 += N; w++; dot00 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); dot01 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); dot02 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); dot03 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); dot04 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); dot05 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); dot06 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); dot07 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );", // NOLINT "MM_DOT_PRODUCT(0, 0);", // NOLINT "MM_DOT_PRODUCT(0, 1);", // NOLINT "MM_DOT_PRODUCT(1, 0);", // NOLINT @@ -3314,87 +3431,102 @@ static std::vector> cl_kernels{ "MM_DOT_PRODUCT(6, 1);", // NOLINT "MM_DOT_PRODUCT(7, 0);", // NOLINT "MM_DOT_PRODUCT(7, 1);", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"MM_DOT_PRODUCT(8, 0);", // NOLINT +"MM_DOT_PRODUCT(8, 1);", // NOLINT +"MM_DOT_PRODUCT(9, 0);", // NOLINT +"MM_DOT_PRODUCT(9, 1);", // NOLINT +"MM_DOT_PRODUCT(10, 0);", // NOLINT +"MM_DOT_PRODUCT(10, 1);", // NOLINT +"MM_DOT_PRODUCT(11, 0);", // NOLINT +"MM_DOT_PRODUCT(11, 1);", // NOLINT +"MM_DOT_PRODUCT(12, 0);", // NOLINT +"MM_DOT_PRODUCT(12, 1);", // NOLINT +"MM_DOT_PRODUCT(13, 0);", // NOLINT +"MM_DOT_PRODUCT(13, 1);", // NOLINT +"MM_DOT_PRODUCT(14, 0);", // NOLINT +"MM_DOT_PRODUCT(14, 1);", // NOLINT +"MM_DOT_PRODUCT(15, 0);", // NOLINT +"MM_DOT_PRODUCT(15, 1);", // NOLINT +"#endif", // NOLINT "#undef MM_DOT_PRODUCT", // NOLINT "}", // NOLINT "", // NOLINT "if(global_x * 4 < N && global_y * 8 < M) {", // NOLINT "if(mad24(global_x, 4, 3) < N) {", // NOLINT -"__global float4 *dst_write = (__global float4 *)dst_write0;", // NOLINT -"dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0;", // NOLINT -"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"vstore4(dot00, 0, dst_write0); dst_write0 += N;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; }", // NOLINT +"if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); }", // NOLINT "} else if(mad24(global_x, 4, 2) < N) {", // NOLINT -"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT -"dst_write[0] = dot00.xy;", // NOLINT +"vstore2(dot00.xy, 0, dst_write0);", // NOLINT "dst_write0[2] = dot00.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write0 += N;", // NOLINT "if(mad24(global_y, 8, 1) < M) {", // NOLINT -"dst_write[0] = dot01.xy;", // NOLINT +"vstore2(dot01.xy, 0, dst_write0);", // NOLINT "dst_write0[2] = dot01.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 2) < M) {", // NOLINT -"dst_write[0] = dot02.xy;", // NOLINT +"vstore2(dot02.xy, 0, dst_write0);", // NOLINT "dst_write0[2] = dot02.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 3) < M) {", // NOLINT -"dst_write[0] = dot03.xy;", // NOLINT +"vstore2(dot03.xy, 0, dst_write0);", // NOLINT "dst_write0[2] = dot03.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 4) < M) {", // NOLINT -"dst_write[0] = dot04.xy;", // NOLINT +"vstore2(dot04.xy, 0, dst_write0);", // NOLINT "dst_write0[2] = dot04.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 5) < M) {", // NOLINT -"dst_write[0] = dot05.xy;", // NOLINT +"vstore2(dot05.xy, 0, dst_write0);", // NOLINT "dst_write0[2] = dot05.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 6) < M) {", // NOLINT -"dst_write[0] = dot06.xy;", // NOLINT +"vstore2(dot06.xy, 0, dst_write0);", // NOLINT "dst_write0[2] = dot06.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 7) < M) {", // NOLINT -"dst_write[0] = dot07.xy;", // NOLINT +"vstore2(dot07.xy, 0, dst_write0);", // NOLINT "dst_write0[2] = dot07.z;", // NOLINT "}", // NOLINT "} else if(mad24(global_x, 4, 1) < N) {", // NOLINT -"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT -"dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT -"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"vstore2(dot00.xy, 0, dst_write0); dst_write0 += N;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; }", // NOLINT +"if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); }", // NOLINT "} else {", // NOLINT "dst_write0[0] = dot00.x; dst_write0 += N;", // NOLINT "if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }", // NOLINT @@ -3420,6 +3552,7 @@ static std::vector> cl_kernels{ "#undef TILE_K", // NOLINT "#undef TILE_N", // NOLINT "", // NOLINT +"", // NOLINT "#define VEC_SIZE 1", // NOLINT "#define LWG_HEIGHT 16", // NOLINT "#define TILE_M 8", // NOLINT @@ -3428,6 +3561,7 @@ static std::vector> cl_kernels{ "#define SLM_BLOCK 512", // NOLINT "", // NOLINT "__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT +"__attribute__((intel_reqd_sub_group_size(8)))", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(", // NOLINT "const __global float *src0, int off0,", // NOLINT "const __global float *src1, int off1,", // NOLINT @@ -3435,9 +3569,11 @@ static std::vector> cl_kernels{ "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha,", // NOLINT -"float beta)", // NOLINT +"float alpha_in,", // NOLINT +"float beta_in)", // NOLINT "{", // NOLINT +"const float alpha = (float)alpha_in;", // NOLINT +"const float beta = (float)beta_in;", // NOLINT "const int group_x = get_group_id(0);", // NOLINT "const int group_y = get_group_id(1);", // NOLINT "const int local_x = get_local_id(0);", // NOLINT @@ -3463,11 +3599,11 @@ static std::vector> cl_kernels{ "float4 brow6;", // NOLINT "float4 brow7;", // NOLINT "", // NOLINT -"__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"__global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT "", // NOLINT -"const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * K + off0;", // NOLINT +"const __global float *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;", // NOLINT "", // NOLINT -"const __global float *src1_read0 = src1 + ( group_x * TILE_N ) * K + off1;", // NOLINT +"const __global float *src1_read0 = src1 + (group_x * TILE_N) * K + off1;", // NOLINT "", // NOLINT "__local float slm_brow[8 * SLM_BLOCK];", // NOLINT "__local float* slm_brow0;", // NOLINT @@ -3476,14 +3612,14 @@ static std::vector> cl_kernels{ "int w;", // NOLINT "for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {", // NOLINT "barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT -"((__local float4 *)(slm_brow + mad24(0, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(0, K, local_index)))[0];", // NOLINT -"((__local float4 *)(slm_brow + mad24(1, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(1, K, local_index)))[0];", // NOLINT -"((__local float4 *)(slm_brow + mad24(2, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(2, K, local_index)))[0];", // NOLINT -"((__local float4 *)(slm_brow + mad24(3, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(3, K, local_index)))[0];", // NOLINT -"((__local float4 *)(slm_brow + mad24(4, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(4, K, local_index)))[0];", // NOLINT -"((__local float4 *)(slm_brow + mad24(5, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(5, K, local_index)))[0];", // NOLINT -"((__local float4 *)(slm_brow + mad24(6, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(6, K, local_index)))[0];", // NOLINT -"((__local float4 *)(slm_brow + mad24(7, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(7, K, local_index)))[0];", // NOLINT +"vstore4(vload4(0, src1_read0 + mad24(0, K, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index));", // NOLINT +"vstore4(vload4(0, src1_read0 + mad24(1, K, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index));", // NOLINT +"vstore4(vload4(0, src1_read0 + mad24(2, K, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index));", // NOLINT +"vstore4(vload4(0, src1_read0 + mad24(3, K, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index));", // NOLINT +"vstore4(vload4(0, src1_read0 + mad24(4, K, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index));", // NOLINT +"vstore4(vload4(0, src1_read0 + mad24(5, K, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index));", // NOLINT +"vstore4(vload4(0, src1_read0 + mad24(6, K, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index));", // NOLINT +"vstore4(vload4(0, src1_read0 + mad24(7, K, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index));", // NOLINT "barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT "", // NOLINT "slm_brow0 = slm_brow + local_x * (TILE_K / 8);", // NOLINT @@ -3492,16 +3628,16 @@ static std::vector> cl_kernels{ "while( w + TILE_K <= end_w ) {", // NOLINT "float4 arow;", // NOLINT "", // NOLINT -"brow0 = ((__local float4 *)(slm_brow0 + 0 * SLM_BLOCK))[0];", // NOLINT -"brow1 = ((__local float4 *)(slm_brow0 + 1 * SLM_BLOCK))[0];", // NOLINT -"brow2 = ((__local float4 *)(slm_brow0 + 2 * SLM_BLOCK))[0];", // NOLINT -"brow3 = ((__local float4 *)(slm_brow0 + 3 * SLM_BLOCK))[0];", // NOLINT -"brow4 = ((__local float4 *)(slm_brow0 + 4 * SLM_BLOCK))[0];", // NOLINT -"brow5 = ((__local float4 *)(slm_brow0 + 5 * SLM_BLOCK))[0];", // NOLINT -"brow6 = ((__local float4 *)(slm_brow0 + 6 * SLM_BLOCK))[0];", // NOLINT -"brow7 = ((__local float4 *)(slm_brow0 + 7 * SLM_BLOCK))[0];", // NOLINT +"brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK);", // NOLINT +"brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK);", // NOLINT +"brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK);", // NOLINT +"brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK);", // NOLINT +"brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK);", // NOLINT +"brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK);", // NOLINT +"brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK);", // NOLINT +"brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK);", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( _row, _dot ) arow = ((__global float4 *)(src0_read + _row * K))[0]; _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT +"#define MM_DOT_PRODUCT( _row, _dot ) arow = vload4(0, src0_read + _row * K); _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT "MM_DOT_PRODUCT( 0, dot00 );", // NOLINT "MM_DOT_PRODUCT( 1, dot01 );", // NOLINT "MM_DOT_PRODUCT( 2, dot02 );", // NOLINT @@ -3522,7 +3658,7 @@ static std::vector> cl_kernels{ "if(w < K) {", // NOLINT "float4 arow;", // NOLINT "", // NOLINT -"#define READ_BROW(_brow, _row) _brow = ((__local float4 *)(slm_brow0 + _row * SLM_BLOCK))[0]; _brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; _brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; _brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; _brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;", // NOLINT +"#define READ_BROW(_brow, _row) _brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); _brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; _brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; _brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; _brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;", // NOLINT "READ_BROW(brow0, 0);", // NOLINT "READ_BROW(brow1, 1);", // NOLINT "READ_BROW(brow2, 2);", // NOLINT @@ -3532,7 +3668,7 @@ static std::vector> cl_kernels{ "READ_BROW(brow6, 6);", // NOLINT "READ_BROW(brow7, 7);", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( _row, _dot ) arow = ((__global float4 *)(src0_read + _row * K))[0]; arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT +"#define MM_DOT_PRODUCT( _row, _dot ) arow = vload4(0, src0_read + _row * K); arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT "MM_DOT_PRODUCT( 0, dot00 );", // NOLINT "MM_DOT_PRODUCT( 1, dot01 );", // NOLINT "MM_DOT_PRODUCT( 2, dot02 );", // NOLINT @@ -3544,7 +3680,7 @@ static std::vector> cl_kernels{ "#undef MM_DOT_PRODUCT", // NOLINT "}", // NOLINT "", // NOLINT -"#define REDUCE(_dot) _dot = intel_sub_group_shuffle(_dot, 0) + intel_sub_group_shuffle(_dot, 1) + intel_sub_group_shuffle(_dot, 2) + intel_sub_group_shuffle(_dot, 3) + intel_sub_group_shuffle(_dot, 4) + intel_sub_group_shuffle(_dot, 5) + intel_sub_group_shuffle(_dot, 6) + intel_sub_group_shuffle(_dot, 7);", // NOLINT +"#define REDUCE(_dot) _dot = as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));", // NOLINT "REDUCE(dot00);", // NOLINT "REDUCE(dot01);", // NOLINT "REDUCE(dot02);", // NOLINT @@ -3580,32 +3716,32 @@ static std::vector> cl_kernels{ "", // NOLINT "#define SLM_SIZE 64", // NOLINT "void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(", // NOLINT -"const __global Dtype* srca_read0,", // NOLINT -"const __global Dtype* srca_read1,", // NOLINT -"const __global Dtype* srcb_read,", // NOLINT -"__local Dtype4* work0,", // NOLINT -"__local Dtype4* work1,", // NOLINT +"const __global float* srca_read0,", // NOLINT +"const __global float* srca_read1,", // NOLINT +"const __global float* srcb_read,", // NOLINT +"__local float4* work0,", // NOLINT +"__local float4* work1,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT "int x_gid,", // NOLINT "int lid,", // NOLINT -"Dtype alpha,", // NOLINT -"Dtype beta,", // NOLINT -"__global Dtype* dstc0,", // NOLINT -"__global Dtype* dstc1)", // NOLINT +"float alpha,", // NOLINT +"float beta,", // NOLINT +"__global float* dstc0,", // NOLINT +"__global float* dstc1)", // NOLINT "{", // NOLINT -"__local Dtype* work_each0 = (__local Dtype*)work0;", // NOLINT -"__local Dtype* work_each1 = (__local Dtype*)work1;", // NOLINT +"__local float* work_each0 = (__local float*)work0;", // NOLINT +"__local float* work_each1 = (__local float*)work1;", // NOLINT "", // NOLINT "int rows = N - x_gid * 4;", // NOLINT "", // NOLINT -"Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT -"Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"float4 dot0[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"float4 dot1[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT "", // NOLINT "int i = lid;", // NOLINT "while( i < K / 4) {", // NOLINT -"const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT -"const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT +"const float4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT +"const float4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < rows; ++j) {", // NOLINT "dot0[j] += b0 * vload4(i, srcb_read + j * K);", // NOLINT @@ -3624,13 +3760,13 @@ static std::vector> cl_kernels{ "short tail_items = K % 4;", // NOLINT "", // NOLINT "if(tail_items != 0) {", // NOLINT -"const __global Dtype *srcb_tail = srcb_read + i * 4;", // NOLINT -"const __global Dtype *srca_tail0 = srca_read0 + i * 4;", // NOLINT -"const __global Dtype *srca_tail1 = srca_read1 + i * 4;", // NOLINT +"const __global float *srcb_tail = srcb_read + i * 4;", // NOLINT +"const __global float *srca_tail0 = srca_read0 + i * 4;", // NOLINT +"const __global float *srca_tail1 = srca_read1 + i * 4;", // NOLINT "#pragma unroll", // NOLINT "for(short i = 0; i < tail_items; ++i) {", // NOLINT -"const Dtype at0 = srca_tail0[i];", // NOLINT -"const Dtype at1 = srca_tail1[i];", // NOLINT +"const float at0 = srca_tail0[i];", // NOLINT +"const float at1 = srca_tail1[i];", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < rows; ++j) {", // NOLINT "work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];", // NOLINT @@ -3658,11 +3794,11 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(", // NOLINT -"__global const Dtype * A,", // NOLINT +"__global const float * A,", // NOLINT "int offA,", // NOLINT -"__global const Dtype * B,", // NOLINT +"__global const float * B,", // NOLINT "int offB,", // NOLINT -"__global Dtype * C,", // NOLINT +"__global float * C,", // NOLINT "int offC,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT @@ -3670,36 +3806,36 @@ static std::vector> cl_kernels{ "float alpha_f,", // NOLINT "float beta_f)", // NOLINT "{", // NOLINT -"Dtype alpha = (Dtype)alpha_f;", // NOLINT -"Dtype beta = (Dtype)beta_f;", // NOLINT +"float alpha = (float)alpha_f;", // NOLINT +"float beta = (float)beta_f;", // NOLINT "int x_gid = get_group_id(0);", // NOLINT "int lid = get_local_id(0);", // NOLINT "", // NOLINT -"const __global Dtype *srca_read0 = A + offA;", // NOLINT -"const __global Dtype *srca_read1 = srca_read0 + K;", // NOLINT +"const __global float *srca_read0 = A + offA;", // NOLINT +"const __global float *srca_read1 = srca_read0 + K;", // NOLINT "", // NOLINT -"const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT +"const __global float *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT "", // NOLINT -"__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);", // NOLINT -"__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);", // NOLINT +"__global float4 *dstc0 = (__global float4*)(C + offC);", // NOLINT +"__global float4 *dstc1 = (__global float4*)((__global float*)(dstc0) + N);", // NOLINT "", // NOLINT -"__local Dtype4 work0[SLM_SIZE];", // NOLINT -"__local Dtype4 work1[SLM_SIZE];", // NOLINT -"__local Dtype* work_each0 = (__local Dtype*)work0;", // NOLINT -"__local Dtype* work_each1 = (__local Dtype*)work1;", // NOLINT +"__local float4 work0[SLM_SIZE];", // NOLINT +"__local float4 work1[SLM_SIZE];", // NOLINT +"__local float* work_each0 = (__local float*)work0;", // NOLINT +"__local float* work_each1 = (__local float*)work1;", // NOLINT "", // NOLINT "if(x_gid == N / 4) {", // NOLINT -"TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1);", // NOLINT +"TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global float*)dstc0, (__global float*)dstc1);", // NOLINT "} else {", // NOLINT -"Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT -"Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"float4 dot0[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"float4 dot1[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT "int i = lid;", // NOLINT "while( i < K / 4) {", // NOLINT -"const Dtype4 b0 = vload4(i, srca_read0);", // NOLINT -"const Dtype4 b1 = vload4(i, srca_read1);", // NOLINT +"const float4 b0 = vload4(i, srca_read0);", // NOLINT +"const float4 b1 = vload4(i, srca_read1);", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < 4; ++j) {", // NOLINT -"Dtype4 a = vload4(i, srcb_read + j * K);", // NOLINT +"float4 a = vload4(i, srcb_read + j * K);", // NOLINT "dot0[j] += b0 * a;", // NOLINT "dot1[j] += b1 * a;", // NOLINT "}", // NOLINT @@ -3715,14 +3851,14 @@ static std::vector> cl_kernels{ "if(i == K / 4) {", // NOLINT "short tail_items = K % 4;", // NOLINT "if(tail_items != 0) {", // NOLINT -"const __global Dtype *srcb_tail = srcb_read + i * 4;", // NOLINT +"const __global float *srcb_tail = srcb_read + i * 4;", // NOLINT "", // NOLINT -"const __global Dtype *srca_tail0 = srca_read0 + i * 4;", // NOLINT -"const __global Dtype *srca_tail1 = srca_read1 + i * 4;", // NOLINT +"const __global float *srca_tail0 = srca_read0 + i * 4;", // NOLINT +"const __global float *srca_tail1 = srca_read1 + i * 4;", // NOLINT "#pragma unroll", // NOLINT "for(short i = 0; i < tail_items; ++i) {", // NOLINT -"const Dtype at0 = srca_tail0[i];", // NOLINT -"const Dtype at1 = srca_tail1[i];", // NOLINT +"const float at0 = srca_tail0[i];", // NOLINT +"const float at1 = srca_tail1[i];", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < 4; ++j) {", // NOLINT "work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];", // NOLINT @@ -3750,44 +3886,44 @@ static std::vector> cl_kernels{ "", // NOLINT "#define SLM_SIZE 32", // NOLINT "void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(", // NOLINT -"const __global Dtype* srca_read0,", // NOLINT -"const __global Dtype* srca_read1,", // NOLINT -"const __global Dtype* srca_read2,", // NOLINT -"const __global Dtype* srca_read3,", // NOLINT -"const __global Dtype* srcb_read,", // NOLINT -"__local Dtype4* work0,", // NOLINT -"__local Dtype4* work1,", // NOLINT -"__local Dtype4* work2,", // NOLINT -"__local Dtype4* work3,", // NOLINT +"const __global float* srca_read0,", // NOLINT +"const __global float* srca_read1,", // NOLINT +"const __global float* srca_read2,", // NOLINT +"const __global float* srca_read3,", // NOLINT +"const __global float* srcb_read,", // NOLINT +"__local float4* work0,", // NOLINT +"__local float4* work1,", // NOLINT +"__local float4* work2,", // NOLINT +"__local float4* work3,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT "int x_gid,", // NOLINT "int lid,", // NOLINT -"Dtype alpha,", // NOLINT -"Dtype beta,", // NOLINT -"__global Dtype* dstc0,", // NOLINT -"__global Dtype* dstc1,", // NOLINT -"__global Dtype* dstc2,", // NOLINT -"__global Dtype* dstc3)", // NOLINT +"float alpha,", // NOLINT +"float beta,", // NOLINT +"__global float* dstc0,", // NOLINT +"__global float* dstc1,", // NOLINT +"__global float* dstc2,", // NOLINT +"__global float* dstc3)", // NOLINT "{", // NOLINT -"__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);", // NOLINT -"__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);", // NOLINT -"__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);", // NOLINT -"__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);", // NOLINT +"__local float* work_each0 = (__local float*)(work0 + lid);", // NOLINT +"__local float* work_each1 = (__local float*)(work1 + lid);", // NOLINT +"__local float* work_each2 = (__local float*)(work2 + lid);", // NOLINT +"__local float* work_each3 = (__local float*)(work3 + lid);", // NOLINT "", // NOLINT "int rows = N - x_gid * 4;", // NOLINT "", // NOLINT -"Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT -"Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT -"Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT -"Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"float4 dot0[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"float4 dot1[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"float4 dot2[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"float4 dot3[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT "", // NOLINT "int i = lid;", // NOLINT "while( i < K / 4) {", // NOLINT -"const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT -"const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT -"const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]};", // NOLINT -"const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]};", // NOLINT +"const float4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT +"const float4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT +"const float4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]};", // NOLINT +"const float4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]};", // NOLINT "#pragma unrol", // NOLINT "for(int j = 0; j < rows; ++j) {", // NOLINT "dot0[j] += a0 * vload4(i, srcb_read + j * K);", // NOLINT @@ -3810,18 +3946,18 @@ static std::vector> cl_kernels{ "short tail_items = K % 4;", // NOLINT "", // NOLINT "if(tail_items != 0) {", // NOLINT -"const __global Dtype *srcb_tail = srcb_read + i * 4;", // NOLINT +"const __global float *srcb_tail = srcb_read + i * 4;", // NOLINT "", // NOLINT -"const __global Dtype *srca_tail0 = srca_read0 + i * 4;", // NOLINT -"const __global Dtype *srca_tail1 = srca_read1 + i * 4;", // NOLINT -"const __global Dtype *srca_tail2 = srca_read2 + i * 4;", // NOLINT -"const __global Dtype *srca_tail3 = srca_read3 + i * 4;", // NOLINT +"const __global float *srca_tail0 = srca_read0 + i * 4;", // NOLINT +"const __global float *srca_tail1 = srca_read1 + i * 4;", // NOLINT +"const __global float *srca_tail2 = srca_read2 + i * 4;", // NOLINT +"const __global float *srca_tail3 = srca_read3 + i * 4;", // NOLINT "#pragma unroll", // NOLINT "for(short i = 0; i < tail_items; ++i) {", // NOLINT -"const Dtype at0 = srca_tail0[i];", // NOLINT -"const Dtype at1 = srca_tail1[i];", // NOLINT -"const Dtype at2 = srca_tail2[i];", // NOLINT -"const Dtype at3 = srca_tail3[i];", // NOLINT +"const float at0 = srca_tail0[i];", // NOLINT +"const float at1 = srca_tail1[i];", // NOLINT +"const float at2 = srca_tail2[i];", // NOLINT +"const float at3 = srca_tail3[i];", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < rows; ++j) {", // NOLINT "work_each0[j] += at0 * srcb_tail[i + j * K];", // NOLINT @@ -3855,11 +3991,11 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(", // NOLINT -"__global const Dtype * A,", // NOLINT +"__global const float * A,", // NOLINT "int offA,", // NOLINT -"__global const Dtype * B,", // NOLINT +"__global const float * B,", // NOLINT "int offB,", // NOLINT -"__global Dtype * C,", // NOLINT +"__global float * C,", // NOLINT "int offC,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT @@ -3867,50 +4003,50 @@ static std::vector> cl_kernels{ "float alpha_f,", // NOLINT "float beta_f)", // NOLINT "{", // NOLINT -"Dtype alpha = (Dtype)alpha_f;", // NOLINT -"Dtype beta = (Dtype)beta_f;", // NOLINT +"float alpha = (float)alpha_f;", // NOLINT +"float beta = (float)beta_f;", // NOLINT "int x_gid = get_group_id(0);", // NOLINT "int lid = get_local_id(0);", // NOLINT "int lsize = get_local_size(0);", // NOLINT "", // NOLINT -"const __global Dtype *srca_read0 = A + offA;", // NOLINT -"const __global Dtype *srca_read1 = srca_read0 + K;", // NOLINT -"const __global Dtype *srca_read2 = srca_read1 + K;", // NOLINT -"const __global Dtype *srca_read3 = srca_read2 + K;", // NOLINT +"const __global float *srca_read0 = A + offA;", // NOLINT +"const __global float *srca_read1 = srca_read0 + K;", // NOLINT +"const __global float *srca_read2 = srca_read1 + K;", // NOLINT +"const __global float *srca_read3 = srca_read2 + K;", // NOLINT "", // NOLINT -"const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT +"const __global float *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT "", // NOLINT -"__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);", // NOLINT -"__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);", // NOLINT -"__global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N);", // NOLINT -"__global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N);", // NOLINT +"__global float4 *dstc0 = (__global float4*)(C + offC);", // NOLINT +"__global float4 *dstc1 = (__global float4*)((__global float*)(dstc0) + N);", // NOLINT +"__global float4 *dstc2 = (__global float4*)((__global float*)(dstc1) + N);", // NOLINT +"__global float4 *dstc3 = (__global float4*)((__global float*)(dstc2) + N);", // NOLINT "", // NOLINT -"__local Dtype4 work0[SLM_SIZE];", // NOLINT -"__local Dtype4 work1[SLM_SIZE];", // NOLINT -"__local Dtype4 work2[SLM_SIZE];", // NOLINT -"__local Dtype4 work3[SLM_SIZE];", // NOLINT -"__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);", // NOLINT -"__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);", // NOLINT -"__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);", // NOLINT -"__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);", // NOLINT +"__local float4 work0[SLM_SIZE];", // NOLINT +"__local float4 work1[SLM_SIZE];", // NOLINT +"__local float4 work2[SLM_SIZE];", // NOLINT +"__local float4 work3[SLM_SIZE];", // NOLINT +"__local float* work_each0 = (__local float*)(work0 + lid);", // NOLINT +"__local float* work_each1 = (__local float*)(work1 + lid);", // NOLINT +"__local float* work_each2 = (__local float*)(work2 + lid);", // NOLINT +"__local float* work_each3 = (__local float*)(work3 + lid);", // NOLINT "", // NOLINT "if(x_gid == N / 4) {", // NOLINT -"TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) (srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3);", // NOLINT +"TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) (srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, (__global float*)dstc0, (__global float*)dstc1, (__global float*)dstc2, (__global float*)dstc3);", // NOLINT "} else {", // NOLINT -"Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT -"Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT -"Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT -"Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"float4 dot0[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"float4 dot1[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"float4 dot2[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"float4 dot3[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT "", // NOLINT "int kid = lid;", // NOLINT "while( kid < K / 4) {", // NOLINT -"const Dtype4 b0 = vload4(kid, srca_read0);", // NOLINT -"const Dtype4 b1 = vload4(kid, srca_read1);", // NOLINT -"const Dtype4 b2 = vload4(kid, srca_read2);", // NOLINT -"const Dtype4 b3 = vload4(kid, srca_read3);", // NOLINT +"const float4 b0 = vload4(kid, srca_read0);", // NOLINT +"const float4 b1 = vload4(kid, srca_read1);", // NOLINT +"const float4 b2 = vload4(kid, srca_read2);", // NOLINT +"const float4 b3 = vload4(kid, srca_read3);", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < 4; ++j) {", // NOLINT -"Dtype4 a = vload4(kid, srcb_read + j * K);", // NOLINT +"float4 a = vload4(kid, srcb_read + j * K);", // NOLINT "dot0[j] += b0 * a;", // NOLINT "dot1[j] += b1 * a;", // NOLINT "dot2[j] += b2 * a;", // NOLINT @@ -3930,18 +4066,18 @@ static std::vector> cl_kernels{ "short tail_items = K % 4;", // NOLINT "if(tail_items != 0) {", // NOLINT "int offset = kid << 2;", // NOLINT -"const __global Dtype *srcb_tail = srcb_read + offset;", // NOLINT +"const __global float *srcb_tail = srcb_read + offset;", // NOLINT "", // NOLINT -"const __global Dtype *srca_tail0 = srca_read0 + offset;", // NOLINT -"const __global Dtype *srca_tail1 = srca_read1 + offset;", // NOLINT -"const __global Dtype *srca_tail2 = srca_read2 + offset;", // NOLINT -"const __global Dtype *srca_tail3 = srca_read3 + offset;", // NOLINT +"const __global float *srca_tail0 = srca_read0 + offset;", // NOLINT +"const __global float *srca_tail1 = srca_read1 + offset;", // NOLINT +"const __global float *srca_tail2 = srca_read2 + offset;", // NOLINT +"const __global float *srca_tail3 = srca_read3 + offset;", // NOLINT "#pragma unroll", // NOLINT "for(short i = 0; i < tail_items; ++i) {", // NOLINT -"const Dtype at0 = srca_tail0[i];", // NOLINT -"const Dtype at1 = srca_tail1[i];", // NOLINT -"const Dtype at2 = srca_tail2[i];", // NOLINT -"const Dtype at3 = srca_tail3[i];", // NOLINT +"const float at0 = srca_tail0[i];", // NOLINT +"const float at1 = srca_tail1[i];", // NOLINT +"const float at2 = srca_tail2[i];", // NOLINT +"const float at3 = srca_tail3[i];", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < 4; ++j) {", // NOLINT "work_each0[j] += at0 * srcb_tail[i + j * K];", // NOLINT @@ -3975,11 +4111,11 @@ static std::vector> cl_kernels{ "", // NOLINT "#define SLM_SIZE 16", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(", // NOLINT -"__global const Dtype * A,", // NOLINT +"__global const float * A,", // NOLINT "int offA,", // NOLINT -"__global const Dtype * B,", // NOLINT +"__global const float * B,", // NOLINT "int offB,", // NOLINT -"__global Dtype * C,", // NOLINT +"__global float * C,", // NOLINT "int offC,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT @@ -3987,61 +4123,61 @@ static std::vector> cl_kernels{ "float alpha_f,", // NOLINT "float beta_f)", // NOLINT "{", // NOLINT -"Dtype alpha = (Dtype)alpha_f;", // NOLINT -"Dtype beta = (Dtype)beta_f;", // NOLINT +"float alpha = (float)alpha_f;", // NOLINT +"float beta = (float)beta_f;", // NOLINT "int x_gid = get_group_id(0);", // NOLINT "int lid = get_local_id(0);", // NOLINT "int lsize = get_local_size(0);", // NOLINT "", // NOLINT -"const __global Dtype *srca_read0 = A + offA;", // NOLINT -"const __global Dtype *srca_read1 = srca_read0 + K;", // NOLINT -"const __global Dtype *srca_read2 = srca_read1 + K;", // NOLINT -"const __global Dtype *srca_read3 = srca_read2 + K;", // NOLINT -"const __global Dtype *srca_read4 = srca_read3 + K;", // NOLINT -"const __global Dtype *srca_read5 = srca_read4 + K;", // NOLINT -"const __global Dtype *srca_read6 = srca_read5 + K;", // NOLINT -"const __global Dtype *srca_read7 = srca_read6 + K;", // NOLINT -"", // NOLINT -"const __global Dtype *srcb_read = B + x_gid * K + offB;", // NOLINT -"", // NOLINT -"__global Dtype *dstc0 = C + offC;", // NOLINT -"__global Dtype *dstc1 = dstc0 + N;", // NOLINT -"__global Dtype *dstc2 = dstc1 + N;", // NOLINT -"__global Dtype *dstc3 = dstc2 + N;", // NOLINT -"__global Dtype *dstc4 = dstc3 + N;", // NOLINT -"__global Dtype *dstc5 = dstc4 + N;", // NOLINT -"__global Dtype *dstc6 = dstc5 + N;", // NOLINT -"__global Dtype *dstc7 = dstc6 + N;", // NOLINT -"", // NOLINT -"__local Dtype work0[SLM_SIZE];", // NOLINT -"__local Dtype work1[SLM_SIZE];", // NOLINT -"__local Dtype work2[SLM_SIZE];", // NOLINT -"__local Dtype work3[SLM_SIZE];", // NOLINT -"__local Dtype work4[SLM_SIZE];", // NOLINT -"__local Dtype work5[SLM_SIZE];", // NOLINT -"__local Dtype work6[SLM_SIZE];", // NOLINT -"__local Dtype work7[SLM_SIZE];", // NOLINT -"", // NOLINT -"Dtype4 dot0 = (Dtype4)(0.);", // NOLINT -"Dtype4 dot1 = (Dtype4)(0.);", // NOLINT -"Dtype4 dot2 = (Dtype4)(0.);", // NOLINT -"Dtype4 dot3 = (Dtype4)(0.);", // NOLINT -"Dtype4 dot4 = (Dtype4)(0.);", // NOLINT -"Dtype4 dot5 = (Dtype4)(0.);", // NOLINT -"Dtype4 dot6 = (Dtype4)(0.);", // NOLINT -"Dtype4 dot7 = (Dtype4)(0.);", // NOLINT +"const __global float *srca_read0 = A + offA;", // NOLINT +"const __global float *srca_read1 = srca_read0 + K;", // NOLINT +"const __global float *srca_read2 = srca_read1 + K;", // NOLINT +"const __global float *srca_read3 = srca_read2 + K;", // NOLINT +"const __global float *srca_read4 = srca_read3 + K;", // NOLINT +"const __global float *srca_read5 = srca_read4 + K;", // NOLINT +"const __global float *srca_read6 = srca_read5 + K;", // NOLINT +"const __global float *srca_read7 = srca_read6 + K;", // NOLINT +"", // NOLINT +"const __global float *srcb_read = B + x_gid * K + offB;", // NOLINT +"", // NOLINT +"__global float *dstc0 = C + offC;", // NOLINT +"__global float *dstc1 = dstc0 + N;", // NOLINT +"__global float *dstc2 = dstc1 + N;", // NOLINT +"__global float *dstc3 = dstc2 + N;", // NOLINT +"__global float *dstc4 = dstc3 + N;", // NOLINT +"__global float *dstc5 = dstc4 + N;", // NOLINT +"__global float *dstc6 = dstc5 + N;", // NOLINT +"__global float *dstc7 = dstc6 + N;", // NOLINT +"", // NOLINT +"__local float work0[SLM_SIZE];", // NOLINT +"__local float work1[SLM_SIZE];", // NOLINT +"__local float work2[SLM_SIZE];", // NOLINT +"__local float work3[SLM_SIZE];", // NOLINT +"__local float work4[SLM_SIZE];", // NOLINT +"__local float work5[SLM_SIZE];", // NOLINT +"__local float work6[SLM_SIZE];", // NOLINT +"__local float work7[SLM_SIZE];", // NOLINT +"", // NOLINT +"float4 dot0 = (float4)(0.);", // NOLINT +"float4 dot1 = (float4)(0.);", // NOLINT +"float4 dot2 = (float4)(0.);", // NOLINT +"float4 dot3 = (float4)(0.);", // NOLINT +"float4 dot4 = (float4)(0.);", // NOLINT +"float4 dot5 = (float4)(0.);", // NOLINT +"float4 dot6 = (float4)(0.);", // NOLINT +"float4 dot7 = (float4)(0.);", // NOLINT "", // NOLINT "int kid = lid;", // NOLINT "while( kid < K / 4) {", // NOLINT -"const Dtype4 a0 = vload4(kid, srca_read0);", // NOLINT -"const Dtype4 a1 = vload4(kid, srca_read1);", // NOLINT -"const Dtype4 a2 = vload4(kid, srca_read2);", // NOLINT -"const Dtype4 a3 = vload4(kid, srca_read3);", // NOLINT -"const Dtype4 a4 = vload4(kid, srca_read4);", // NOLINT -"const Dtype4 a5 = vload4(kid, srca_read5);", // NOLINT -"const Dtype4 a6 = vload4(kid, srca_read6);", // NOLINT -"const Dtype4 a7 = vload4(kid, srca_read7);", // NOLINT -"Dtype4 b = vload4(kid, srcb_read);", // NOLINT +"const float4 a0 = vload4(kid, srca_read0);", // NOLINT +"const float4 a1 = vload4(kid, srca_read1);", // NOLINT +"const float4 a2 = vload4(kid, srca_read2);", // NOLINT +"const float4 a3 = vload4(kid, srca_read3);", // NOLINT +"const float4 a4 = vload4(kid, srca_read4);", // NOLINT +"const float4 a5 = vload4(kid, srca_read5);", // NOLINT +"const float4 a6 = vload4(kid, srca_read6);", // NOLINT +"const float4 a7 = vload4(kid, srca_read7);", // NOLINT +"float4 b = vload4(kid, srcb_read);", // NOLINT "dot0 += a0 * b;", // NOLINT "dot1 += a1 * b;", // NOLINT "dot2 += a2 * b;", // NOLINT @@ -4066,16 +4202,16 @@ static std::vector> cl_kernels{ "short tail_items = K % 4;", // NOLINT "if(tail_items != 0) {", // NOLINT "int offset = kid << 2;", // NOLINT -"const __global Dtype *srcb_tail = srcb_read + offset;", // NOLINT -"", // NOLINT -"const __global Dtype *srca_tail0 = srca_read0 + offset;", // NOLINT -"const __global Dtype *srca_tail1 = srca_read1 + offset;", // NOLINT -"const __global Dtype *srca_tail2 = srca_read2 + offset;", // NOLINT -"const __global Dtype *srca_tail3 = srca_read3 + offset;", // NOLINT -"const __global Dtype *srca_tail4 = srca_read4 + offset;", // NOLINT -"const __global Dtype *srca_tail5 = srca_read5 + offset;", // NOLINT -"const __global Dtype *srca_tail6 = srca_read6 + offset;", // NOLINT -"const __global Dtype *srca_tail7 = srca_read7 + offset;", // NOLINT +"const __global float *srcb_tail = srcb_read + offset;", // NOLINT +"", // NOLINT +"const __global float *srca_tail0 = srca_read0 + offset;", // NOLINT +"const __global float *srca_tail1 = srca_read1 + offset;", // NOLINT +"const __global float *srca_tail2 = srca_read2 + offset;", // NOLINT +"const __global float *srca_tail3 = srca_read3 + offset;", // NOLINT +"const __global float *srca_tail4 = srca_read4 + offset;", // NOLINT +"const __global float *srca_tail5 = srca_read5 + offset;", // NOLINT +"const __global float *srca_tail6 = srca_read6 + offset;", // NOLINT +"const __global float *srca_tail7 = srca_read7 + offset;", // NOLINT "#pragma unroll", // NOLINT "for(short item = 0; item < tail_items; ++item) {", // NOLINT "work0[lid] += srca_tail0[item] * srcb_tail[item];", // NOLINT @@ -4120,10 +4256,16 @@ static std::vector> cl_kernels{ "#define VEC_SIZE 4", // NOLINT "#define LWG_HEIGHT 4", // NOLINT "#define TILE_M 8", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define TILE_K 32", // NOLINT +"#define TILE_N 64", // NOLINT +"#else", // NOLINT "#define TILE_K 16", // NOLINT "#define TILE_N 32", // NOLINT +"#endif", // NOLINT "", // NOLINT -"__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT +"__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1)))", // NOLINT +"__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM)))", // NOLINT "__kernel void TEMPLATE(gemm_buffer_TN, Dtype)(", // NOLINT "const __global float *src0, int off0,", // NOLINT "const __global float *src1, int off1,", // NOLINT @@ -4131,11 +4273,13 @@ static std::vector> cl_kernels{ "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha,", // NOLINT -"float beta,", // NOLINT +"float alpha_in,", // NOLINT +"float beta_in,", // NOLINT "int start_index)", // NOLINT "", // NOLINT "{", // NOLINT +"const float alpha = (float)alpha_in;", // NOLINT +"const float beta = (float)beta_in;", // NOLINT "const int group_x = get_group_id(0);", // NOLINT "const int group_y = get_group_id(1);", // NOLINT "const int local_x = get_local_id(0);", // NOLINT @@ -4145,43 +4289,61 @@ static std::vector> cl_kernels{ "", // NOLINT "float4 brow;", // NOLINT "", // NOLINT -"__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"__global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT "", // NOLINT -"const __global float *src0_read = src0 + (local_x * ( TILE_K / 8 ) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT +"const __global float *src0_read = src0 + (local_x * (TILE_K / SIMD_SIZE_GEMM) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT "", // NOLINT -"const __global float *src1_read0 = src1 + local_x * VEC_SIZE + ( group_x * TILE_N ) + start_index * N + off1;", // NOLINT +"const __global float *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;", // NOLINT "", // NOLINT -"float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : (beta * ((__global float4 *)dst_write0)[0]);", // NOLINT -"float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + N))[0] : (beta * ((__global float4 *)(dst_write0 + N))[0]);", // NOLINT -"float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 2 * N))[0]);", // NOLINT -"float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 3 * N))[0]);", // NOLINT -"float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 4 * N))[0]);", // NOLINT -"float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 5 * N))[0]);", // NOLINT -"float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 6 * N))[0]);", // NOLINT -"float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 7 * N))[0]);", // NOLINT +"float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);", // NOLINT +"float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N);", // NOLINT +"float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);", // NOLINT +"float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);", // NOLINT +"float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);", // NOLINT +"float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);", // NOLINT +"float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);", // NOLINT +"float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);", // NOLINT "", // NOLINT "int end_index = min(start_index + 256, K);", // NOLINT "while( start_index + TILE_K <= end_index ) {", // NOLINT -"float8 arow0 = alpha * ((__global float8 *)src0_read)[0];", // NOLINT -"float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0];", // NOLINT -"", // NOLINT -"#define MM_DOT_PRODUCT( _arow ) brow = ((__global float4 *)src1_read0)[0]; src1_read0 += N; dot00 = mad( (float4)(_arow.s0), brow, dot00 ); dot01 = mad( (float4)(_arow.s1), brow, dot01 ); dot02 = mad( (float4)(_arow.s2), brow, dot02 ); dot03 = mad( (float4)(_arow.s3), brow, dot03 ); dot04 = mad( (float4)(_arow.s4), brow, dot04 ); dot05 = mad( (float4)(_arow.s5), brow, dot05 ); dot06 = mad( (float4)(_arow.s6), brow, dot06 ); dot07 = mad( (float4)(_arow.s7), brow, dot07 );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 0 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 0 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 1 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 1 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 2 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 2 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 3 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 3 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 4 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 4 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 5 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 5 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 6 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 6 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 7 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 7 ) );", // NOLINT +"float8 arow0 = alpha * vload8(0, src0_read);", // NOLINT +"float8 arow1 = alpha * vload8(0, src0_read + M);", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _arow ) brow = vload4(0, src1_read0); src1_read0 += N; dot00 = mad( (float4)(_arow.s0), brow, dot00 ); dot01 = mad( (float4)(_arow.s1), brow, dot01 ); dot02 = mad( (float4)(_arow.s2), brow, dot02 ); dot03 = mad( (float4)(_arow.s3), brow, dot03 ); dot04 = mad( (float4)(_arow.s4), brow, dot04 ); dot05 = mad( (float4)(_arow.s5), brow, dot05 ); dot06 = mad( (float4)(_arow.s6), brow, dot06 ); dot07 = mad( (float4)(_arow.s7), brow, dot07 );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) );", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) );", // NOLINT +"#endif", // NOLINT "#undef MM_DOT_PRODUCT", // NOLINT "", // NOLINT "src0_read += TILE_K * M;", // NOLINT @@ -4189,99 +4351,114 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "if(start_index < end_index) {", // NOLINT -"float8 arow0 = ((start_index + local_x * 2) < K) ? (alpha * ((__global float8 *)src0_read)[0]) : 0.0f;", // NOLINT -"float8 arow1 = ((start_index + local_x * 2 + 1) < K) ? (alpha * ((__global float8 *)(src0_read + M))[0]) : 0.0f;", // NOLINT -"", // NOLINT -"#define MM_DOT_PRODUCT( _arow ) brow = (start_index < K) ? ((__global float4 *)src1_read0)[0] : 0.0f; src1_read0 += N; start_index++; dot00 = mad( (float4)(_arow.s0), brow, dot00 ); dot01 = mad( (float4)(_arow.s1), brow, dot01 ); dot02 = mad( (float4)(_arow.s2), brow, dot02 ); dot03 = mad( (float4)(_arow.s3), brow, dot03 ); dot04 = mad( (float4)(_arow.s4), brow, dot04 ); dot05 = mad( (float4)(_arow.s5), brow, dot05 ); dot06 = mad( (float4)(_arow.s6), brow, dot06 ); dot07 = mad( (float4)(_arow.s7), brow, dot07 );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 0 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 0 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 1 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 1 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 2 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 2 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 3 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 3 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 4 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 4 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 5 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 5 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 6 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 6 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 7 ) );", // NOLINT -"MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 7 ) );", // NOLINT +"float8 arow0 = ((start_index + local_x * 2) < K) ? alpha * vload8(0, src0_read) : (float8)0.0f;", // NOLINT +"float8 arow1 = ((start_index + local_x * 2 + 1) < K) ? alpha * vload8(0, src0_read + M) : (float8)0.0f;", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _arow ) brow = (start_index < K) ? vload4(0, src1_read0) : (float4)0.0f; src1_read0 += N; start_index++; dot00 = mad( (float4)(_arow.s0), brow, dot00 ); dot01 = mad( (float4)(_arow.s1), brow, dot01 ); dot02 = mad( (float4)(_arow.s2), brow, dot02 ); dot03 = mad( (float4)(_arow.s3), brow, dot03 ); dot04 = mad( (float4)(_arow.s4), brow, dot04 ); dot05 = mad( (float4)(_arow.s5), brow, dot05 ); dot06 = mad( (float4)(_arow.s6), brow, dot06 ); dot07 = mad( (float4)(_arow.s7), brow, dot07 );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) );", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) );", // NOLINT +"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) );", // NOLINT +"#endif", // NOLINT "#undef MM_DOT_PRODUCT", // NOLINT "}", // NOLINT "", // NOLINT "if(global_x * 4 < N && global_y * 8 < M) {", // NOLINT "if(mad24(global_x, 4, 3) < N) {", // NOLINT -"__global float4 *dst_write = (__global float4 *)dst_write0;", // NOLINT -"dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0;", // NOLINT -"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"vstore4(dot00, 0, dst_write0); dst_write0 += N;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; }", // NOLINT +"if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); }", // NOLINT "} else if(mad24(global_x, 4, 2) < N) {", // NOLINT -"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT -"dst_write[0] = dot00.xy; dst_write0[2] = dot00.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot00.xy, 0, dst_write0); dst_write0[2] = dot00.z;", // NOLINT +"dst_write0 += N;", // NOLINT "if(mad24(global_y, 8, 1) < M) {", // NOLINT -"dst_write[0] = dot01.xy; dst_write0[2] = dot01.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot01.xy, 0, dst_write0); dst_write0[2] = dot01.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 2) < M) {", // NOLINT -"dst_write[0] = dot02.xy; dst_write0[2] = dot02.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot02.xy, 0, dst_write0); dst_write0[2] = dot02.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 3) < M) {", // NOLINT -"dst_write[0] = dot03.xy; dst_write0[2] = dot03.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot03.xy, 0, dst_write0); dst_write0[2] = dot03.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 4) < M) {", // NOLINT -"dst_write[0] = dot04.xy; dst_write0[2] = dot04.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot04.xy, 0, dst_write0); dst_write0[2] = dot04.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 5) < M) {", // NOLINT -"dst_write[0] = dot05.xy; dst_write0[2] = dot05.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot05.xy, 0, dst_write0); dst_write0[2] = dot05.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 6) < M) {", // NOLINT -"dst_write[0] = dot06.xy; dst_write0[2] = dot06.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot06.xy, 0, dst_write0); dst_write0[2] = dot06.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 7) < M) {", // NOLINT -"dst_write[0] = dot07.xy; dst_write0[2] = dot07.z;", // NOLINT +"vstore2(dot07.xy, 0, dst_write0); dst_write0[2] = dot07.z;", // NOLINT "}", // NOLINT "} else if(mad24(global_x, 4, 1) < N) {", // NOLINT -"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT -"dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT -"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"vstore2(dot00.xy, 0, dst_write0); dst_write0 += N;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; }", // NOLINT +"if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); }", // NOLINT "} else {", // NOLINT "dst_write0[0] = dot00.x; dst_write0 += N;", // NOLINT "if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }", // NOLINT @@ -4314,6 +4491,7 @@ static std::vector> cl_kernels{ "#define TILE_N 32", // NOLINT "", // NOLINT "__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT +"__attribute__((intel_reqd_sub_group_size(8)))", // NOLINT "__kernel void TEMPLATE(gemm_buffer_TT, Dtype)(", // NOLINT "const __global float *src0, int off0,", // NOLINT "const __global float *src1, int off1,", // NOLINT @@ -4321,11 +4499,13 @@ static std::vector> cl_kernels{ "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha,", // NOLINT -"float beta,", // NOLINT +"float alpha_in,", // NOLINT +"float beta_in,", // NOLINT "int start_index)", // NOLINT "", // NOLINT "{", // NOLINT +"const float alpha = (float)alpha_in;", // NOLINT +"const float beta = (float)beta_in;", // NOLINT "const int group_x = get_group_id(0);", // NOLINT "const int group_y = get_group_id(1);", // NOLINT "const int local_x = get_local_id(0);", // NOLINT @@ -4343,32 +4523,32 @@ static std::vector> cl_kernels{ "float16 brow2;", // NOLINT "float16 brow3;", // NOLINT "", // NOLINT -"__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"__global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT "", // NOLINT -"const __global float *src0_read = src0 + (local_x * ( TILE_K / 8 ) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT +"const __global float *src0_read = src0 + (local_x * (TILE_K / 8) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT "", // NOLINT -"const __global float *src1_read0 = src1 + (local_x * VEC_SIZE + ( group_x * TILE_N )) * K + start_index + off1;", // NOLINT +"const __global float *src1_read0 = src1 + (local_x * VEC_SIZE + (group_x * TILE_N)) * K + start_index + off1;", // NOLINT "", // NOLINT -"float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : (beta * ((__global float4 *)dst_write0)[0]);", // NOLINT -"float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + N))[0] : (beta * ((__global float4 *)(dst_write0 + N))[0]);", // NOLINT -"float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 2 * N))[0]);", // NOLINT -"float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 3 * N))[0]);", // NOLINT -"float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 4 * N))[0]);", // NOLINT -"float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 5 * N))[0]);", // NOLINT -"float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 6 * N))[0]);", // NOLINT -"float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 7 * N))[0]);", // NOLINT +"float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);", // NOLINT +"float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N);", // NOLINT +"float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);", // NOLINT +"float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);", // NOLINT +"float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);", // NOLINT +"float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);", // NOLINT +"float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);", // NOLINT +"float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);", // NOLINT "", // NOLINT "int end_index = min(start_index + 256, K);", // NOLINT "while( start_index + TILE_K <= end_index ) {", // NOLINT -"brow0 = ((__global float16 *)src1_read0)[0];", // NOLINT -"brow1 = ((__global float16 *)(src1_read0 + K))[0];", // NOLINT -"brow2 = ((__global float16 *)(src1_read0 + 2 * K))[0];", // NOLINT -"brow3 = ((__global float16 *)(src1_read0 + 3 * K))[0];", // NOLINT +"brow0 = vload16(0, src1_read0);", // NOLINT +"brow1 = vload16(0, src1_read0 + K);", // NOLINT +"brow2 = vload16(0, src1_read0 + 2 * K);", // NOLINT +"brow3 = vload16(0, src1_read0 + 3 * K);", // NOLINT "", // NOLINT -"float8 arow0 = alpha * ((__global float8 *)src0_read)[0];", // NOLINT -"float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0];", // NOLINT +"float8 arow0 = alpha * vload8(0, src0_read);", // NOLINT +"float8 arow1 = alpha * vload8(0, src0_read + M);", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( _brow, _dot) _dot = mad( intel_sub_group_shuffle( arow0, 0 ), (float8)_brow.s0, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 0 ), (float8)_brow.s1, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 1 ), (float8)_brow.s2, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 1 ), (float8)_brow.s3, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 2 ), (float8)_brow.s4, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 2 ), (float8)_brow.s5, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 3 ), (float8)_brow.s6, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 3 ), (float8)_brow.s7, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 4 ), (float8)_brow.s8, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 4 ), (float8)_brow.s9, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 5 ), (float8)_brow.sa, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 5 ), (float8)_brow.sb, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 6 ), (float8)_brow.sc, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 6 ), (float8)_brow.sd, _dot ); _dot = mad( intel_sub_group_shuffle( arow0, 7 ), (float8)_brow.se, _dot ); _dot = mad( intel_sub_group_shuffle( arow1, 7 ), (float8)_brow.sf, _dot );", // NOLINT +"#define MM_DOT_PRODUCT( _brow, _dot) _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (float8)_brow.s0, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (float8)_brow.s1, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (float8)_brow.s2, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (float8)_brow.s3, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (float8)_brow.s4, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (float8)_brow.s5, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (float8)_brow.s6, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (float8)_brow.s7, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (float8)_brow.s8, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (float8)_brow.s9, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (float8)_brow.sa, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (float8)_brow.sb, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (float8)_brow.sc, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (float8)_brow.sd, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (float8)_brow.se, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (float8)_brow.sf, _dot );", // NOLINT "MM_DOT_PRODUCT( brow0, dot0 );", // NOLINT "MM_DOT_PRODUCT( brow1, dot1 );", // NOLINT "MM_DOT_PRODUCT( brow2, dot2 );", // NOLINT @@ -4381,15 +4561,15 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "if(start_index < end_index) {", // NOLINT -"brow0 = ((__global float16 *)src1_read0)[0]; src1_read0 += K;", // NOLINT -"brow1 = ((__global float16 *)src1_read0)[0]; src1_read0 += K;", // NOLINT -"brow2 = ((__global float16 *)src1_read0)[0]; src1_read0 += K;", // NOLINT -"brow3 = ((__global float16 *)src1_read0)[0];", // NOLINT +"brow0 = vload16(0, src1_read0); src1_read0 += K;", // NOLINT +"brow1 = vload16(0, src1_read0); src1_read0 += K;", // NOLINT +"brow2 = vload16(0, src1_read0); src1_read0 += K;", // NOLINT +"brow3 = vload16(0, src1_read0);", // NOLINT "", // NOLINT -"float8 arow0 = alpha * ((__global float8 *)src0_read)[0];", // NOLINT -"float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0];", // NOLINT +"float8 arow0 = alpha * vload8(0, src0_read);", // NOLINT +"float8 arow1 = alpha * vload8(0, src0_read + M);", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( _brow, _dot) _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 0 ), (float8)_brow.s0, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 0 ), (float8)_brow.s1, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 1 ), (float8)_brow.s2, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 1 ), (float8)_brow.s3, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 2 ), (float8)_brow.s4, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 2 ), (float8)_brow.s5, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 3 ), (float8)_brow.s6, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 3 ), (float8)_brow.s7, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 4 ), (float8)_brow.s8, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 4 ), (float8)_brow.s9, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 5 ), (float8)_brow.sa, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 5 ), (float8)_brow.sb, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 6 ), (float8)_brow.sc, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 6 ), (float8)_brow.sd, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 7 ), (float8)_brow.se, _dot ) : _dot; _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 7 ), (float8)_brow.sf, _dot ) : _dot;", // NOLINT +"#define MM_DOT_PRODUCT( _brow, _dot) _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (float8)_brow.s0, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (float8)_brow.s1, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (float8)_brow.s2, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (float8)_brow.s3, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (float8)_brow.s4, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (float8)_brow.s5, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (float8)_brow.s6, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (float8)_brow.s7, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (float8)_brow.s8, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (float8)_brow.s9, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (float8)_brow.sa, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (float8)_brow.sb, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (float8)_brow.sc, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (float8)_brow.sd, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (float8)_brow.se, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (float8)_brow.sf, _dot ) : _dot;", // NOLINT "int w = start_index;", // NOLINT "MM_DOT_PRODUCT( brow0, dot0 );", // NOLINT "w = start_index;", // NOLINT @@ -4412,74 +4592,71 @@ static std::vector> cl_kernels{ "", // NOLINT "if(global_x * 4 < N && global_y * 8 < M) {", // NOLINT "if(mad24(global_x, 4, 3) < N) {", // NOLINT -"__global float4 *dst_write = (__global float4 *)dst_write0;", // NOLINT -"dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0;", // NOLINT -"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"vstore4(dot00, 0, dst_write0); dst_write0 += N;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; }", // NOLINT +"if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); }", // NOLINT "} else if(mad24(global_x, 4, 2) < N) {", // NOLINT -"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT -"dst_write[0] = dot00.xy; dst_write0[2] = dot00.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot00.xy, 0, dst_write0); dst_write0[2] = dot00.z;", // NOLINT +"dst_write0 += N;", // NOLINT "if(mad24(global_y, 8, 1) < M) {", // NOLINT -"dst_write[0] = dot01.xy; dst_write0[2] = dot01.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot01.xy, 0, dst_write0); dst_write0[2] = dot01.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 2) < M) {", // NOLINT -"dst_write[0] = dot02.xy; dst_write0[2] = dot02.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot02.xy, 0, dst_write0); dst_write0[2] = dot02.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 3) < M) {", // NOLINT -"dst_write[0] = dot03.xy; dst_write0[2] = dot03.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot03.xy, 0, dst_write0); dst_write0[2] = dot03.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 4) < M) {", // NOLINT -"dst_write[0] = dot04.xy; dst_write0[2] = dot04.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot04.xy, 0, dst_write0); dst_write0[2] = dot04.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 5) < M) {", // NOLINT -"dst_write[0] = dot05.xy; dst_write0[2] = dot05.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot05.xy, 0, dst_write0); dst_write0[2] = dot05.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 6) < M) {", // NOLINT -"dst_write[0] = dot06.xy; dst_write0[2] = dot06.z;", // NOLINT -"dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT +"vstore2(dot06.xy, 0, dst_write0); dst_write0[2] = dot06.z;", // NOLINT +"dst_write0 += N;", // NOLINT "} else", // NOLINT "return;", // NOLINT "if(mad24(global_y, 8, 7) < M) {", // NOLINT -"dst_write[0] = dot07.xy; dst_write0[2] = dot07.z;", // NOLINT +"vstore2(dot07.xy, 0, dst_write0); dst_write0[2] = dot07.z;", // NOLINT "}", // NOLINT "} else if(mad24(global_x, 4, 1) < N) {", // NOLINT -"__global float2 *dst_write = (__global float2 *)dst_write0;", // NOLINT -"dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0;", // NOLINT -"if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"vstore2(dot00.xy, 0, dst_write0); dst_write0 += N;", // NOLINT +"if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; }", // NOLINT +"if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; }", // NOLINT "else return;", // NOLINT -"if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; }", // NOLINT +"if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); }", // NOLINT "} else {", // NOLINT "dst_write0[0] = dot00.x; dst_write0 += N;", // NOLINT "if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }", // NOLINT @@ -4504,6 +4681,8 @@ static std::vector> cl_kernels{ "#undef TILE_M", // NOLINT "#undef TILE_K", // NOLINT "#undef TILE_N", // NOLINT +"", // NOLINT +"#endif", // NOLINT ""}, // NOLINT {"#ifndef __OPENCL_VERSION__", // NOLINT "#include \"header.cl\"", // NOLINT diff --git a/src/caffe/greentea/cl_kernels/gemm.cl b/src/caffe/greentea/cl_kernels/gemm.cl index de5fc39c3fd..3767107761a 100644 --- a/src/caffe/greentea/cl_kernels/gemm.cl +++ b/src/caffe/greentea/cl_kernels/gemm.cl @@ -2,49 +2,73 @@ #include "header.cl" #endif +#if TYPE != TYPE_DOUBLE + #define TILE_M 32 #define TILE_K 8 -#define TILE_N 8 // common block to calculate (alpha * AxB + beta * C) and output to destination image. +#if TYPE == TYPE_HALF +#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read_us8( __image, __coord ) +#define SHUFFLE_TYPE2(val) as_ushort2(val) +#define SHUFFLE_TYPE8(val) as_ushort8(val) +#define READ_IMAGE(__image, __coord) read_imageh(__image, sampler, __coord) +#define SIZE_OF_ELEMENT sizeof(ushort) +#define SIMD_SIZE_GEMM 16 +#define TILE_N 16 +#else +#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read8( __image, __coord ) +#define SHUFFLE_TYPE2(val) val +#define SHUFFLE_TYPE8(val) val +#define READ_IMAGE(__image, __coord) read_imagef(__image, sampler, __coord) +#define SIZE_OF_ELEMENT sizeof(uint) +#define SIMD_SIZE_GEMM 8 +#define TILE_N 8 +#endif + //#define USE_IMAGE_C #ifdef USE_IMAGE_C +#if TYPE == TYPE_HALF +#define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read_us8( _C, _coordC ) ) +#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write_us8( _C, _coordC, as_ushort8( _val ) ) +#else #define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read8( _C, _coordC ) ) #define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) ) +#endif #define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst #define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint)) #else #define BLOCKC_READ8( _C, _coordC ) \ - (float8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * N + _coordC.x + get_local_id(0) ] : 0, \ - (_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * N + _coordC.x + get_local_id(0) ] : 0, \ - (_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * N + _coordC.x + get_local_id(0) ] : 0, \ - (_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * N + _coordC.x + get_local_id(0) ] : 0, \ - (_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * N + _coordC.x + get_local_id(0) ] : 0, \ - (_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * N + _coordC.x + get_local_id(0) ] : 0, \ - (_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * N + _coordC.x + get_local_id(0) ] : 0, \ - (_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * N + _coordC.x + get_local_id(0) ] : 0) + (float8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ + (_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * ldc + _coordC.x + get_local_id(0) ] : 0) #define BLOCKC_WRITE8( _C, _coordC, _val) do {\ if (_coordC.x + get_local_id(0) < N) { \ if (_coordC.y < M) \ - _C[ _coordC.y * N + _coordC.x + get_local_id(0) ] = _val.s0; \ + _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] = _val.s0; \ if (_coordC.y + 1 < M) \ - _C[ ( _coordC.y + 1 )* N + _coordC.x + get_local_id(0) ] = _val.s1; \ + _C[ ( _coordC.y + 1 )* ldc + _coordC.x + get_local_id(0) ] = _val.s1; \ if (_coordC.y + 2 < M) \ - _C[ ( _coordC.y + 2 )* N + _coordC.x + get_local_id(0) ] = _val.s2; \ + _C[ ( _coordC.y + 2 )* ldc + _coordC.x + get_local_id(0) ] = _val.s2; \ if (_coordC.y + 3 < M) \ - _C[ ( _coordC.y + 3 )* N + _coordC.x + get_local_id(0) ] = _val.s3; \ + _C[ ( _coordC.y + 3 )* ldc + _coordC.x + get_local_id(0) ] = _val.s3; \ if (_coordC.y + 4 < M) \ - _C[ ( _coordC.y + 4 )* N + _coordC.x + get_local_id(0) ] = _val.s4; \ + _C[ ( _coordC.y + 4 )* ldc + _coordC.x + get_local_id(0) ] = _val.s4; \ if (_coordC.y + 5 < M) \ - _C[ ( _coordC.y + 5 )* N + _coordC.x + get_local_id(0) ] = _val.s5; \ + _C[ ( _coordC.y + 5 )* ldc + _coordC.x + get_local_id(0) ] = _val.s5; \ if (_coordC.y + 6 < M) \ - _C[ ( _coordC.y + 6 )* N + _coordC.x + get_local_id(0) ] = _val.s6; \ + _C[ ( _coordC.y + 6 )* ldc + _coordC.x + get_local_id(0) ] = _val.s6; \ if (_coordC.y + 7 < M) \ - _C[ ( _coordC.y + 7 )* N + _coordC.x + get_local_id(0) ] = _val.s7; \ + _C[ ( _coordC.y + 7 )* ldc + _coordC.x + get_local_id(0) ] = _val.s7; \ }} while(0) -#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N +#define MATC_PARAMETER __global float * C, const int offC, const int M, const int N, const int ldc #define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1) #endif @@ -56,36 +80,36 @@ float8 blockC02; \ float8 blockC03; \ if (BETA_NOT0) { \ - blockC00 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ - blockC01 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ - blockC02 = BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ - blockC03 = BLOCKC_READ8( _C, coordC ); \ + blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ + blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ + blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ + blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \ if (!ALPHA1) { \ - blockC00 *= beta; \ - blockC01 *= beta; \ - blockC02 *= beta; \ - blockC03 *= beta; \ blockC00 = mad(blockAxB00, (float8)alpha, blockC00); \ blockC01 = mad(blockAxB01, (float8)alpha, blockC01); \ blockC02 = mad(blockAxB02, (float8)alpha, blockC02); \ blockC03 = mad(blockAxB03, (float8)alpha, blockC03); \ } else { \ - blockC00 = mad(blockC00, (float8)beta, blockAxB00); \ - blockC01 = mad(blockC01, (float8)beta, blockAxB01); \ - blockC02 = mad(blockC02, (float8)beta, blockAxB02); \ - blockC03 = mad(blockC03, (float8)beta, blockAxB03); \ + blockC00 += blockAxB00; \ + blockC01 += blockAxB01; \ + blockC02 += blockAxB02; \ + blockC03 += blockAxB03; \ } \ } else { \ + blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ + blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ + blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ + blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \ if (!ALPHA1) { \ - blockC00 = blockAxB00 * alpha; \ - blockC01 = blockAxB01 * alpha; \ - blockC02 = blockAxB02 * alpha; \ - blockC03 = blockAxB03 * alpha; \ + blockC00 = mad(blockAxB00, (float8)alpha, blockC00); \ + blockC01 = mad(blockAxB01, (float8)alpha, blockC01); \ + blockC02 = mad(blockAxB02, (float8)alpha, blockC02); \ + blockC03 = mad(blockAxB03, (float8)alpha, blockC03); \ } else { \ - blockC00 = blockAxB00; \ - blockC01 = blockAxB01; \ - blockC02 = blockAxB02; \ - blockC03 = blockAxB03; \ + blockC00 += blockAxB00; \ + blockC01 += blockAxB01; \ + blockC02 += blockAxB02; \ + blockC03 += blockAxB03; \ } \ } \ BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; \ @@ -105,6 +129,43 @@ intel_sub_group_shuffle( _block.s7, _col ) ); // A's column block multiply B 's row block. +#if TYPE == TYPE_HALF +#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB00, _blockB01 ) \ + { \ + const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ + const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ + const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ + const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ + const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ + const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ + const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ + const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ + const float8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \ + const float8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \ + const float8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \ + const float8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \ + const float8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \ + const float8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \ + const float8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \ + const float8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \ + _result = mad( (float8)(_blockB00.s0), acol0, _result ); \ + _result = mad( (float8)(_blockB00.s1), acol1, _result ); \ + _result = mad( (float8)(_blockB00.s2), acol2, _result ); \ + _result = mad( (float8)(_blockB00.s3), acol3, _result ); \ + _result = mad( (float8)(_blockB00.s4), acol4, _result ); \ + _result = mad( (float8)(_blockB00.s5), acol5, _result ); \ + _result = mad( (float8)(_blockB00.s6), acol6, _result ); \ + _result = mad( (float8)(_blockB00.s7), acol7, _result ); \ + _result = mad( (float8)(_blockB01.s0), acol8, _result ); \ + _result = mad( (float8)(_blockB01.s1), acol9, _result ); \ + _result = mad( (float8)(_blockB01.s2), acola, _result ); \ + _result = mad( (float8)(_blockB01.s3), acolb, _result ); \ + _result = mad( (float8)(_blockB01.s4), acolc, _result ); \ + _result = mad( (float8)(_blockB01.s5), acold, _result ); \ + _result = mad( (float8)(_blockB01.s6), acole, _result ); \ + _result = mad( (float8)(_blockB01.s7), acolf, _result ); \ + } +#else #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ { \ const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ @@ -124,17 +185,64 @@ _result = mad( (float8)(_blockB.s6), acol6, _result ); \ _result = mad( (float8)(_blockB.s7), acol7, _result ); \ } +#endif +#if TYPE == TYPE_HALF #define GEMM_NN(ALPHA1, BETA_NOT0) \ -__attribute__((reqd_work_group_size(8, 1, 1))) \ -__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ +__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \ __read_only image2d_t A, \ __read_only image2d_t B, \ MATC_PARAMETER, \ - float alpha, \ - float beta, \ - int width0) \ + float alpha_in, \ + float beta_in, \ + int width0, \ + int isFirstColBlock) \ { \ + const float alpha = (float)alpha_in; \ + const float beta = (float)beta_in; \ + const int group_x = get_group_id(0); \ + const int group_y = get_group_id(1); \ + float8 blockAxB00 = 0; \ + float8 blockAxB01 = 0; \ + float8 blockAxB02 = 0; \ + float8 blockAxB03 = 0; \ + int2 coordA = (int2)( 0, group_y * TILE_M ); \ + int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \ + do \ + { \ + int2 coordBTemp = coordB; \ + float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ + float8 blockB01 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ + int2 coordATemp = coordA; \ + float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, blockB01 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, blockB01 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, blockB01 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, blockB01 ); \ + } \ + while( coordB.y < width0 ); \ + GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ +} +#else +#define GEMM_NN(ALPHA1, BETA_NOT0) \ +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ +__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \ + __read_only image2d_t A, \ + __read_only image2d_t B, \ + MATC_PARAMETER, \ + float alpha_in, \ + float beta_in, \ + int width0, \ + int isFirstColBlock) \ +{ \ + const float alpha = (float)alpha_in; \ + const float beta = (float)beta_in; \ const int group_x = get_group_id(0); \ const int group_y = get_group_id(1); \ float8 blockAxB00 = 0.0f; \ @@ -142,16 +250,16 @@ __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ float8 blockAxB02 = 0.0f; \ float8 blockAxB03 = 0.0f; \ int2 coordA = (int2)( 0, group_y * TILE_M ); \ - int2 coordB = (int2)( ( group_x * TILE_N ) * sizeof(uint), 0 ); \ + int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \ do \ { \ int2 coordBTemp = coordB; \ - float8 blockB00 = as_float8( intel_sub_group_block_read8( B, coordBTemp ) ); coordB.y += TILE_K; \ + float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ int2 coordATemp = coordA; \ - float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.x += TILE_K * sizeof(uint); \ + float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ @@ -160,6 +268,7 @@ __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ while( coordB.y < width0 ); \ GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ } +#endif GEMM_NN(1, 0) // ALPHA == 1, BETA == 0 GEMM_NN(1, 1) // ALPHA == 1, BETA != 0 @@ -170,63 +279,107 @@ GEMM_NN(0, 1) // ALPHA != 1, BETA != 0 #undef MULTIPLY_BLOCKS_8x8 // replicate the first row to column block. -#define TRANSPOSE_BLOCK_8(_vec) \ - (float8)( intel_sub_group_shuffle(_vec, 0), \ - intel_sub_group_shuffle(_vec, 1), \ - intel_sub_group_shuffle(_vec, 2), \ - intel_sub_group_shuffle(_vec, 3), \ - intel_sub_group_shuffle(_vec, 4), \ - intel_sub_group_shuffle(_vec, 5), \ - intel_sub_group_shuffle(_vec, 6), \ - intel_sub_group_shuffle(_vec, 7) ) - -#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ +#define TRANSPOSE_BLOCK_8(_vec, _col) \ + (float8)( intel_sub_group_shuffle(_vec, _col + 0), \ + intel_sub_group_shuffle(_vec, _col + 1), \ + intel_sub_group_shuffle(_vec, _col + 2), \ + intel_sub_group_shuffle(_vec, _col + 3), \ + intel_sub_group_shuffle(_vec, _col + 4), \ + intel_sub_group_shuffle(_vec, _col + 5), \ + intel_sub_group_shuffle(_vec, _col + 6), \ + intel_sub_group_shuffle(_vec, _col + 7) ) + +#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \ { \ - _result = mad( (float8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0), _result ); \ - _result = mad( (float8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1), _result ); \ - _result = mad( (float8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2), _result ); \ - _result = mad( (float8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3), _result ); \ - _result = mad( (float8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4), _result ); \ - _result = mad( (float8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5), _result ); \ - _result = mad( (float8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6), _result ); \ - _result = mad( (float8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7), _result ); \ + _result = mad( (float8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); \ + _result = mad( (float8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); \ + _result = mad( (float8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); \ + _result = mad( (float8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); \ + _result = mad( (float8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); \ + _result = mad( (float8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); \ + _result = mad( (float8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); \ + _result = mad( (float8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); \ } +#if TYPE == TYPE_HALF #define GEMM_TN(ALPHA1, BETA_NOT0) \ -__attribute__((reqd_work_group_size(8, 1, 1))) \ +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ __read_only image2d_t A, \ __read_only image2d_t B, \ MATC_PARAMETER, \ - float alpha, \ - float beta, \ - int width0) \ + float alpha_in, \ + float beta_in, \ + int width0, \ + int isFirstColBlock) \ { \ + const float alpha = (float)alpha_in; \ + const float beta = (float)beta_in; \ + const int group_x = get_group_id(0);\ + const int group_y = get_group_id(1);\ + float8 blockAxB00 = 0;\ + float8 blockAxB01 = 0;\ + float8 blockAxB02 = 0;\ + float8 blockAxB03 = 0;\ + int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\ + int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\ + do\ + {\ + int2 coordBTemp = coordB;\ + float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\ + int2 coordATemp = coordA;\ + float8 blockA00 = as_half8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\ + float8 blockA01 = as_half8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \ + } \ + while( coordB.y < width0 ); \ + GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ +} +#else +#define GEMM_TN(ALPHA1, BETA_NOT0) \ +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ +__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ + __read_only image2d_t A, \ + __read_only image2d_t B, \ + MATC_PARAMETER, \ + float alpha_in, \ + float beta_in, \ + int width0, \ + int isFirstColBlock) \ +{ \ + const float alpha = (float)alpha_in; \ + const float beta = (float)beta_in; \ const int group_x = get_group_id(0);\ const int group_y = get_group_id(1);\ float8 blockAxB00 = 0.0f;\ float8 blockAxB01 = 0.0f;\ float8 blockAxB02 = 0.0f;\ float8 blockAxB03 = 0.0f;\ - int2 coordA = (int2)( group_y * TILE_M * sizeof(uint), 0 );\ - int2 coordB = (int2)( ( group_x * TILE_N ) * sizeof(uint), 0 );\ + int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\ + int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\ do\ {\ int2 coordBTemp = coordB;\ - float8 blockB00 = as_float8( intel_sub_group_block_read8( B, coordBTemp ) ); coordB.y += TILE_K;\ + float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\ int2 coordATemp = coordA;\ - float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint);\ - float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint);\ - float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint);\ - float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.y += TILE_K;\ - MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ - MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ - MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ - MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \ + float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ + float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ + float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ + float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, 0 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, 0 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, 0 ); \ } \ while( coordB.y < width0 ); \ GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ } +#endif GEMM_TN(1, 0) // ALPHA == 1, BETA == 0 GEMM_TN(1, 1) // ALPHA == 1, BETA != 0 @@ -247,6 +400,7 @@ GEMM_TN(0, 1) // ALPHA != 1, BETA != 0 intel_sub_group_shuffle( _block.s6, _col), \ intel_sub_group_shuffle( _block.s7, _col) ) +#if TYPE == TYPE_HALF #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ { \ const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ @@ -257,6 +411,14 @@ GEMM_TN(0, 1) // ALPHA != 1, BETA != 0 const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ + const float8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \ + const float8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \ + const float8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \ + const float8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \ + const float8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \ + const float8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \ + const float8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \ + const float8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \ _result = mad( (float8)_blockB.s0, acol0, _result ); \ _result = mad( (float8)_blockB.s1, acol1, _result ); \ _result = mad( (float8)_blockB.s2, acol2, _result ); \ @@ -265,21 +427,95 @@ GEMM_TN(0, 1) // ALPHA != 1, BETA != 0 _result = mad( (float8)_blockB.s5, acol5, _result ); \ _result = mad( (float8)_blockB.s6, acol6, _result ); \ _result = mad( (float8)_blockB.s7, acol7, _result ); \ + _result = mad( (float8)_blockB.s8, acol8, _result ); \ + _result = mad( (float8)_blockB.s9, acol9, _result ); \ + _result = mad( (float8)_blockB.sa, acola, _result ); \ + _result = mad( (float8)_blockB.sb, acolb, _result ); \ + _result = mad( (float8)_blockB.sc, acolc, _result ); \ + _result = mad( (float8)_blockB.sd, acold, _result ); \ + _result = mad( (float8)_blockB.se, acole, _result ); \ + _result = mad( (float8)_blockB.sf, acolf, _result ); \ } +#else +#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ + { \ + const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ + const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ + const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ + const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ + const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ + const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ + const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ + const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ + _result = mad( (float8)_blockB.s0, acol0, _result ); \ + _result = mad( (float8)_blockB.s1, acol1, _result ); \ + _result = mad( (float8)_blockB.s2, acol2, _result ); \ + _result = mad( (float8)_blockB.s3, acol3, _result ); \ + _result = mad( (float8)_blockB.s4, acol4, _result ); \ + _result = mad( (float8)_blockB.s5, acol5, _result ); \ + _result = mad( (float8)_blockB.s6, acol6, _result ); \ + _result = mad( (float8)_blockB.s7, acol7, _result ); \ + } +#endif - - +#if TYPE == TYPE_HALF #define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \ -__attribute__((reqd_work_group_size(8, 1, 1))) \ +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ __read_only image2d_t A, \ MATB_PARAMETER, \ MATC_PARAMETER, \ - float alpha, \ - float beta, \ + float alpha_in, \ + float beta_in, \ int padded_k, \ - int k) \ + int k, \ + int isFirstColBlock) \ { \ + const float alpha = (float)alpha_in; \ + const float beta = (float)beta_in; \ + const int group_x = get_group_id(0); \ + const int group_y = get_group_id(1); \ + float8 blockAxB00 = 0; \ + float8 blockAxB01 = 0; \ + float8 blockAxB02 = 0; \ + float8 blockAxB03 = 0; \ + int2 coordA = (int2)( 0, group_y * TILE_M ); \ + int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ + const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ + do \ + { \ + float16 blockB00; \ + BLOCKB_READ8(blockB00, B, coordB); \ + int2 coordATemp = coordA; \ + float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \ + } \ + while( coordB.x < padded_k / VECSIZE ); \ + GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ +} +#else +#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \ +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ +__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ + __read_only image2d_t A, \ + MATB_PARAMETER, \ + MATC_PARAMETER, \ + float alpha_in, \ + float beta_in, \ + int padded_k, \ + int k, \ + int isFirstColBlock) \ +{ \ + const float alpha = (float)alpha_in; \ + const float beta = (float)beta_in; \ const int group_x = get_group_id(0); \ const int group_y = get_group_id(1); \ float8 blockAxB00 = 0.0f; \ @@ -291,13 +527,13 @@ __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dt const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ do \ { \ - float8 blockB00; \ + float8 blockB00; \ BLOCKB_READ8(blockB00, B, coordB); \ int2 coordATemp = coordA; \ - float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.x += TILE_K * sizeof(uint); \ + float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ @@ -306,13 +542,23 @@ __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dt while( coordB.x < padded_k / VECSIZE ); \ GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ } +#endif - +#if TYPE == TYPE_HALF #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ - _blockb.s0123 = read_imagef(_B, sampler, _coordBTemp); _coordBTemp.x += 1; \ - _blockb.s4567 = read_imagef(_B, sampler, _coordBTemp); _coordB.x += 2; + _blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s89ab = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.scdef = READ_IMAGE(_B, _coordBTemp); _coordB.x += 4; +#else +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + _blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2; +#endif #define MATB_PARAMETER __read_only image2d_t B @@ -323,13 +569,23 @@ GEMM_NT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0 #undef BLOCKB_READ8 #undef MATB_PARAMETER +#if TYPE == TYPE_HALF +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \ + _blockb = as_half16(as_ushort16(vload8(0, B_read))); \ + _coordB.x += TILE_K * 2; +#else #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ - _blockb = *(__global float8*)&_B[_coordBTemp.y * k + _coordBTemp.x + offB];\ + const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \ + _blockb = vload8(0, B_read); \ _coordB.x += TILE_K; +#endif -#define MATB_PARAMETER __global float *B, int offB +#define MATB_PARAMETER __global float *B, int offB, int ldb GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0 GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0 @@ -338,28 +594,67 @@ GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0 #undef BLOCKB_READ8 #undef MATB_PARAMETER - +#if TYPE == TYPE_HALF +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + float4 temp; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s0 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s1 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s2 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s3 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s4 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s5 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s6 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s7 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s8 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s9 = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.sa = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.sb = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.sc = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.sd = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.se = temp.s0; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.sf = temp.s0; \ + _coordB.x += 16; +#else #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ float4 temp; \ - temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s0 = temp.s0; \ - temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s1 = temp.s0; \ - temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s2 = temp.s0; \ - temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s3 = temp.s0; \ - temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s4 = temp.s0; \ - temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s5 = temp.s0; \ - temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s6 = temp.s0; \ - temp = read_imagef(_B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s7 = temp.s0; \ _coordB.x += 8; +#endif #define MATB_PARAMETER __read_only image2d_t B @@ -374,26 +669,26 @@ GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 #undef TRANSPOSE_BLOCK_8 //The same as GEMM_TN. -#define TRANSPOSE_BLOCK_8(_vec) \ - (float8)( intel_sub_group_shuffle(_vec, 0), \ - intel_sub_group_shuffle(_vec, 1), \ - intel_sub_group_shuffle(_vec, 2), \ - intel_sub_group_shuffle(_vec, 3), \ - intel_sub_group_shuffle(_vec, 4), \ - intel_sub_group_shuffle(_vec, 5), \ - intel_sub_group_shuffle(_vec, 6), \ - intel_sub_group_shuffle(_vec, 7) ); - -#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ +#define TRANSPOSE_BLOCK_8(_vec, _col) \ + (float8)( intel_sub_group_shuffle(_vec, _col + 0), \ + intel_sub_group_shuffle(_vec, _col + 1), \ + intel_sub_group_shuffle(_vec, _col + 2), \ + intel_sub_group_shuffle(_vec, _col + 3), \ + intel_sub_group_shuffle(_vec, _col + 4), \ + intel_sub_group_shuffle(_vec, _col + 5), \ + intel_sub_group_shuffle(_vec, _col + 6), \ + intel_sub_group_shuffle(_vec, _col + 7) ); + +#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \ { \ - const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0 ); \ - const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1 ); \ - const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2 ); \ - const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3 ); \ - const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4 ); \ - const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5 ); \ - const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6 ); \ - const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7 ); \ + const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); \ + const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); \ + const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); \ + const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); \ + const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); \ + const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); \ + const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); \ + const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); \ _result = mad( (float8)_blockB.s0, acol0, _result ); \ _result = mad( (float8)_blockB.s1, acol1, _result ); \ _result = mad( (float8)_blockB.s2, acol2, _result ); \ @@ -404,24 +699,69 @@ GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 _result = mad( (float8)_blockB.s7, acol7, _result ); \ } +#if TYPE == TYPE_HALF +#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \ +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ +__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \ + __read_only image2d_t A, \ + MATB_PARAMETER, \ + MATC_PARAMETER, \ + float alpha_in, \ + float beta_in, \ + int padded_k, \ + int k, \ + int isFirstColBlock) \ +{ \ + const float alpha = (float)alpha_in; \ + const float beta = (float)beta_in; \ + const int group_x = get_group_id(0); \ + const int group_y = get_group_id(1); \ + float8 blockAxB00 = 0; \ + float8 blockAxB01 = 0; \ + float8 blockAxB02 = 0; \ + float8 blockAxB03 = 0; \ + int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \ + int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ + const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ + do \ + { \ + float8 blockB00; \ + BLOCKB_READ8(blockB00, B, coordB); \ + int2 coordATemp = coordA; \ + float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\ + float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \ + } \ + while( coordB.x < padded_k / VECSIZE ); \ + GEMM_OUTPUT(ALPHA1, BETA_NOT0);\ +} +#else #define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \ -__attribute__((reqd_work_group_size(8, 1, 1))) \ +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \ __read_only image2d_t A, \ MATB_PARAMETER, \ MATC_PARAMETER, \ - float alpha, \ - float beta, \ + float alpha_in, \ + float beta_in, \ int padded_k, \ - int k) \ + int k, \ + int isFirstColBlock) \ { \ + const float alpha = (float)alpha_in; \ + const float beta = (float)beta_in; \ const int group_x = get_group_id(0); \ const int group_y = get_group_id(1); \ float8 blockAxB00 = 0.0f; \ float8 blockAxB01 = 0.0f; \ float8 blockAxB02 = 0.0f; \ float8 blockAxB03 = 0.0f; \ - int2 coordA = (int2)( group_y * TILE_M * sizeof(uint), 0 ); \ + int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \ int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ do \ @@ -429,24 +769,25 @@ __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, D float8 blockB00; \ BLOCKB_READ8(blockB00, B, coordB); \ int2 coordATemp = coordA; \ - float8 blockA00 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); \ - float8 blockA01 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); \ - float8 blockA02 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordATemp.x += 8 * sizeof(uint); \ - float8 blockA03 = as_float8( intel_sub_group_block_read8( A, coordATemp ) ); coordA.y += TILE_K; \ - MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00 ); \ - MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00 ); \ - MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00 ); \ - MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00 ); \ + float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ + float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ + float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ + float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; \ + MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00, 0 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00, 0 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00, 0 ); \ + MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00, 0 ); \ } \ while( coordB.x < padded_k / VECSIZE ); \ GEMM_OUTPUT(ALPHA1, BETA_NOT0);\ } +#endif #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ - blockB00.s0123 = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ - blockB00.s4567 = read_imagef(B, _coordBTemp); _coordB.x += 2; + _blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ + _blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2; #define MATB_PARAMETER __read_only image2d_t B @@ -457,13 +798,23 @@ GEMM_TT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0 #undef BLOCKB_READ8 #undef MATB_PARAMETER +#if TYPE == TYPE_HALF #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ - _blockb = *(__global float8*)&_B[_coordBTemp.y * k + _coordBTemp.x + offB];\ + const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \ + _blockb = as_half8(as_ushort8(vload4(0, B_read))); \ _coordB.x += TILE_K; +#else +#define BLOCKB_READ8(_blockb, _B, _coordB) \ + int2 _coordBTemp = _coordB; \ + _coordBTemp.y += get_local_id(0); \ + const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \ + _blockb = vload8(0, B_read); \ + _coordB.x += TILE_K; +#endif -#define MATB_PARAMETER __global float *B, int offB +#define MATB_PARAMETER __global float *B, int offB, int ldb GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0 GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0 @@ -476,21 +827,21 @@ GEMM_TT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0 int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ float4 temp; \ - temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s0 = temp.s0; \ - temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s1 = temp.s0; \ - temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s2 = temp.s0; \ - temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s3 = temp.s0; \ - temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s4 = temp.s0; \ - temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s5 = temp.s0; \ - temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s6 = temp.s0; \ - temp = read_imagef(B, _coordBTemp); _coordBTemp.x += 1; \ + temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s7 = temp.s0; \ _coordB.x += 8; @@ -510,32 +861,69 @@ GEMM_TT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 #undef TILE_K #undef TILE_N -__kernel void TEMPLATE(gemm_buffer_copy_image,Dtype)( +__kernel void TEMPLATE(gemm_buffer_copy_image_transpose, Dtype)( + __global float* A, + __write_only image2d_t ImA, + int offA, + int width, + int height, + int ldA) +{ + const int gidx = get_global_id(0); + const int gidy = get_global_id(1); + int2 coord_dst = (int2)(gidx, gidy); + __global float* A_off = A + offA; + float srcA = A_off[gidy * ldA + gidx]; +#if TYPE == TYPE_HALF + write_imageh(ImA, coord_dst, (float4)srcA); +#else + write_imagef(ImA, coord_dst, (float4)srcA); +#endif +} + +__kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)( __global float* A, __write_only image2d_t ImA, int offA, int width, - int height) + int height, + int ldA) { const int gidx = get_global_id(0); const int gidy = get_global_id(1); int2 coord_dst = (int2)(gidx, gidy); +#if TYPE == TYPE_HALF + if (gidx >= width || gidy >= height) { + write_imageh(ImA, coord_dst, 0); + return; + } + __global float* A_off = A + offA; + write_imageh(ImA, coord_dst, A_off[gidy * ldA + gidx]); +#else if (gidx >= width || gidy >= height) { write_imageui(ImA, coord_dst, (uint4)0); return; } __global float* A_off = A + offA; - uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * width + gidx])); + uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * ldA + gidx])); write_imageui(ImA, coord_dst, srcA); +#endif } + #define VEC_SIZE 4 #define LWG_HEIGHT 4 #define TILE_M 8 +#if TYPE == TYPE_HALF +#define TILE_K 32 +#define TILE_N 64 +#else #define TILE_K 16 #define TILE_N 32 +#endif -__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1))) +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( const __global float *src0, int off0, const __global float *src1, int off1, @@ -543,10 +931,12 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( int M, int N, int K, - float alpha, - float beta, + float alpha_in, + float beta_in, int start_index) { + const float alpha = (float)alpha_in; + const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); const int local_x = get_local_id(0); @@ -557,11 +947,11 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( float4 brow; float2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7; - __global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + __global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; - const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * K + start_index + off0; + const __global float *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0; - const __global float *src1_read0 = src1 + local_x * VEC_SIZE + ( group_x * TILE_N ) + start_index * N + off1; + const __global float *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1; int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M); @@ -574,37 +964,37 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border; int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border; - float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : beta * ((__global float4 *)dst_write0)[0]; - float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 1 * N))[0] : beta * ((__global float4 *)(dst_write0 + 1 * N))[0]; - float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : beta * ((__global float4 *)(dst_write0 + 2 * N))[0]; - float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : beta * ((__global float4 *)(dst_write0 + 3 * N))[0]; - float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : beta * ((__global float4 *)(dst_write0 + 4 * N))[0]; - float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : beta * ((__global float4 *)(dst_write0 + 5 * N))[0]; - float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : beta * ((__global float4 *)(dst_write0 + 6 * N))[0]; - float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : beta * ((__global float4 *)(dst_write0 + 7 * N))[0]; + float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); + float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N); + float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); + float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); + float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); + float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); + float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); + float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); int end_index = min(start_index + 256, K); int w = start_index; while( w + TILE_K <= end_index ) { - arow0 = alpha * ((__global float2 *)(src0_read + row0 * K))[0]; - arow1 = alpha * ((__global float2 *)(src0_read + row1 * K))[0]; - arow2 = alpha * ((__global float2 *)(src0_read + row2 * K))[0]; - arow3 = alpha * ((__global float2 *)(src0_read + row3 * K))[0]; - arow4 = alpha * ((__global float2 *)(src0_read + row4 * K))[0]; - arow5 = alpha * ((__global float2 *)(src0_read + row5 * K))[0]; - arow6 = alpha * ((__global float2 *)(src0_read + row6 * K))[0]; - arow7 = alpha * ((__global float2 *)(src0_read + row7 * K))[0]; + arow0 = alpha * vload2(0, src0_read + row0 * K); + arow1 = alpha * vload2(0, src0_read + row1 * K); + arow2 = alpha * vload2(0, src0_read + row2 * K); + arow3 = alpha * vload2(0, src0_read + row3 * K); + arow4 = alpha * vload2(0, src0_read + row4 * K); + arow5 = alpha * vload2(0, src0_read + row5 * K); + arow6 = alpha * vload2(0, src0_read + row6 * K); + arow7 = alpha * vload2(0, src0_read + row7 * K); #define MM_DOT_PRODUCT( index, suffix ) \ - brow = ((__global float4 *)src1_read0)[0]; src1_read0 += N; \ - dot00 = mad( (float4)(intel_sub_group_shuffle( arow0, index ).s##suffix), brow, dot00 ); \ - dot01 = mad( (float4)(intel_sub_group_shuffle( arow1, index ).s##suffix), brow, dot01 ); \ - dot02 = mad( (float4)(intel_sub_group_shuffle( arow2, index ).s##suffix), brow, dot02 ); \ - dot03 = mad( (float4)(intel_sub_group_shuffle( arow3, index ).s##suffix), brow, dot03 ); \ - dot04 = mad( (float4)(intel_sub_group_shuffle( arow4, index ).s##suffix), brow, dot04 ); \ - dot05 = mad( (float4)(intel_sub_group_shuffle( arow5, index ).s##suffix), brow, dot05 ); \ - dot06 = mad( (float4)(intel_sub_group_shuffle( arow6, index ).s##suffix), brow, dot06 ); \ - dot07 = mad( (float4)(intel_sub_group_shuffle( arow7, index ).s##suffix), brow, dot07 ); \ + brow = vload4(0, src1_read0); src1_read0 += N; \ + dot00 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \ + dot01 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \ + dot02 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \ + dot03 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \ + dot04 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \ + dot05 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \ + dot06 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \ + dot07 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 ); \ MM_DOT_PRODUCT(0, 0); MM_DOT_PRODUCT(0, 1); @@ -622,6 +1012,24 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( MM_DOT_PRODUCT(6, 1); MM_DOT_PRODUCT(7, 0); MM_DOT_PRODUCT(7, 1); +#if TYPE == TYPE_HALF + MM_DOT_PRODUCT(8, 0); + MM_DOT_PRODUCT(8, 1); + MM_DOT_PRODUCT(9, 0); + MM_DOT_PRODUCT(9, 1); + MM_DOT_PRODUCT(10, 0); + MM_DOT_PRODUCT(10, 1); + MM_DOT_PRODUCT(11, 0); + MM_DOT_PRODUCT(11, 1); + MM_DOT_PRODUCT(12, 0); + MM_DOT_PRODUCT(12, 1); + MM_DOT_PRODUCT(13, 0); + MM_DOT_PRODUCT(13, 1); + MM_DOT_PRODUCT(14, 0); + MM_DOT_PRODUCT(14, 1); + MM_DOT_PRODUCT(15, 0); + MM_DOT_PRODUCT(15, 1); +#endif #undef MM_DOT_PRODUCT src0_read += TILE_K; @@ -647,15 +1055,15 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f; #define MM_DOT_PRODUCT( index, suffix ) \ - brow = (w < K) ? ((__global float4 *)src1_read0)[0] : 0.0f; src1_read0 += N; w++; \ - dot00 = mad( (float4)(intel_sub_group_shuffle( arow0, index ).s##suffix), brow, dot00 ); \ - dot01 = mad( (float4)(intel_sub_group_shuffle( arow1, index ).s##suffix), brow, dot01 ); \ - dot02 = mad( (float4)(intel_sub_group_shuffle( arow2, index ).s##suffix), brow, dot02 ); \ - dot03 = mad( (float4)(intel_sub_group_shuffle( arow3, index ).s##suffix), brow, dot03 ); \ - dot04 = mad( (float4)(intel_sub_group_shuffle( arow4, index ).s##suffix), brow, dot04 ); \ - dot05 = mad( (float4)(intel_sub_group_shuffle( arow5, index ).s##suffix), brow, dot05 ); \ - dot06 = mad( (float4)(intel_sub_group_shuffle( arow6, index ).s##suffix), brow, dot06 ); \ - dot07 = mad( (float4)(intel_sub_group_shuffle( arow7, index ).s##suffix), brow, dot07 ); \ + brow = (w < K) ? vload4(0, src1_read0) : (float4)0.0f; src1_read0 += N; w++; \ + dot00 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \ + dot01 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \ + dot02 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \ + dot03 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \ + dot04 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \ + dot05 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \ + dot06 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \ + dot07 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 ); \ MM_DOT_PRODUCT(0, 0); MM_DOT_PRODUCT(0, 1); @@ -673,87 +1081,102 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( MM_DOT_PRODUCT(6, 1); MM_DOT_PRODUCT(7, 0); MM_DOT_PRODUCT(7, 1); +#if TYPE == TYPE_HALF + MM_DOT_PRODUCT(8, 0); + MM_DOT_PRODUCT(8, 1); + MM_DOT_PRODUCT(9, 0); + MM_DOT_PRODUCT(9, 1); + MM_DOT_PRODUCT(10, 0); + MM_DOT_PRODUCT(10, 1); + MM_DOT_PRODUCT(11, 0); + MM_DOT_PRODUCT(11, 1); + MM_DOT_PRODUCT(12, 0); + MM_DOT_PRODUCT(12, 1); + MM_DOT_PRODUCT(13, 0); + MM_DOT_PRODUCT(13, 1); + MM_DOT_PRODUCT(14, 0); + MM_DOT_PRODUCT(14, 1); + MM_DOT_PRODUCT(15, 0); + MM_DOT_PRODUCT(15, 1); +#endif #undef MM_DOT_PRODUCT } if(global_x * 4 < N && global_y * 8 < M) { if(mad24(global_x, 4, 3) < N) { - __global float4 *dst_write = (__global float4 *)dst_write0; - dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0; - if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + vstore4(dot00, 0, dst_write0); dst_write0 += N; + if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; } + if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); } } else if(mad24(global_x, 4, 2) < N) { - __global float2 *dst_write = (__global float2 *)dst_write0; - dst_write[0] = dot00.xy; + vstore2(dot00.xy, 0, dst_write0); dst_write0[2] = dot00.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + dst_write0 += N; if(mad24(global_y, 8, 1) < M) { - dst_write[0] = dot01.xy; + vstore2(dot01.xy, 0, dst_write0); dst_write0[2] = dot01.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + dst_write0 += N; } else return; if(mad24(global_y, 8, 2) < M) { - dst_write[0] = dot02.xy; + vstore2(dot02.xy, 0, dst_write0); dst_write0[2] = dot02.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + dst_write0 += N; } else return; if(mad24(global_y, 8, 3) < M) { - dst_write[0] = dot03.xy; + vstore2(dot03.xy, 0, dst_write0); dst_write0[2] = dot03.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + dst_write0 += N; } else return; if(mad24(global_y, 8, 4) < M) { - dst_write[0] = dot04.xy; + vstore2(dot04.xy, 0, dst_write0); dst_write0[2] = dot04.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + dst_write0 += N; } else return; if(mad24(global_y, 8, 5) < M) { - dst_write[0] = dot05.xy; + vstore2(dot05.xy, 0, dst_write0); dst_write0[2] = dot05.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + dst_write0 += N; } else return; if(mad24(global_y, 8, 6) < M) { - dst_write[0] = dot06.xy; + vstore2(dot06.xy, 0, dst_write0); dst_write0[2] = dot06.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + dst_write0 += N; } else return; if(mad24(global_y, 8, 7) < M) { - dst_write[0] = dot07.xy; + vstore2(dot07.xy, 0, dst_write0); dst_write0[2] = dot07.z; } } else if(mad24(global_x, 4, 1) < N) { - __global float2 *dst_write = (__global float2 *)dst_write0; - dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; - if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + vstore2(dot00.xy, 0, dst_write0); dst_write0 += N; + if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; } + if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); } } else { dst_write0[0] = dot00.x; dst_write0 += N; if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; } @@ -779,6 +1202,7 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( #undef TILE_K #undef TILE_N + #define VEC_SIZE 1 #define LWG_HEIGHT 16 #define TILE_M 8 @@ -787,6 +1211,7 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( #define SLM_BLOCK 512 __attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( const __global float *src0, int off0, const __global float *src1, int off1, @@ -794,9 +1219,11 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( int M, int N, int K, - float alpha, - float beta) + float alpha_in, + float beta_in) { + const float alpha = (float)alpha_in; + const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); const int local_x = get_local_id(0); @@ -822,11 +1249,11 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( float4 brow6; float4 brow7; - __global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + __global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; - const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * K + off0; + const __global float *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0; - const __global float *src1_read0 = src1 + ( group_x * TILE_N ) * K + off1; + const __global float *src1_read0 = src1 + (group_x * TILE_N) * K + off1; __local float slm_brow[8 * SLM_BLOCK]; __local float* slm_brow0; @@ -835,14 +1262,14 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( int w; for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) { barrier(CLK_LOCAL_MEM_FENCE); - ((__local float4 *)(slm_brow + mad24(0, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(0, K, local_index)))[0]; - ((__local float4 *)(slm_brow + mad24(1, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(1, K, local_index)))[0]; - ((__local float4 *)(slm_brow + mad24(2, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(2, K, local_index)))[0]; - ((__local float4 *)(slm_brow + mad24(3, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(3, K, local_index)))[0]; - ((__local float4 *)(slm_brow + mad24(4, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(4, K, local_index)))[0]; - ((__local float4 *)(slm_brow + mad24(5, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(5, K, local_index)))[0]; - ((__local float4 *)(slm_brow + mad24(6, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(6, K, local_index)))[0]; - ((__local float4 *)(slm_brow + mad24(7, SLM_BLOCK, local_index)))[0] = ((__global float4 *)(src1_read0 + mad24(7, K, local_index)))[0]; + vstore4(vload4(0, src1_read0 + mad24(0, K, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index)); + vstore4(vload4(0, src1_read0 + mad24(1, K, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index)); + vstore4(vload4(0, src1_read0 + mad24(2, K, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index)); + vstore4(vload4(0, src1_read0 + mad24(3, K, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index)); + vstore4(vload4(0, src1_read0 + mad24(4, K, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index)); + vstore4(vload4(0, src1_read0 + mad24(5, K, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index)); + vstore4(vload4(0, src1_read0 + mad24(6, K, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index)); + vstore4(vload4(0, src1_read0 + mad24(7, K, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index)); barrier(CLK_LOCAL_MEM_FENCE); slm_brow0 = slm_brow + local_x * (TILE_K / 8); @@ -851,17 +1278,17 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( while( w + TILE_K <= end_w ) { float4 arow; - brow0 = ((__local float4 *)(slm_brow0 + 0 * SLM_BLOCK))[0]; - brow1 = ((__local float4 *)(slm_brow0 + 1 * SLM_BLOCK))[0]; - brow2 = ((__local float4 *)(slm_brow0 + 2 * SLM_BLOCK))[0]; - brow3 = ((__local float4 *)(slm_brow0 + 3 * SLM_BLOCK))[0]; - brow4 = ((__local float4 *)(slm_brow0 + 4 * SLM_BLOCK))[0]; - brow5 = ((__local float4 *)(slm_brow0 + 5 * SLM_BLOCK))[0]; - brow6 = ((__local float4 *)(slm_brow0 + 6 * SLM_BLOCK))[0]; - brow7 = ((__local float4 *)(slm_brow0 + 7 * SLM_BLOCK))[0]; + brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK); + brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK); + brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK); + brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK); + brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK); + brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK); + brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK); + brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK); #define MM_DOT_PRODUCT( _row, _dot ) \ - arow = ((__global float4 *)(src0_read + _row * K))[0]; \ + arow = vload4(0, src0_read + _row * K); \ _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ @@ -888,7 +1315,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( float4 arow; #define READ_BROW(_brow, _row) \ - _brow = ((__local float4 *)(slm_brow0 + _row * SLM_BLOCK))[0]; \ + _brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \ _brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \ _brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \ _brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \ @@ -904,7 +1331,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( READ_BROW(brow7, 7); #define MM_DOT_PRODUCT( _row, _dot ) \ - arow = ((__global float4 *)(src0_read + _row * K))[0]; \ + arow = vload4(0, src0_read + _row * K); \ arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \ arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \ arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \ @@ -926,8 +1353,8 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( } #define REDUCE(_dot) \ - _dot = intel_sub_group_shuffle(_dot, 0) + intel_sub_group_shuffle(_dot, 1) + intel_sub_group_shuffle(_dot, 2) + intel_sub_group_shuffle(_dot, 3) + \ - intel_sub_group_shuffle(_dot, 4) + intel_sub_group_shuffle(_dot, 5) + intel_sub_group_shuffle(_dot, 6) + intel_sub_group_shuffle(_dot, 7); \ + _dot = as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \ + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7)); \ REDUCE(dot00); REDUCE(dot01); @@ -974,32 +1401,32 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( #define SLM_SIZE 64 void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( - const __global Dtype* srca_read0, - const __global Dtype* srca_read1, - const __global Dtype* srcb_read, - __local Dtype4* work0, - __local Dtype4* work1, + const __global float* srca_read0, + const __global float* srca_read1, + const __global float* srcb_read, + __local float4* work0, + __local float4* work1, int N, int K, int x_gid, int lid, - Dtype alpha, - Dtype beta, - __global Dtype* dstc0, - __global Dtype* dstc1) + float alpha, + float beta, + __global float* dstc0, + __global float* dstc1) { - __local Dtype* work_each0 = (__local Dtype*)work0; - __local Dtype* work_each1 = (__local Dtype*)work1; + __local float* work_each0 = (__local float*)work0; + __local float* work_each1 = (__local float*)work1; int rows = N - x_gid * 4; - Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; - Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + float4 dot0[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; + float4 dot1[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; int i = lid; while( i < K / 4) { - const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; - const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; + const float4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; + const float4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; #pragma unroll for(int j = 0; j < rows; ++j) { dot0[j] += b0 * vload4(i, srcb_read + j * K); @@ -1018,13 +1445,13 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( short tail_items = K % 4; if(tail_items != 0) { - const __global Dtype *srcb_tail = srcb_read + i * 4; - const __global Dtype *srca_tail0 = srca_read0 + i * 4; - const __global Dtype *srca_tail1 = srca_read1 + i * 4; + const __global float *srcb_tail = srcb_read + i * 4; + const __global float *srca_tail0 = srca_read0 + i * 4; + const __global float *srca_tail1 = srca_read1 + i * 4; #pragma unroll for(short i = 0; i < tail_items; ++i) { - const Dtype at0 = srca_tail0[i]; - const Dtype at1 = srca_tail1[i]; + const float at0 = srca_tail0[i]; + const float at1 = srca_tail1[i]; #pragma unroll for(int j = 0; j < rows; ++j) { work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K]; @@ -1052,11 +1479,11 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( } __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( - __global const Dtype * A, + __global const float * A, int offA, - __global const Dtype * B, + __global const float * B, int offB, - __global Dtype * C, + __global float * C, int offC, int M, int N, @@ -1064,37 +1491,37 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( float alpha_f, float beta_f) { - Dtype alpha = (Dtype)alpha_f; - Dtype beta = (Dtype)beta_f; + float alpha = (float)alpha_f; + float beta = (float)beta_f; int x_gid = get_group_id(0); int lid = get_local_id(0); - const __global Dtype *srca_read0 = A + offA; - const __global Dtype *srca_read1 = srca_read0 + K; + const __global float *srca_read0 = A + offA; + const __global float *srca_read1 = srca_read0 + K; - const __global Dtype *srcb_read = B + x_gid * 4 * K + offB; + const __global float *srcb_read = B + x_gid * 4 * K + offB; - __global Dtype4 *dstc0 = (__global Dtype4*)(C + offC); - __global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N); + __global float4 *dstc0 = (__global float4*)(C + offC); + __global float4 *dstc1 = (__global float4*)((__global float*)(dstc0) + N); - __local Dtype4 work0[SLM_SIZE]; - __local Dtype4 work1[SLM_SIZE]; - __local Dtype* work_each0 = (__local Dtype*)work0; - __local Dtype* work_each1 = (__local Dtype*)work1; + __local float4 work0[SLM_SIZE]; + __local float4 work1[SLM_SIZE]; + __local float* work_each0 = (__local float*)work0; + __local float* work_each1 = (__local float*)work1; if(x_gid == N / 4) { TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) \ - (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1); + (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global float*)dstc0, (__global float*)dstc1); } else { - Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; - Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + float4 dot0[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; + float4 dot1[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; int i = lid; while( i < K / 4) { - const Dtype4 b0 = vload4(i, srca_read0); - const Dtype4 b1 = vload4(i, srca_read1); + const float4 b0 = vload4(i, srca_read0); + const float4 b1 = vload4(i, srca_read1); #pragma unroll for(int j = 0; j < 4; ++j) { - Dtype4 a = vload4(i, srcb_read + j * K); + float4 a = vload4(i, srcb_read + j * K); dot0[j] += b0 * a; dot1[j] += b1 * a; } @@ -1110,14 +1537,14 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( if(i == K / 4) { short tail_items = K % 4; if(tail_items != 0) { - const __global Dtype *srcb_tail = srcb_read + i * 4; + const __global float *srcb_tail = srcb_read + i * 4; - const __global Dtype *srca_tail0 = srca_read0 + i * 4; - const __global Dtype *srca_tail1 = srca_read1 + i * 4; + const __global float *srca_tail0 = srca_read0 + i * 4; + const __global float *srca_tail1 = srca_read1 + i * 4; #pragma unroll for(short i = 0; i < tail_items; ++i) { - const Dtype at0 = srca_tail0[i]; - const Dtype at1 = srca_tail1[i]; + const float at0 = srca_tail0[i]; + const float at1 = srca_tail1[i]; #pragma unroll for(int j = 0; j < 4; ++j) { work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K]; @@ -1145,44 +1572,44 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( #define SLM_SIZE 32 void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)( - const __global Dtype* srca_read0, - const __global Dtype* srca_read1, - const __global Dtype* srca_read2, - const __global Dtype* srca_read3, - const __global Dtype* srcb_read, - __local Dtype4* work0, - __local Dtype4* work1, - __local Dtype4* work2, - __local Dtype4* work3, + const __global float* srca_read0, + const __global float* srca_read1, + const __global float* srca_read2, + const __global float* srca_read3, + const __global float* srcb_read, + __local float4* work0, + __local float4* work1, + __local float4* work2, + __local float4* work3, int N, int K, int x_gid, int lid, - Dtype alpha, - Dtype beta, - __global Dtype* dstc0, - __global Dtype* dstc1, - __global Dtype* dstc2, - __global Dtype* dstc3) + float alpha, + float beta, + __global float* dstc0, + __global float* dstc1, + __global float* dstc2, + __global float* dstc3) { - __local Dtype* work_each0 = (__local Dtype*)(work0 + lid); - __local Dtype* work_each1 = (__local Dtype*)(work1 + lid); - __local Dtype* work_each2 = (__local Dtype*)(work2 + lid); - __local Dtype* work_each3 = (__local Dtype*)(work3 + lid); + __local float* work_each0 = (__local float*)(work0 + lid); + __local float* work_each1 = (__local float*)(work1 + lid); + __local float* work_each2 = (__local float*)(work2 + lid); + __local float* work_each3 = (__local float*)(work3 + lid); int rows = N - x_gid * 4; - Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; - Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; - Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; - Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + float4 dot0[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; + float4 dot1[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; + float4 dot2[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; + float4 dot3[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; int i = lid; while( i < K / 4) { - const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; - const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; - const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]}; - const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]}; + const float4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; + const float4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; + const float4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]}; + const float4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]}; #pragma unrol for(int j = 0; j < rows; ++j) { dot0[j] += a0 * vload4(i, srcb_read + j * K); @@ -1205,18 +1632,18 @@ void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)( short tail_items = K % 4; if(tail_items != 0) { - const __global Dtype *srcb_tail = srcb_read + i * 4; + const __global float *srcb_tail = srcb_read + i * 4; - const __global Dtype *srca_tail0 = srca_read0 + i * 4; - const __global Dtype *srca_tail1 = srca_read1 + i * 4; - const __global Dtype *srca_tail2 = srca_read2 + i * 4; - const __global Dtype *srca_tail3 = srca_read3 + i * 4; + const __global float *srca_tail0 = srca_read0 + i * 4; + const __global float *srca_tail1 = srca_read1 + i * 4; + const __global float *srca_tail2 = srca_read2 + i * 4; + const __global float *srca_tail3 = srca_read3 + i * 4; #pragma unroll for(short i = 0; i < tail_items; ++i) { - const Dtype at0 = srca_tail0[i]; - const Dtype at1 = srca_tail1[i]; - const Dtype at2 = srca_tail2[i]; - const Dtype at3 = srca_tail3[i]; + const float at0 = srca_tail0[i]; + const float at1 = srca_tail1[i]; + const float at2 = srca_tail2[i]; + const float at3 = srca_tail3[i]; #pragma unroll for(int j = 0; j < rows; ++j) { work_each0[j] += at0 * srcb_tail[i + j * K]; @@ -1250,11 +1677,11 @@ void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)( } __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( - __global const Dtype * A, + __global const float * A, int offA, - __global const Dtype * B, + __global const float * B, int offB, - __global Dtype * C, + __global float * C, int offC, int M, int N, @@ -1262,53 +1689,53 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( float alpha_f, float beta_f) { - Dtype alpha = (Dtype)alpha_f; - Dtype beta = (Dtype)beta_f; + float alpha = (float)alpha_f; + float beta = (float)beta_f; int x_gid = get_group_id(0); int lid = get_local_id(0); int lsize = get_local_size(0); - const __global Dtype *srca_read0 = A + offA; - const __global Dtype *srca_read1 = srca_read0 + K; - const __global Dtype *srca_read2 = srca_read1 + K; - const __global Dtype *srca_read3 = srca_read2 + K; + const __global float *srca_read0 = A + offA; + const __global float *srca_read1 = srca_read0 + K; + const __global float *srca_read2 = srca_read1 + K; + const __global float *srca_read3 = srca_read2 + K; - const __global Dtype *srcb_read = B + x_gid * 4 * K + offB; + const __global float *srcb_read = B + x_gid * 4 * K + offB; - __global Dtype4 *dstc0 = (__global Dtype4*)(C + offC); - __global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N); - __global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N); - __global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N); + __global float4 *dstc0 = (__global float4*)(C + offC); + __global float4 *dstc1 = (__global float4*)((__global float*)(dstc0) + N); + __global float4 *dstc2 = (__global float4*)((__global float*)(dstc1) + N); + __global float4 *dstc3 = (__global float4*)((__global float*)(dstc2) + N); - __local Dtype4 work0[SLM_SIZE]; - __local Dtype4 work1[SLM_SIZE]; - __local Dtype4 work2[SLM_SIZE]; - __local Dtype4 work3[SLM_SIZE]; - __local Dtype* work_each0 = (__local Dtype*)(work0 + lid); - __local Dtype* work_each1 = (__local Dtype*)(work1 + lid); - __local Dtype* work_each2 = (__local Dtype*)(work2 + lid); - __local Dtype* work_each3 = (__local Dtype*)(work3 + lid); + __local float4 work0[SLM_SIZE]; + __local float4 work1[SLM_SIZE]; + __local float4 work2[SLM_SIZE]; + __local float4 work3[SLM_SIZE]; + __local float* work_each0 = (__local float*)(work0 + lid); + __local float* work_each1 = (__local float*)(work1 + lid); + __local float* work_each2 = (__local float*)(work2 + lid); + __local float* work_each3 = (__local float*)(work3 + lid); if(x_gid == N / 4) { TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) \ (srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, \ work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, \ - (__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3); + (__global float*)dstc0, (__global float*)dstc1, (__global float*)dstc2, (__global float*)dstc3); } else { - Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; - Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; - Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; - Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + float4 dot0[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; + float4 dot1[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; + float4 dot2[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; + float4 dot3[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; int kid = lid; while( kid < K / 4) { - const Dtype4 b0 = vload4(kid, srca_read0); - const Dtype4 b1 = vload4(kid, srca_read1); - const Dtype4 b2 = vload4(kid, srca_read2); - const Dtype4 b3 = vload4(kid, srca_read3); + const float4 b0 = vload4(kid, srca_read0); + const float4 b1 = vload4(kid, srca_read1); + const float4 b2 = vload4(kid, srca_read2); + const float4 b3 = vload4(kid, srca_read3); #pragma unroll for(int j = 0; j < 4; ++j) { - Dtype4 a = vload4(kid, srcb_read + j * K); + float4 a = vload4(kid, srcb_read + j * K); dot0[j] += b0 * a; dot1[j] += b1 * a; dot2[j] += b2 * a; @@ -1328,18 +1755,18 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( short tail_items = K % 4; if(tail_items != 0) { int offset = kid << 2; - const __global Dtype *srcb_tail = srcb_read + offset; + const __global float *srcb_tail = srcb_read + offset; - const __global Dtype *srca_tail0 = srca_read0 + offset; - const __global Dtype *srca_tail1 = srca_read1 + offset; - const __global Dtype *srca_tail2 = srca_read2 + offset; - const __global Dtype *srca_tail3 = srca_read3 + offset; + const __global float *srca_tail0 = srca_read0 + offset; + const __global float *srca_tail1 = srca_read1 + offset; + const __global float *srca_tail2 = srca_read2 + offset; + const __global float *srca_tail3 = srca_read3 + offset; #pragma unroll for(short i = 0; i < tail_items; ++i) { - const Dtype at0 = srca_tail0[i]; - const Dtype at1 = srca_tail1[i]; - const Dtype at2 = srca_tail2[i]; - const Dtype at3 = srca_tail3[i]; + const float at0 = srca_tail0[i]; + const float at1 = srca_tail1[i]; + const float at2 = srca_tail2[i]; + const float at3 = srca_tail3[i]; #pragma unroll for(int j = 0; j < 4; ++j) { work_each0[j] += at0 * srcb_tail[i + j * K]; @@ -1373,11 +1800,11 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( #define SLM_SIZE 16 __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( - __global const Dtype * A, + __global const float * A, int offA, - __global const Dtype * B, + __global const float * B, int offB, - __global Dtype * C, + __global float * C, int offC, int M, int N, @@ -1385,61 +1812,61 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( float alpha_f, float beta_f) { - Dtype alpha = (Dtype)alpha_f; - Dtype beta = (Dtype)beta_f; + float alpha = (float)alpha_f; + float beta = (float)beta_f; int x_gid = get_group_id(0); int lid = get_local_id(0); int lsize = get_local_size(0); - const __global Dtype *srca_read0 = A + offA; - const __global Dtype *srca_read1 = srca_read0 + K; - const __global Dtype *srca_read2 = srca_read1 + K; - const __global Dtype *srca_read3 = srca_read2 + K; - const __global Dtype *srca_read4 = srca_read3 + K; - const __global Dtype *srca_read5 = srca_read4 + K; - const __global Dtype *srca_read6 = srca_read5 + K; - const __global Dtype *srca_read7 = srca_read6 + K; - - const __global Dtype *srcb_read = B + x_gid * K + offB; - - __global Dtype *dstc0 = C + offC; - __global Dtype *dstc1 = dstc0 + N; - __global Dtype *dstc2 = dstc1 + N; - __global Dtype *dstc3 = dstc2 + N; - __global Dtype *dstc4 = dstc3 + N; - __global Dtype *dstc5 = dstc4 + N; - __global Dtype *dstc6 = dstc5 + N; - __global Dtype *dstc7 = dstc6 + N; - - __local Dtype work0[SLM_SIZE]; - __local Dtype work1[SLM_SIZE]; - __local Dtype work2[SLM_SIZE]; - __local Dtype work3[SLM_SIZE]; - __local Dtype work4[SLM_SIZE]; - __local Dtype work5[SLM_SIZE]; - __local Dtype work6[SLM_SIZE]; - __local Dtype work7[SLM_SIZE]; - - Dtype4 dot0 = (Dtype4)(0.); - Dtype4 dot1 = (Dtype4)(0.); - Dtype4 dot2 = (Dtype4)(0.); - Dtype4 dot3 = (Dtype4)(0.); - Dtype4 dot4 = (Dtype4)(0.); - Dtype4 dot5 = (Dtype4)(0.); - Dtype4 dot6 = (Dtype4)(0.); - Dtype4 dot7 = (Dtype4)(0.); + const __global float *srca_read0 = A + offA; + const __global float *srca_read1 = srca_read0 + K; + const __global float *srca_read2 = srca_read1 + K; + const __global float *srca_read3 = srca_read2 + K; + const __global float *srca_read4 = srca_read3 + K; + const __global float *srca_read5 = srca_read4 + K; + const __global float *srca_read6 = srca_read5 + K; + const __global float *srca_read7 = srca_read6 + K; + + const __global float *srcb_read = B + x_gid * K + offB; + + __global float *dstc0 = C + offC; + __global float *dstc1 = dstc0 + N; + __global float *dstc2 = dstc1 + N; + __global float *dstc3 = dstc2 + N; + __global float *dstc4 = dstc3 + N; + __global float *dstc5 = dstc4 + N; + __global float *dstc6 = dstc5 + N; + __global float *dstc7 = dstc6 + N; + + __local float work0[SLM_SIZE]; + __local float work1[SLM_SIZE]; + __local float work2[SLM_SIZE]; + __local float work3[SLM_SIZE]; + __local float work4[SLM_SIZE]; + __local float work5[SLM_SIZE]; + __local float work6[SLM_SIZE]; + __local float work7[SLM_SIZE]; + + float4 dot0 = (float4)(0.); + float4 dot1 = (float4)(0.); + float4 dot2 = (float4)(0.); + float4 dot3 = (float4)(0.); + float4 dot4 = (float4)(0.); + float4 dot5 = (float4)(0.); + float4 dot6 = (float4)(0.); + float4 dot7 = (float4)(0.); int kid = lid; while( kid < K / 4) { - const Dtype4 a0 = vload4(kid, srca_read0); - const Dtype4 a1 = vload4(kid, srca_read1); - const Dtype4 a2 = vload4(kid, srca_read2); - const Dtype4 a3 = vload4(kid, srca_read3); - const Dtype4 a4 = vload4(kid, srca_read4); - const Dtype4 a5 = vload4(kid, srca_read5); - const Dtype4 a6 = vload4(kid, srca_read6); - const Dtype4 a7 = vload4(kid, srca_read7); - Dtype4 b = vload4(kid, srcb_read); + const float4 a0 = vload4(kid, srca_read0); + const float4 a1 = vload4(kid, srca_read1); + const float4 a2 = vload4(kid, srca_read2); + const float4 a3 = vload4(kid, srca_read3); + const float4 a4 = vload4(kid, srca_read4); + const float4 a5 = vload4(kid, srca_read5); + const float4 a6 = vload4(kid, srca_read6); + const float4 a7 = vload4(kid, srca_read7); + float4 b = vload4(kid, srcb_read); dot0 += a0 * b; dot1 += a1 * b; dot2 += a2 * b; @@ -1464,16 +1891,16 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( short tail_items = K % 4; if(tail_items != 0) { int offset = kid << 2; - const __global Dtype *srcb_tail = srcb_read + offset; - - const __global Dtype *srca_tail0 = srca_read0 + offset; - const __global Dtype *srca_tail1 = srca_read1 + offset; - const __global Dtype *srca_tail2 = srca_read2 + offset; - const __global Dtype *srca_tail3 = srca_read3 + offset; - const __global Dtype *srca_tail4 = srca_read4 + offset; - const __global Dtype *srca_tail5 = srca_read5 + offset; - const __global Dtype *srca_tail6 = srca_read6 + offset; - const __global Dtype *srca_tail7 = srca_read7 + offset; + const __global float *srcb_tail = srcb_read + offset; + + const __global float *srca_tail0 = srca_read0 + offset; + const __global float *srca_tail1 = srca_read1 + offset; + const __global float *srca_tail2 = srca_read2 + offset; + const __global float *srca_tail3 = srca_read3 + offset; + const __global float *srca_tail4 = srca_read4 + offset; + const __global float *srca_tail5 = srca_read5 + offset; + const __global float *srca_tail6 = srca_read6 + offset; + const __global float *srca_tail7 = srca_read7 + offset; #pragma unroll for(short item = 0; item < tail_items; ++item) { work0[lid] += srca_tail0[item] * srcb_tail[item]; @@ -1518,10 +1945,16 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( #define VEC_SIZE 4 #define LWG_HEIGHT 4 #define TILE_M 8 +#if TYPE == TYPE_HALF +#define TILE_K 32 +#define TILE_N 64 +#else #define TILE_K 16 #define TILE_N 32 +#endif -__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) +__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1))) +__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( const __global float *src0, int off0, const __global float *src1, int off1, @@ -1529,11 +1962,13 @@ __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( int M, int N, int K, - float alpha, - float beta, + float alpha_in, + float beta_in, int start_index) { + const float alpha = (float)alpha_in; + const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); const int local_x = get_local_id(0); @@ -1543,28 +1978,28 @@ __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( float4 brow; - __global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + __global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; - const __global float *src0_read = src0 + (local_x * ( TILE_K / 8 ) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; + const __global float *src0_read = src0 + (local_x * (TILE_K / SIMD_SIZE_GEMM) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; - const __global float *src1_read0 = src1 + local_x * VEC_SIZE + ( group_x * TILE_N ) + start_index * N + off1; + const __global float *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1; - float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : (beta * ((__global float4 *)dst_write0)[0]); - float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + N))[0] : (beta * ((__global float4 *)(dst_write0 + N))[0]); - float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 2 * N))[0]); - float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 3 * N))[0]); - float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 4 * N))[0]); - float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 5 * N))[0]); - float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 6 * N))[0]); - float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 7 * N))[0]); + float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); + float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N); + float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); + float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); + float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); + float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); + float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); + float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); int end_index = min(start_index + 256, K); while( start_index + TILE_K <= end_index ) { - float8 arow0 = alpha * ((__global float8 *)src0_read)[0]; - float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0]; + float8 arow0 = alpha * vload8(0, src0_read); + float8 arow1 = alpha * vload8(0, src0_read + M); #define MM_DOT_PRODUCT( _arow ) \ - brow = ((__global float4 *)src1_read0)[0]; src1_read0 += N; \ + brow = vload4(0, src1_read0); src1_read0 += N; \ dot00 = mad( (float4)(_arow.s0), brow, dot00 ); \ dot01 = mad( (float4)(_arow.s1), brow, dot01 ); \ dot02 = mad( (float4)(_arow.s2), brow, dot02 ); \ @@ -1574,22 +2009,40 @@ __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( dot06 = mad( (float4)(_arow.s6), brow, dot06 ); \ dot07 = mad( (float4)(_arow.s7), brow, dot07 ); \ - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 0 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 0 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 1 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 1 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 2 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 2 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 3 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 3 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 4 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 4 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 5 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 5 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 6 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 6 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 7 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 7 ) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) ); +#if TYPE == TYPE_HALF + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) ); +#endif #undef MM_DOT_PRODUCT src0_read += TILE_K * M; @@ -1597,11 +2050,11 @@ __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( } if(start_index < end_index) { - float8 arow0 = ((start_index + local_x * 2) < K) ? (alpha * ((__global float8 *)src0_read)[0]) : 0.0f; - float8 arow1 = ((start_index + local_x * 2 + 1) < K) ? (alpha * ((__global float8 *)(src0_read + M))[0]) : 0.0f; + float8 arow0 = ((start_index + local_x * 2) < K) ? alpha * vload8(0, src0_read) : (float8)0.0f; + float8 arow1 = ((start_index + local_x * 2 + 1) < K) ? alpha * vload8(0, src0_read + M) : (float8)0.0f; #define MM_DOT_PRODUCT( _arow ) \ - brow = (start_index < K) ? ((__global float4 *)src1_read0)[0] : 0.0f; src1_read0 += N; start_index++; \ + brow = (start_index < K) ? vload4(0, src1_read0) : (float4)0.0f; src1_read0 += N; start_index++; \ dot00 = mad( (float4)(_arow.s0), brow, dot00 ); \ dot01 = mad( (float4)(_arow.s1), brow, dot01 ); \ dot02 = mad( (float4)(_arow.s2), brow, dot02 ); \ @@ -1611,95 +2064,110 @@ __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( dot06 = mad( (float4)(_arow.s6), brow, dot06 ); \ dot07 = mad( (float4)(_arow.s7), brow, dot07 ); \ - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 0 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 0 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 1 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 1 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 2 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 2 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 3 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 3 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 4 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 4 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 5 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 5 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 6 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 6 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow0, 7 ) ); - MM_DOT_PRODUCT( intel_sub_group_shuffle( arow1, 7 ) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) ); +#if TYPE == TYPE_HALF + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) ); + MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) ); +#endif #undef MM_DOT_PRODUCT } if(global_x * 4 < N && global_y * 8 < M) { if(mad24(global_x, 4, 3) < N) { - __global float4 *dst_write = (__global float4 *)dst_write0; - dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0; - if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + vstore4(dot00, 0, dst_write0); dst_write0 += N; + if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; } + if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); } } else if(mad24(global_x, 4, 2) < N) { - __global float2 *dst_write = (__global float2 *)dst_write0; - dst_write[0] = dot00.xy; dst_write0[2] = dot00.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot00.xy, 0, dst_write0); dst_write0[2] = dot00.z; + dst_write0 += N; if(mad24(global_y, 8, 1) < M) { - dst_write[0] = dot01.xy; dst_write0[2] = dot01.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot01.xy, 0, dst_write0); dst_write0[2] = dot01.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 2) < M) { - dst_write[0] = dot02.xy; dst_write0[2] = dot02.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot02.xy, 0, dst_write0); dst_write0[2] = dot02.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 3) < M) { - dst_write[0] = dot03.xy; dst_write0[2] = dot03.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot03.xy, 0, dst_write0); dst_write0[2] = dot03.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 4) < M) { - dst_write[0] = dot04.xy; dst_write0[2] = dot04.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot04.xy, 0, dst_write0); dst_write0[2] = dot04.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 5) < M) { - dst_write[0] = dot05.xy; dst_write0[2] = dot05.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot05.xy, 0, dst_write0); dst_write0[2] = dot05.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 6) < M) { - dst_write[0] = dot06.xy; dst_write0[2] = dot06.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot06.xy, 0, dst_write0); dst_write0[2] = dot06.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 7) < M) { - dst_write[0] = dot07.xy; dst_write0[2] = dot07.z; + vstore2(dot07.xy, 0, dst_write0); dst_write0[2] = dot07.z; } } else if(mad24(global_x, 4, 1) < N) { - __global float2 *dst_write = (__global float2 *)dst_write0; - dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; - if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + vstore2(dot00.xy, 0, dst_write0); dst_write0 += N; + if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; } + if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); } } else { dst_write0[0] = dot00.x; dst_write0 += N; if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; } @@ -1732,6 +2200,7 @@ __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( #define TILE_N 32 __attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( const __global float *src0, int off0, const __global float *src1, int off1, @@ -1739,11 +2208,13 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( int M, int N, int K, - float alpha, - float beta, + float alpha_in, + float beta_in, int start_index) { + const float alpha = (float)alpha_in; + const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); const int local_x = get_local_id(0); @@ -1761,48 +2232,48 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( float16 brow2; float16 brow3; - __global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + __global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; - const __global float *src0_read = src0 + (local_x * ( TILE_K / 8 ) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; + const __global float *src0_read = src0 + (local_x * (TILE_K / 8) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; - const __global float *src1_read0 = src1 + (local_x * VEC_SIZE + ( group_x * TILE_N )) * K + start_index + off1; + const __global float *src1_read0 = src1 + (local_x * VEC_SIZE + (group_x * TILE_N)) * K + start_index + off1; - float4 dot00 = (start_index != 0) ? ((__global float4 *)dst_write0)[0] : (beta * ((__global float4 *)dst_write0)[0]); - float4 dot01 = (start_index != 0) ? ((__global float4 *)(dst_write0 + N))[0] : (beta * ((__global float4 *)(dst_write0 + N))[0]); - float4 dot02 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 2 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 2 * N))[0]); - float4 dot03 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 3 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 3 * N))[0]); - float4 dot04 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 4 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 4 * N))[0]); - float4 dot05 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 5 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 5 * N))[0]); - float4 dot06 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 6 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 6 * N))[0]); - float4 dot07 = (start_index != 0) ? ((__global float4 *)(dst_write0 + 7 * N))[0] : (beta * ((__global float4 *)(dst_write0 + 7 * N))[0]); + float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); + float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N); + float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); + float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); + float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); + float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); + float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); + float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); int end_index = min(start_index + 256, K); while( start_index + TILE_K <= end_index ) { - brow0 = ((__global float16 *)src1_read0)[0]; - brow1 = ((__global float16 *)(src1_read0 + K))[0]; - brow2 = ((__global float16 *)(src1_read0 + 2 * K))[0]; - brow3 = ((__global float16 *)(src1_read0 + 3 * K))[0]; + brow0 = vload16(0, src1_read0); + brow1 = vload16(0, src1_read0 + K); + brow2 = vload16(0, src1_read0 + 2 * K); + brow3 = vload16(0, src1_read0 + 3 * K); - float8 arow0 = alpha * ((__global float8 *)src0_read)[0]; - float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0]; + float8 arow0 = alpha * vload8(0, src0_read); + float8 arow1 = alpha * vload8(0, src0_read + M); #define MM_DOT_PRODUCT( _brow, _dot) \ - _dot = mad( intel_sub_group_shuffle( arow0, 0 ), (float8)_brow.s0, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow1, 0 ), (float8)_brow.s1, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow0, 1 ), (float8)_brow.s2, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow1, 1 ), (float8)_brow.s3, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow0, 2 ), (float8)_brow.s4, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow1, 2 ), (float8)_brow.s5, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow0, 3 ), (float8)_brow.s6, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow1, 3 ), (float8)_brow.s7, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow0, 4 ), (float8)_brow.s8, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow1, 4 ), (float8)_brow.s9, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow0, 5 ), (float8)_brow.sa, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow1, 5 ), (float8)_brow.sb, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow0, 6 ), (float8)_brow.sc, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow1, 6 ), (float8)_brow.sd, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow0, 7 ), (float8)_brow.se, _dot ); \ - _dot = mad( intel_sub_group_shuffle( arow1, 7 ), (float8)_brow.sf, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (float8)_brow.s0, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (float8)_brow.s1, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (float8)_brow.s2, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (float8)_brow.s3, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (float8)_brow.s4, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (float8)_brow.s5, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (float8)_brow.s6, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (float8)_brow.s7, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (float8)_brow.s8, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (float8)_brow.s9, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (float8)_brow.sa, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (float8)_brow.sb, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (float8)_brow.sc, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (float8)_brow.sd, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (float8)_brow.se, _dot ); \ + _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (float8)_brow.sf, _dot ); \ MM_DOT_PRODUCT( brow0, dot0 ); MM_DOT_PRODUCT( brow1, dot1 ); @@ -1816,31 +2287,31 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( } if(start_index < end_index) { - brow0 = ((__global float16 *)src1_read0)[0]; src1_read0 += K; - brow1 = ((__global float16 *)src1_read0)[0]; src1_read0 += K; - brow2 = ((__global float16 *)src1_read0)[0]; src1_read0 += K; - brow3 = ((__global float16 *)src1_read0)[0]; + brow0 = vload16(0, src1_read0); src1_read0 += K; + brow1 = vload16(0, src1_read0); src1_read0 += K; + brow2 = vload16(0, src1_read0); src1_read0 += K; + brow3 = vload16(0, src1_read0); - float8 arow0 = alpha * ((__global float8 *)src0_read)[0]; - float8 arow1 = alpha * ((__global float8 *)(src0_read + M))[0]; + float8 arow0 = alpha * vload8(0, src0_read); + float8 arow1 = alpha * vload8(0, src0_read + M); #define MM_DOT_PRODUCT( _brow, _dot) \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 0 ), (float8)_brow.s0, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 0 ), (float8)_brow.s1, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 1 ), (float8)_brow.s2, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 1 ), (float8)_brow.s3, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 2 ), (float8)_brow.s4, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 2 ), (float8)_brow.s5, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 3 ), (float8)_brow.s6, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 3 ), (float8)_brow.s7, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 4 ), (float8)_brow.s8, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 4 ), (float8)_brow.s9, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 5 ), (float8)_brow.sa, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 5 ), (float8)_brow.sb, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 6 ), (float8)_brow.sc, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 6 ), (float8)_brow.sd, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow0, 7 ), (float8)_brow.se, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( intel_sub_group_shuffle( arow1, 7 ), (float8)_brow.sf, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (float8)_brow.s0, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (float8)_brow.s1, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (float8)_brow.s2, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (float8)_brow.s3, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (float8)_brow.s4, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (float8)_brow.s5, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (float8)_brow.s6, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (float8)_brow.s7, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (float8)_brow.s8, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (float8)_brow.s9, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (float8)_brow.sa, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (float8)_brow.sb, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (float8)_brow.sc, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (float8)_brow.sd, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (float8)_brow.se, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (float8)_brow.sf, _dot ) : _dot; \ int w = start_index; MM_DOT_PRODUCT( brow0, dot0 ); @@ -1864,74 +2335,71 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( if(global_x * 4 < N && global_y * 8 < M) { if(mad24(global_x, 4, 3) < N) { - __global float4 *dst_write = (__global float4 *)dst_write0; - dst_write[0] = dot00; dst_write0 += N; dst_write = (__global float4 *)dst_write0; - if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + vstore4(dot00, 0, dst_write0); dst_write0 += N; + if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06; dst_write0 += N; dst_write = (__global float4 *)dst_write0; } + if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07; } + if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); } } else if(mad24(global_x, 4, 2) < N) { - __global float2 *dst_write = (__global float2 *)dst_write0; - dst_write[0] = dot00.xy; dst_write0[2] = dot00.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot00.xy, 0, dst_write0); dst_write0[2] = dot00.z; + dst_write0 += N; if(mad24(global_y, 8, 1) < M) { - dst_write[0] = dot01.xy; dst_write0[2] = dot01.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot01.xy, 0, dst_write0); dst_write0[2] = dot01.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 2) < M) { - dst_write[0] = dot02.xy; dst_write0[2] = dot02.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot02.xy, 0, dst_write0); dst_write0[2] = dot02.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 3) < M) { - dst_write[0] = dot03.xy; dst_write0[2] = dot03.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot03.xy, 0, dst_write0); dst_write0[2] = dot03.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 4) < M) { - dst_write[0] = dot04.xy; dst_write0[2] = dot04.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot04.xy, 0, dst_write0); dst_write0[2] = dot04.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 5) < M) { - dst_write[0] = dot05.xy; dst_write0[2] = dot05.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot05.xy, 0, dst_write0); dst_write0[2] = dot05.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 6) < M) { - dst_write[0] = dot06.xy; dst_write0[2] = dot06.z; - dst_write0 += N; dst_write = (__global float2 *)dst_write0; + vstore2(dot06.xy, 0, dst_write0); dst_write0[2] = dot06.z; + dst_write0 += N; } else return; if(mad24(global_y, 8, 7) < M) { - dst_write[0] = dot07.xy; dst_write0[2] = dot07.z; + vstore2(dot07.xy, 0, dst_write0); dst_write0[2] = dot07.z; } } else if(mad24(global_x, 4, 1) < N) { - __global float2 *dst_write = (__global float2 *)dst_write0; - dst_write[0] = dot00.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; - if(mad24(global_y, 8, 1) < M) { dst_write[0] = dot01.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + vstore2(dot00.xy, 0, dst_write0); dst_write0 += N; + if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 2) < M) { dst_write[0] = dot02.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 3) < M) { dst_write[0] = dot03.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 4) < M) { dst_write[0] = dot04.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 5) < M) { dst_write[0] = dot05.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 6) < M) { dst_write[0] = dot06.xy; dst_write0 += N; dst_write = (__global float2 *)dst_write0; } + if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; } else return; - if(mad24(global_y, 8, 7) < M) { dst_write[0] = dot07.xy; } + if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); } } else { dst_write0[0] = dot00.x; dst_write0 += N; if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; } @@ -1956,3 +2424,5 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( #undef TILE_M #undef TILE_K #undef TILE_N + +#endif diff --git a/src/caffe/greentea/greentea_math_functions.cpp b/src/caffe/greentea/greentea_math_functions.cpp index 60739c2cec0..1db968d730b 100644 --- a/src/caffe/greentea/greentea_math_functions.cpp +++ b/src/caffe/greentea/greentea_math_functions.cpp @@ -176,712 +176,14 @@ template void greentea_copy(const int_tp N, const cl_mem X, const int_tp offY, viennacl::ocl::context *ctx); -struct gemm_callback_arg { - std::vector evs; - std::vector imgs; -}; - -static void CL_CALLBACK gemm_callback (cl_event event, - cl_int event_command_exec_status, - void *user_data) { - struct gemm_callback_arg *arg = (struct gemm_callback_arg *) user_data; - for(int i = 0; i < arg->evs.size(); i++) { - clReleaseEvent(arg->evs[i]); - } - - for(int i = 0; i < arg->imgs.size(); i++) { - clReleaseMemObject(arg->imgs[i]); - } - delete arg; -} - -// Create and copy buffer to image for GEMM's matrix A and B. -// Will return image to caller if the input image is NULL. Otherwise, -// will use the image directly. It's caller's responsibility to -// release the created image. -void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, - cl_mem *image, cl_mem buffer, int offset, - bool is_matrix_a, bool transpose, - bool padding, int padded_height, - int padded_width, int height, - int width, int wait_list_size, - cl_event *wait_list, - cl_event *event) { - - viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); - viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) - ->program(); - cl_image_desc desc; - cl_image_format format; - - memset(&desc, 0, sizeof(desc)); - if (!is_matrix_a && transpose) { - // For matrix B with transpose, we need to handle them differently. - // As we can't use the sub group block read to get a row easily, - // we have to use CL_FLOAT type with read_imagef to get the row. - cl_int err; - format.image_channel_data_type = CL_FLOAT; - desc.image_type = CL_MEM_OBJECT_IMAGE2D; - if ( width % 4 == 0 ) { - desc.image_width = width / 4; - format.image_channel_order = CL_RGBA; - } else { - desc.image_width = width; - format.image_channel_order = CL_R; - } - desc.image_height = height; - // if (offB == 0 && (desc.image_width % 4) == 0 && N > 8 && K > 8) - // desc.mem_object = buffer; - if (*image == NULL) { - *image = clCreateImage( - ctx.handle().get(), - CL_MEM_READ_WRITE, - &format, - &desc, - NULL, - &err); - OCL_CHECK(err); - } - // if (!desc.mem_object) { - size_t origin[] = {0, 0, 0}; - size_t region[] = {(size_t)desc.image_width, - (size_t)desc.image_height, 1}; - OCL_CHECK(clEnqueueCopyBufferToImage(ctx.get_queue().handle().get(), - buffer, *image, sizeof(float) * offset, - origin, region, wait_list_size, - wait_list, event)); - // } - return; - } - - if (*image == NULL) { - desc.image_type = CL_MEM_OBJECT_IMAGE2D; - format.image_channel_data_type = CL_UNSIGNED_INT8; - format.image_channel_order = CL_RGBA; - if (!padding) { - //if (width % 4 == 0 && offset == 0 && height > 8 && width > 8) - // desc.buffer = buffer; - desc.image_width = width; - desc.image_height = height; - } else { - desc.image_width = padded_width; - desc.image_height = padded_height; - } - cl_int err; - *image = clCreateImage(ctx.handle().get(), - desc.buffer ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE, - &format, - &desc, - NULL, - &err); - OCL_CHECK(err); - } - if (!padding && desc.buffer != NULL) - return; - if (!padding && desc.buffer == NULL) { - // copy without padding. - size_t origin[] = {0, 0, 0}; - size_t region[] = {(size_t)width, (size_t)height, 1}; - OCL_CHECK(clEnqueueCopyBufferToImage(ctx.get_queue().handle().get(), - buffer, *image, sizeof(float) * offset, - origin, region, wait_list_size, wait_list, event)); - return; - } - viennacl::ocl::kernel &oclk_gemm_copy = program.get_kernel( - "gemm_buffer_copy_image_float"); - - size_t global_copy[2]; - global_copy[0] = padding ? padded_width : width; - global_copy[1] = padding ? padded_height : height; - oclk_gemm_copy.arg(0, WrapHandle(buffer, &ctx)); - oclk_gemm_copy.arg(1, WrapHandle(*image, &ctx)); - oclk_gemm_copy.arg(2, offset); - oclk_gemm_copy.arg(3, width); - oclk_gemm_copy.arg(4, height); - OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), - oclk_gemm_copy.handle().get(), - 2, NULL, global_copy, NULL, - wait_list_size, wait_list, - event)); -} - -// #define GEMM_PROFILING -#ifdef GEMM_PROFILING -#define START_TIMER(n) \ - clFinish(ctx.get_queue().handle().get()); \ - gettimeofday(&start[n], NULL); - -#define STOP_TIMER(n) \ - clFinish(ctx.get_queue().handle().get()); \ - gettimeofday(&end[n], NULL); -#else -#define START_TIMER(n) -#define STOP_TIMER(n) -#endif - -enum gemm_type_t { - GEMM_TYPE_NONE = 0, - GEMM_TYPE_CLBLAS, - GEMM_TYPE_CLBLAST, - GEMM_TYPE_VIENNACL, - GEMM_TYPE_FAST_IMAGE_32_1, - GEMM_TYPE_FAST_IMAGE_32_2, - GEMM_TYPE_FAST_BUFFER, - GEMM_TYPE_MAX -}; - -static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int_tp M, - const int_tp N, const int_tp K, const float alpha, - const cl_mem A, const int_tp offA, const cl_mem B, - const int_tp offB, const float beta, cl_mem C, - const int_tp offC, bool is_image_a, bool is_image_b, - enum gemm_type_t gemm_type) { - CHECK_EQ(gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 - || gemm_type == GEMM_TYPE_FAST_IMAGE_32_2, true) - << "Invalid fast image gemm type." << std::endl; - if (is_image_a) - CHECK_EQ(offA, 0) << "Invalid input image offset." << std::endl; - - if (is_image_b) - CHECK_EQ(offB, 0) << "Invalid input image offset." << std::endl; - - #ifdef GEMM_PROFILING - struct timeval start[4], end[4]; - for(int i = 0; i < 4; i++) - start[i] = end[i]; - #endif - uint32_t widthA = (TransA == CblasNoTrans) ? K : M; - uint32_t heightA = (TransA == CblasNoTrans) ? M : K; - uint32_t widthB = (TransB == CblasNoTrans) ? N : K; - uint32_t heightB = (TransB == CblasNoTrans) ? K : N; - // To fix the edge problem casued by the sub group block read. - // we have to pad the image if it's not multiple of tile. - // just padding one line is enough as the sub group block read - // will clamp to edge according to the spec. - uint32_t padded_k = K + ((K & 7) ? 1 : 0); - uint32_t imageA_w = (TransA == CblasNoTrans) ? padded_k : M; - uint32_t imageA_h = (TransA == CblasNoTrans) ? M : padded_k; - uint32_t imageB_w = (TransB == CblasNoTrans) ? N : padded_k; - uint32_t imageB_h = (TransB == CblasNoTrans) ? padded_k : N; - viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); - viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) - ->program(); - - cl_mem ImA = NULL; - cl_mem ImB = NULL; - - cl_event ev[5]; - cl_uint ev_idx = 0; - memset(ev, 0, sizeof(cl_event) * 5); - struct gemm_callback_arg * arg = new gemm_callback_arg; - if (TransB == CblasNoTrans) { - bool padding_A = false; - bool padding_B = false; - - if (!is_image_a && !is_image_b) { - if (M * K < N * K) - padding_B = true; - else - padding_A = true; - } - - START_TIMER(0); - if (!is_image_a) { - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, offA, - true, TransA != CblasNoTrans, - padding_A, imageA_h, imageA_w, - heightA, widthA, 0, NULL, &ev[ev_idx]); - if (ev[ev_idx] != NULL) - ev_idx++; - } - - STOP_TIMER(0); - START_TIMER(1); - - if (!is_image_b) { - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, offB, - false, false, - padding_B, imageB_h, imageB_w, - heightB, widthB, 0, NULL, &ev[ev_idx]); - if (ev[ev_idx] != NULL) - ev_idx++; - } - STOP_TIMER(1); - } else { - // We will use normal read_imagef to read image B when B has transpose. - // thus we don't need to pad image A at all. - START_TIMER(2); - if (!is_image_a) { - bool padding; - padding = !is_image_b; - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, offA, - true, TransA != CblasNoTrans, - padding, imageA_h, imageA_w, - heightA, widthA, 0, NULL, &ev[ev_idx]); - if (ev[ev_idx] != NULL) - ev_idx++; - } - STOP_TIMER(2); - } - if (!is_image_a) - arg->imgs.push_back(ImA); - else - ImA = A; - if (!is_image_b) - arg->imgs.push_back(ImB); - else - ImB = B; - - viennacl::ocl::kernel *oclk_gemm_float; - std::string kernel_name("gemm_"); - if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1) - kernel_name += "32_1_"; - else - kernel_name += "32_2_"; - - if (TransA == CblasNoTrans) - kernel_name += "N"; - else - kernel_name += "T"; - - if (TransB == CblasNoTrans) - kernel_name += "N_"; - else { - kernel_name += "T_"; - if (is_image_b) { - if (K % 4 == 0) - kernel_name += "VEC4_"; - else - kernel_name += "SCALAR_"; - } else { - kernel_name += "BUFFER_"; - } - } - - if (alpha == 1) - kernel_name += "1_"; - else - kernel_name += "0_"; - - if (beta == 0) - kernel_name += "0"; - else - kernel_name += "1"; - kernel_name += "_float"; - - oclk_gemm_float = &program.get_kernel(kernel_name); - - size_t global[2]; - - if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1) - global[0] = (size_t)( N + 7 ) & ~7; - else - global[0] = (size_t)( (N / 2 ) + 7 ) ^ ~7; - - global[1] = (size_t)(M + 31) / 32; - const size_t local[] = {8, 1}; - - cl_uint arg_idx = 0; - oclk_gemm_float->arg(arg_idx++, WrapHandle(ImA, &ctx)); - if (TransB == CblasNoTrans || is_image_b) - oclk_gemm_float->arg(arg_idx++, WrapHandle(ImB, &ctx)); - else { - oclk_gemm_float->arg(arg_idx++, WrapHandle(B, &ctx)); - oclk_gemm_float->arg(arg_idx++, offB); - } - oclk_gemm_float->arg(arg_idx++, WrapHandle(C, &ctx)); - oclk_gemm_float->arg(arg_idx++, offC); - oclk_gemm_float->arg(arg_idx++, M); - oclk_gemm_float->arg(arg_idx++, N); - oclk_gemm_float->arg(arg_idx++, alpha); - oclk_gemm_float->arg(arg_idx++, beta); - oclk_gemm_float->arg(arg_idx++, padded_k); - if (TransB != CblasNoTrans) - oclk_gemm_float->arg(arg_idx++, K); - - cl_event *wait_list = NULL; - if (ev_idx != 0) - wait_list = &ev[0]; - START_TIMER(3); - OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), - oclk_gemm_float->handle().get(), 2, NULL, - global, local, ev_idx, - wait_list, &ev[ev_idx])); - STOP_TIMER(3); - #ifdef GEMM_PROFILING - double elapsed[4], total_elapsed; - for( int i = 0; i < 4; i++ ) { - elapsed[i] = (end[i].tv_sec - start[i].tv_sec) * 1e6 + (end[i].tv_usec - start[i].tv_usec); - total_elapsed += elapsed[i]; - } - printf("kernel name %s \n", kernel_name.c_str()); - printf("gemm %d %d %d %f %f %d %d %f %f %f %f %fGFLOPS %f GFLOPS\n", - M, K, N, alpha, beta, TransA == CblasNoTrans, TransB == CblasNoTrans, - elapsed[0] / 1000., elapsed[1] / 1000., elapsed[2] / 1000., - elapsed[3] / 1000., - M * N * ( 2*K - 1. ) / ( elapsed[3] * 1e3 ), - M * N * ( 2 * K - 1.) / ( total_elapsed * 1e3 ) ); - #endif - arg->evs.assign(ev, ev + ev_idx + 1); - clSetEventCallback(ev[ev_idx], CL_COMPLETE, &gemm_callback, (void*)arg); -} - -static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int_tp M, - const int_tp N, const int_tp K, const float alpha, - const cl_mem A, const int_tp offA, const cl_mem B, - const int_tp offB, const float beta, cl_mem C, - const int_tp offC, enum gemm_type_t gemm_type) { - CHECK_EQ(gemm_type == GEMM_TYPE_FAST_BUFFER, true) - << "Invalid fast buffer gemm type." << std::endl; - -#ifdef GEMM_PROFILING - struct timeval start[1], end[1]; - start[0] = end[0]; -#endif - - cl_event ev; - - viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); - viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) - ->program(); - size_t sub_group_size = 8; - bool is_small_batch = (M == 2 || M == 4 || M == 8); - viennacl::ocl::kernel *oclk_gemm_float; - std::string kernel_name("gemm_buffer_"); - if(TransA == CblasNoTrans && TransB == CblasNoTrans) { - kernel_name += "NN_float"; - } else if(TransA == CblasNoTrans && TransB != CblasNoTrans) { - if (M == 2) - kernel_name +="NT_M_2_float"; - else if (M == 4) - kernel_name +="NT_M_4_float"; - else if (M == 8) - kernel_name +="NT_M_8_float"; - else - kernel_name += "NT_float"; - } else if(TransA != CblasNoTrans && TransB == CblasNoTrans) { - kernel_name += "TN_float"; - } else { - kernel_name += "TT_float"; - } - oclk_gemm_float = &program.get_kernel(kernel_name); - size_t local[2] = {}; - size_t global[2] = {}; - if (TransA == CblasNoTrans && TransB != CblasNoTrans && is_small_batch ) { - if(M == 8) - local[0] = 16; - else if(M == 4) - local[0] = 32; - else - local[0] = 64; - local[1] = 1; - - if(M == 8) - global[0] = N * local[0]; - else - global[0] = (N + 3) / 4 * local[0]; - global[1] = 1; - } else { - size_t lx = sub_group_size; - size_t ly = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 16 : 4; - int dx = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 1 : 4; - int dy = 8; - size_t gx = (size_t)(N + dx - 1) / dx; - size_t gy = (size_t)(M + dy - 1) / dy; - global[0] = (gx + lx - 1) / lx * lx; - global[1] = (gy + ly - 1) / ly * ly; - local[0] = lx; - local[1] = ly; - } - - cl_uint arg_idx = 0; - oclk_gemm_float->arg(arg_idx++, WrapHandle(A, &ctx)); - oclk_gemm_float->arg(arg_idx++, offA); - oclk_gemm_float->arg(arg_idx++, WrapHandle(B, &ctx)); - oclk_gemm_float->arg(arg_idx++, offB); - oclk_gemm_float->arg(arg_idx++, WrapHandle(C, &ctx)); - oclk_gemm_float->arg(arg_idx++, offC); - oclk_gemm_float->arg(arg_idx++, M); - oclk_gemm_float->arg(arg_idx++, N); - oclk_gemm_float->arg(arg_idx++, K); - oclk_gemm_float->arg(arg_idx++, alpha); - oclk_gemm_float->arg(arg_idx++, beta); - - START_TIMER(0); - if(TransB == CblasNoTrans || TransA != CblasNoTrans) { - int stride = 256; - for(int start_index = 0; start_index < K; start_index += stride) { - oclk_gemm_float->arg(arg_idx, start_index); - OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), - oclk_gemm_float->handle().get(), 2, NULL, - global, local, 0, - NULL, &ev)); - } - } else { - OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), - oclk_gemm_float->handle().get(), 2, NULL, - global, local, 0, - NULL, &ev)); - } - STOP_TIMER(0); - clReleaseEvent(ev); - -#ifdef GEMM_PROFILING - double total_elapsed; - total_elapsed = (end[0].tv_sec - start[0].tv_sec) * 1e6 + (end[0].tv_usec - start[0].tv_usec); - printf("kernel name %s \n", kernel_name.c_str()); - printf("gemm %d %d %d %f %f %d %d %f %fGFLOPS\n", - M, K, N, alpha, beta, TransA == CblasNoTrans, TransB == CblasNoTrans, - total_elapsed / 1000., M * N * ( 2 * K - 1.) / ( total_elapsed * 1e3 ) ); -#endif -} - -template -static void greentea_gpu_gemm_common(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int_tp M, - const int_tp N, const int_tp K, const Dtype alpha, - const cl_mem A, const int_tp offA, const cl_mem B, - const int_tp offB, const Dtype beta, cl_mem C, - const int_tp offC, bool is_image_a, bool is_image_b, - gemm_type_t gemm_type) { - - viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); - int_tp lda = (TransA == CblasNoTrans) ? K : M; - int_tp ldb = (TransB == CblasNoTrans) ? N : K; - int_tp ldc = N; - - if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || - gemm_type == GEMM_TYPE_FAST_IMAGE_32_2) { - greentea_gpu_fast_image_gemm(ctx_id, TransA, TransB, M, N, K, - alpha, A, offA, B, offB, beta, C, - offC, is_image_a, is_image_b, - gemm_type); - } else if (gemm_type == GEMM_TYPE_FAST_BUFFER) { - greentea_gpu_fast_buffer_gemm(ctx_id, TransA, TransB, M, N, K, - alpha, A, offA, B, offB, beta, C, - offC, gemm_type); - } else if (gemm_type == GEMM_TYPE_CLBLAS) { - #if defined(USE_CLBLAS) - if ((M == 2 || M == 4 || M == 8) && std::is_same::value - && TransA == CblasNoTrans && TransB != CblasNoTrans) { - greentea_gpu_fast_buffer_gemm(ctx_id, TransA, TransB, M, N, K, - alpha, A, offA, B, offB, beta, C, - offC, GEMM_TYPE_FAST_BUFFER); - } else { - clblasOrder clOrder = clblasRowMajor; - clblasTranspose clTransA = - (TransA == CblasNoTrans) ? clblasNoTrans : clblasTrans; - clblasTranspose clTransB = - (TransB == CblasNoTrans) ? clblasNoTrans : clblasTrans; - - cl_command_queue queue = ctx.get_queue().handle().get(); - - if (std::is_same::value) { - GREENTEA_CL_BLAS_CHECK( - clblasSgemm(clOrder, clTransA, clTransB, - M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, - C, offC, ldc, 1, &queue, 0, NULL, NULL)); - } else { - GREENTEA_CL_BLAS_CHECK( - clblasDgemm(clOrder, clTransA, clTransB, - M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, - C, offC, ldc, 1, &queue, 0, NULL, NULL)); - } - } - #endif - } else if (gemm_type == GEMM_TYPE_CLBLAST) { - #ifdef USE_CLBLAST - cl_command_queue queue = ctx.get_queue().handle().get(); - - clblast::Layout layout = clblast::Layout::kRowMajor; - clblast::Transpose a_transpose = (TransA == CblasNoTrans) ? - clblast::Transpose::kNo : clblast::Transpose::kYes; - clblast::Transpose b_transpose = (TransB == CblasNoTrans) ? - clblast::Transpose::kNo : clblast::Transpose::kYes; - - if (std::is_same::value) { - GREENTEA_CLBLAST_CHECK( - clblast::Gemm( - layout, a_transpose, b_transpose, - M, N, K, - alpha, - A, offA, lda, - B, offB, ldb, - beta, - C, offC, ldc, - &queue)); - } else { - GREENTEA_CLBLAST_CHECK( - clblast::Gemm( - layout, a_transpose, b_transpose, - M, N, K, - alpha, - A, offA, lda, - B, offB, ldb, - beta, - C, offC, ldc, - &queue)); - } - #endif - } else if (gemm_type == GEMM_TYPE_VIENNACL) { - typedef typename viennacl::matrix_base::size_type size_type; - typedef typename viennacl::matrix_base::size_type difference_type; - - size_type A_size1 = static_cast((TransA == CblasTrans) ? K : M); - size_type A_size2 = static_cast((TransA == CblasTrans) ? M : K); - - size_type B_size1 = static_cast((TransB == CblasTrans) ? N : K); - size_type B_size2 = static_cast((TransB == CblasTrans) ? K : N); - - viennacl::matrix_base matA(A, ctx, A_size1, - size_type(0), - difference_type(1), - size_type(M), A_size2, - size_type(offA), - difference_type(1), - size_type(lda) - VCL_ROW_MAJOR); - - viennacl::matrix_base matB(B, ctx, B_size1, - size_type(0), - difference_type(1), - size_type(K), B_size2, - size_type(offB), - difference_type(1), - size_type(ldb) - VCL_ROW_MAJOR); - - viennacl::matrix_base matC(C, ctx, size_type(M), - size_type(0), - difference_type(1), - size_type(M), - size_type(N), - size_type(offC), - difference_type(1), - size_type(ldc) - VCL_ROW_MAJOR); - - if (TransA == CblasTrans && TransB == CblasTrans) - viennacl::linalg::prod_impl(viennacl::trans(matA), viennacl::trans(matB), - matC, alpha, beta); - else if (TransA == CblasTrans && TransB == CblasNoTrans) - viennacl::linalg::prod_impl(viennacl::trans(matA), matB, matC, alpha, - beta); - else if (TransA == CblasNoTrans && TransB == CblasTrans) - viennacl::linalg::prod_impl(matA, viennacl::trans(matB), matC, alpha, - beta); - else if (TransA == CblasNoTrans && TransB == CblasNoTrans) - viennacl::linalg::prod_impl(matA, matB, matC, alpha, beta); - } -} - -static void auto_tune_gemm(int ctx_id, const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - gemm_type_t *tuned_gemm_types, - bool use_fast_gemm_image) { - viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); - int M = 1024; - int K = 512; - int N = 1024; - cl_int err; - cl_mem A = clCreateBuffer(ctx.handle().get(), CL_MEM_ALLOC_HOST_PTR, M * K * sizeof(float), NULL, &err); - OCL_CHECK(err); - cl_mem B = clCreateBuffer(ctx.handle().get(), CL_MEM_ALLOC_HOST_PTR, K * N * sizeof(float), NULL, &err); - OCL_CHECK(err); - cl_mem C = clCreateBuffer(ctx.handle().get(), CL_MEM_ALLOC_HOST_PTR, M * N * sizeof(float), NULL, &err); - OCL_CHECK(err); - - std::vector gemm_tests; - - gemm_tests.push_back(GEMM_TYPE_VIENNACL); - if(use_fast_gemm_image) - gemm_tests.push_back(GEMM_TYPE_FAST_IMAGE_32_1); - gemm_tests.push_back(GEMM_TYPE_FAST_BUFFER); - -#ifdef USE_CLBLAS - gemm_tests.push_back(GEMM_TYPE_CLBLAS); -#endif -#ifdef USE_CLBLAST - gemm_tests.push_back(GEMM_TYPE_CLBLAST); -#endif - // warm up. - for( int i = 0; i < gemm_tests.size(); i++ ) { - greentea_gpu_gemm_common(ctx_id, TransA, TransB, M, N, K, - 1.0f, A, 0, B, 0, 0.0f, C, 0, false, false, - gemm_tests[i]); - } - float fastest_time = 1e10; - int fastest_index = -1; - clFinish(ctx.get_queue().handle().get()); - for( int i = 0; i < gemm_tests.size(); i++ ) { - struct timeval start, end; - gettimeofday(&start, NULL); - greentea_gpu_gemm_common(ctx_id, TransA, TransB, M, N, K, - 1.0f, A, 0, B, 0, 0.0f, C, 0, false, false, - gemm_tests[i]); - clFinish(ctx.get_queue().handle().get()); - gettimeofday(&end, NULL); - float elapsed = (end.tv_sec - start.tv_sec) * 1e6 + (end.tv_usec - start.tv_usec); - if (elapsed < fastest_time) { - fastest_time = elapsed; - fastest_index = i; - } - } - clReleaseMemObject(A); - clReleaseMemObject(B); - clReleaseMemObject(C); - - if (fastest_index >= 0) { - tuned_gemm_types[ctx_id] = gemm_tests[fastest_index]; -#ifdef GEMM_PROFILING - printf("The tuned GEMM kernel get %f GFLOPS with kernel type %d.\n", - M*N*(2*(double)K-1)/(fastest_time * 1e3), - tuned_gemm_types[ctx_id]); -#endif - } -} - -static gemm_type_t tuned_gemm_nn_types_with_image[16] = {GEMM_TYPE_NONE}; -static gemm_type_t tuned_gemm_nt_types_with_image[16] = {GEMM_TYPE_NONE}; -static gemm_type_t tuned_gemm_tn_types_with_image[16] = {GEMM_TYPE_NONE}; -static gemm_type_t tuned_gemm_tt_types_with_image[16] = {GEMM_TYPE_NONE}; - -static gemm_type_t tuned_gemm_nn_types_without_image[16] = {GEMM_TYPE_NONE}; -static gemm_type_t tuned_gemm_nt_types_without_image[16] = {GEMM_TYPE_NONE}; -static gemm_type_t tuned_gemm_tn_types_without_image[16] = {GEMM_TYPE_NONE}; -static gemm_type_t tuned_gemm_tt_types_without_image[16] = {GEMM_TYPE_NONE}; - -static void auto_tune_gemm_all(int ctx_id, bool use_fast_gemm_image) { - if(use_fast_gemm_image) { - auto_tune_gemm(ctx_id, CblasNoTrans, CblasNoTrans, tuned_gemm_nn_types_with_image, true); - auto_tune_gemm(ctx_id, CblasNoTrans, CblasTrans, tuned_gemm_nt_types_with_image, true); - auto_tune_gemm(ctx_id, CblasTrans, CblasNoTrans, tuned_gemm_tn_types_with_image, true); - auto_tune_gemm(ctx_id, CblasTrans, CblasTrans, tuned_gemm_tt_types_with_image, true); - } else { - auto_tune_gemm(ctx_id, CblasNoTrans, CblasNoTrans, tuned_gemm_nn_types_without_image, false); - auto_tune_gemm(ctx_id, CblasNoTrans, CblasTrans, tuned_gemm_nt_types_without_image, false); - auto_tune_gemm(ctx_id, CblasTrans, CblasNoTrans, tuned_gemm_tn_types_without_image, false); - auto_tune_gemm(ctx_id, CblasTrans, CblasTrans, tuned_gemm_tt_types_without_image, false); - } -} - -static boost::mutex auto_tune_gemm_mutex; - template void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int_tp M, const int_tp N, const int_tp K, const Dtype alpha, const cl_mem A, const int_tp offA, const cl_mem B, const int_tp offB, const Dtype beta, cl_mem C, - const int_tp offC, bool is_image_a, bool is_image_b) { - CHECK_LT(ctx_id, 16) << "Too many GPU devices."; + const int_tp offC) { viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); - bool use_fast_gemm_image = false; - bool use_fast_gemm_buffer = false; if (ctx.devices()[0].type() == CL_DEVICE_TYPE_CPU) { Dtype* Aptr = reinterpret_cast(clEnqueueMapBuffer( @@ -903,80 +205,121 @@ void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, NULL); clEnqueueUnmapMemObject(ctx.get_queue().handle().get(), C, Cptr, 0, NULL, NULL); - return; - } + } else { + int_tp lda = (TransA == CblasNoTrans) ? K : M; + int_tp ldb = (TransB == CblasNoTrans) ? N : K; + int_tp ldc = N; + +#if defined(USE_CLBLAS) - if (ctx.devices()[0].type() == CL_DEVICE_TYPE_GPU && - std::is_same::value) { - // Check whether can/should we use the fast gemm driver. - // There are the following considerations/restrications: - // 1. The fast gemm kernel is using image which has a size limitation. - // 2. The fast gemm kernel is using the intel sub group extension. - // 3. Currently, only the IGC compiler (the driver version is 16.xxx) - // can get better performance with the fast gemm. - // Cap at 1 MB to capture faulty OpenCL implementations (nVidia) - bool has_sub_group_ext = ctx.devices()[0].extensions().find("cl_intel_subgroups") - != std::string::npos; - if (has_sub_group_ext) { - size_t max_image_size = std::min(ctx.devices()[0].image2d_max_width(), - ctx.devices()[0].image2d_max_height()); - if (M <= max_image_size && - K <= max_image_size && - N <= max_image_size) { - use_fast_gemm_image = true; - } - use_fast_gemm_buffer = true; + clblasOrder clOrder = clblasRowMajor; + clblasTranspose clTransA = + (TransA == CblasNoTrans) ? clblasNoTrans : clblasTrans; + clblasTranspose clTransB = + (TransB == CblasNoTrans) ? clblasNoTrans : clblasTrans; + + cl_command_queue queue = ctx.get_queue().handle().get(); + if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK( + clblasSgemm(clOrder, clTransA, clTransB, + M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, + C, offC, ldc, 1, &queue, 0, NULL, NULL)); + } else { + GREENTEA_CL_BLAS_CHECK( + clblasDgemm(clOrder, clTransA, clTransB, + M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, + C, offC, ldc, 1, &queue, 0, NULL, NULL)); } - } - gemm_type_t preferred_gemm_type = GEMM_TYPE_VIENNACL; -#ifdef USE_CLBLAS - preferred_gemm_type = GEMM_TYPE_CLBLAS; -#endif -#ifdef USE_CLBLAST - preferred_gemm_type = GEMM_TYPE_CLBLAST; -#endif +#elif defined(USE_CLBLAST) + + cl_command_queue queue = ctx.get_queue().handle().get(); + + clblast::Layout layout = clblast::Layout::kRowMajor; + clblast::Transpose a_transpose = (TransA == CblasNoTrans) ? + clblast::Transpose::kNo : clblast::Transpose::kYes; + clblast::Transpose b_transpose = (TransB == CblasNoTrans) ? + clblast::Transpose::kNo : clblast::Transpose::kYes; - { - boost::mutex::scoped_lock lock(auto_tune_gemm_mutex); - if(use_fast_gemm_image) { - if (tuned_gemm_nn_types_with_image[ctx_id] == GEMM_TYPE_NONE) { - auto_tune_gemm_all(ctx_id, true); - } - - if (TransA == CblasNoTrans && TransB == CblasNoTrans) - preferred_gemm_type = tuned_gemm_nn_types_with_image[ctx_id]; - else if (TransA == CblasTrans && TransB == CblasNoTrans) - preferred_gemm_type = tuned_gemm_tn_types_with_image[ctx_id]; - else if (TransA == CblasNoTrans && TransB == CblasTrans) - preferred_gemm_type = tuned_gemm_nt_types_with_image[ctx_id]; - else if (TransA == CblasTrans && TransB == CblasTrans) - preferred_gemm_type = tuned_gemm_tt_types_with_image[ctx_id]; - } else if(use_fast_gemm_buffer) { - if (tuned_gemm_nn_types_without_image[ctx_id] == GEMM_TYPE_NONE) { - auto_tune_gemm_all(ctx_id, false); - } - - if (TransA == CblasNoTrans && TransB == CblasNoTrans) - preferred_gemm_type = tuned_gemm_nn_types_without_image[ctx_id]; - else if (TransA == CblasTrans && TransB == CblasNoTrans) - preferred_gemm_type = tuned_gemm_tn_types_without_image[ctx_id]; - else if (TransA == CblasNoTrans && TransB == CblasTrans) - preferred_gemm_type = tuned_gemm_nt_types_without_image[ctx_id]; - else if (TransA == CblasTrans && TransB == CblasTrans) - preferred_gemm_type = tuned_gemm_tt_types_without_image[ctx_id]; + if (std::is_same::value) { + GREENTEA_CLBLAST_CHECK( + clblast::Gemm( + layout, a_transpose, b_transpose, + M, N, K, + alpha, + A, offA, lda, + B, offB, ldb, + beta, + C, offC, ldc, + &queue)); + } else { + GREENTEA_CLBLAST_CHECK( + clblast::Gemm( + layout, a_transpose, b_transpose, + M, N, K, + alpha, + A, offA, lda, + B, offB, ldb, + beta, + C, offC, ldc, + &queue)); } - } - CHECK_EQ(use_fast_gemm_image || (!is_image_a && !is_image_b), true) - << "Invalid GEMM parameters."; +#else // default (ViennaCL) + + typedef typename viennacl::matrix_base::size_type size_type; + typedef typename viennacl::matrix_base::size_type difference_type; + + size_type A_size1 = static_cast((TransA == CblasTrans) ? K : M); + size_type A_size2 = static_cast((TransA == CblasTrans) ? M : K); + + size_type B_size1 = static_cast((TransB == CblasTrans) ? N : K); + size_type B_size2 = static_cast((TransB == CblasTrans) ? K : N); + + viennacl::matrix_base matA(A, ctx, A_size1, + size_type(0), + difference_type(1), + size_type(M), A_size2, + size_type(offA), + difference_type(1), + size_type(lda) + VCL_ROW_MAJOR); + + viennacl::matrix_base matB(B, ctx, B_size1, + size_type(0), + difference_type(1), + size_type(K), B_size2, + size_type(offB), + difference_type(1), + size_type(ldb) + VCL_ROW_MAJOR); + + viennacl::matrix_base matC(C, ctx, size_type(M), + size_type(0), + difference_type(1), + size_type(M), + size_type(N), + size_type(offC), + difference_type(1), + size_type(ldc) + VCL_ROW_MAJOR); - if (is_image_a || is_image_b) - preferred_gemm_type = GEMM_TYPE_FAST_IMAGE_32_1; + if (TransA == CblasTrans && TransB == CblasTrans) + viennacl::linalg::prod_impl(viennacl::trans(matA), viennacl::trans(matB), + matC, alpha, beta); + else if (TransA == CblasTrans && TransB == CblasNoTrans) + viennacl::linalg::prod_impl(viennacl::trans(matA), matB, matC, alpha, + beta); + else if (TransA == CblasNoTrans && TransB == CblasTrans) + viennacl::linalg::prod_impl(matA, viennacl::trans(matB), matC, alpha, + beta); + else if (TransA == CblasNoTrans && TransB == CblasNoTrans) + viennacl::linalg::prod_impl(matA, matB, matC, alpha, beta); - greentea_gpu_gemm_common(ctx_id, TransA, TransB, M, N, K, alpha, A, offA, - B, offB, beta, C, offC, is_image_a, is_image_b, - preferred_gemm_type); +#endif // clBLAS, CLBlast, or default (ViennaCL) + } } template void greentea_gpu_gemm(const int_tp ctx_id, @@ -987,9 +330,7 @@ template void greentea_gpu_gemm(const int_tp ctx_id, const cl_mem A, const int_tp offA, const cl_mem B, const int_tp offB, const float beta, cl_mem C, - const int_tp offC, - const bool is_image_a = false, - const bool is_image_b = false); + const int_tp offC); template void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, @@ -998,34 +339,7 @@ template void greentea_gpu_gemm(const int_tp ctx_id, const cl_mem A, const int_tp offA, const cl_mem B, const int_tp offB, const double beta, cl_mem C, - const int_tp offC, - const bool is_image_a = false, - const bool is_image_b = false); - -template void greentea_gpu_gemm_common(const int_tp ctx_id, - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int_tp M, const int_tp N, - const int_tp K, const float alpha, - const cl_mem A, const int_tp offA, - const cl_mem B, const int_tp offB, - const float beta, cl_mem C, - const int_tp offC, - const bool is_image_a, - const bool is_image_b, - const gemm_type_t); -template void greentea_gpu_gemm_common(const int_tp ctx_id, - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int_tp M, const int_tp N, - const int_tp K, const double alpha, - const cl_mem A, const int_tp offA, - const cl_mem B, const int_tp offB, - const double beta, cl_mem C, - const int_tp offC, - const bool is_image_a, - const bool is_image_b, - const gemm_type_t); + const int_tp offC); template void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index cef42443b0e..c03fd981e28 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -53,33 +53,7 @@ void InnerProductLayer::LayerSetUp(const vector*>& bottom, } } // parameter initialization this->param_propagate_down_.resize(this->blobs_.size(), true); - - if (this->device_->backend() == BACKEND_OpenCL && this->phase_ == TEST) { - viennacl::ocl::context &ctx = - viennacl::ocl::get_context(this->device_->id()); - size_t max_image_size = std::min(ctx.devices()[0].image2d_max_width(), - ctx.devices()[0].image2d_max_height()); - // For inference only, we can load the weights data to image on Intel platform. - // As image based GEMM is much faster than the buffer based GEMM for most cases. - if (N_ <= max_image_size && - K_ <= max_image_size && - std::is_same::value && - this->device_->CheckCapability("cl_intel_subgroups")) { - const Dtype* weight = this->blobs_[0]->gpu_data(); - int height = !transpose_ ? N_ : K_; - int width = !transpose_ ? K_ : N_; - int padded_height = !transpose_ ? height : (height + ((height & 7) ? 1 : 0)); - int padded_width = !transpose_ ? width : (width + ((width & 7) ? 1 : 0)); - greentea_gpu_gemm_copy_buffer_to_image(this->device_->id(), - &weight_image_, (cl_mem) weight, 0, - false, !transpose_, - true, padded_height, padded_width, - height, width, (int)0, NULL, NULL); - copied_weight_data_ = this->blobs_[0]->data().get(); - } - } else { - copied_weight_data_ = NULL; - } + copied_weight_data_ = NULL; test_only_ = this->phase_ == TEST; } diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index c6e98bec391..90f023b3b1b 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -3,9 +3,702 @@ #include "caffe/filler.hpp" #include "caffe/layers/inner_product_layer.hpp" #include "caffe/util/math_functions.hpp" +#ifdef USE_GREENTEA +#include "viennacl/tools/sha1.hpp" +#include "caffe/util/benchmark.hpp" +#endif namespace caffe { +struct gemm_callback_arg { + std::vector evs; + std::vector imgs; +}; + +static void CL_CALLBACK gemm_callback (cl_event event, + cl_int event_command_exec_status, + void *user_data) { + struct gemm_callback_arg *arg = (struct gemm_callback_arg *) user_data; + for(int i = 0; i < arg->evs.size(); i++) { + clReleaseEvent(arg->evs[i]); + } + + for(int i = 0; i < arg->imgs.size(); i++) { + clReleaseMemObject(arg->imgs[i]); + } + delete arg; +} + +// Create and copy buffer to image for GEMM's matrix A and B. +// Will return image to caller if the input image is NULL. Otherwise, +// will use the image directly. It's caller's responsibility to +// release the created image. +static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, + cl_mem *image, cl_mem buffer, int offset, + bool is_matrix_a, bool transpose, + bool padding, int padded_height, + int padded_width, int height, + int width, int ld, int wait_list_size, + cl_event *wait_list, + cl_event *event) { + + viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) + ->program(); + cl_image_desc desc; + cl_image_format format; + + bool halfPrecisionMode = false; + + memset(&desc, 0, sizeof(desc)); + int src_offset = halfPrecisionMode ? sizeof(unsigned short) * offset : sizeof(float) * offset; + if (!is_matrix_a && transpose) { + // For matrix B with transpose, we need to handle them differently. + // As we can't use the sub group block read to get a row easily, + // we have to use CL_FLOAT type with read_imagef to get the row. + cl_int err; + if(halfPrecisionMode) { + format.image_channel_data_type = CL_HALF_FLOAT; + } else { + format.image_channel_data_type = CL_FLOAT; + } + desc.image_type = CL_MEM_OBJECT_IMAGE2D; + desc.image_width = width; + format.image_channel_order = CL_R; + + desc.image_height = height; + if (*image == NULL) { + *image = clCreateImage( + ctx.handle().get(), + CL_MEM_READ_WRITE, + &format, + &desc, + NULL, + &err); + OCL_CHECK(err); + } + + if(ld == width) { + size_t origin[] = {0, 0, 0}; + size_t region[] = {(size_t)desc.image_width, + (size_t)desc.image_height, 1}; + + OCL_CHECK(clEnqueueCopyBufferToImage(ctx.get_queue().handle().get(), + buffer, *image, src_offset, + origin, region, wait_list_size, + wait_list, event)); + } else { + std::string kernel_name("gemm_buffer_copy_image_transpose_float"); + + viennacl::ocl::kernel &oclk_gemm_copy = program.get_kernel(kernel_name); + + size_t global_copy[2]; + global_copy[0] = width; + global_copy[1] = height; + oclk_gemm_copy.arg(0, WrapHandle(buffer, &ctx)); + oclk_gemm_copy.arg(1, WrapHandle(*image, &ctx)); + oclk_gemm_copy.arg(2, offset); + oclk_gemm_copy.arg(3, width); + oclk_gemm_copy.arg(4, height); + oclk_gemm_copy.arg(5, ld); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_gemm_copy.handle().get(), + 2, NULL, global_copy, NULL, + wait_list_size, wait_list, + event)); + } + } else { + if (*image == NULL) { + desc.image_type = CL_MEM_OBJECT_IMAGE2D; + if(halfPrecisionMode) { + format.image_channel_data_type = CL_HALF_FLOAT; + format.image_channel_order = CL_R; + } else { + format.image_channel_data_type = CL_UNSIGNED_INT8; + format.image_channel_order = CL_RGBA; + } + + if (!padding) { + desc.image_width = width; + desc.image_height = height; + } else { + desc.image_width = padded_width; + desc.image_height = padded_height; + } + cl_int err; + *image = clCreateImage(ctx.handle().get(), + desc.buffer ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE, + &format, + &desc, + NULL, + &err); + OCL_CHECK(err); + } + if (!padding && desc.buffer != NULL) + return; + if (!padding && desc.buffer == NULL) { + // copy without padding. + size_t origin[] = {0, 0, 0}; + size_t region[] = {(size_t)width, (size_t)height, 1}; + OCL_CHECK(clEnqueueCopyBufferToImage(ctx.get_queue().handle().get(), + buffer, *image, src_offset, + origin, region, wait_list_size, wait_list, event)); + } else { + std::string kernel_name("gemm_buffer_copy_image_no_transpose_float"); + + viennacl::ocl::kernel &oclk_gemm_copy = program.get_kernel(kernel_name); + + size_t global_copy[2]; + global_copy[0] = padding ? padded_width : width; + global_copy[1] = padding ? padded_height : height; + oclk_gemm_copy.arg(0, WrapHandle(buffer, &ctx)); + oclk_gemm_copy.arg(1, WrapHandle(*image, &ctx)); + oclk_gemm_copy.arg(2, offset); + oclk_gemm_copy.arg(3, width); + oclk_gemm_copy.arg(4, height); + oclk_gemm_copy.arg(5, ld); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_gemm_copy.handle().get(), + 2, NULL, global_copy, NULL, + wait_list_size, wait_list, + event)); + } + } +} + +static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int_tp M, + const int_tp N, const int_tp K, const float alpha, + const cl_mem A, const int_tp offA, const cl_mem B, + const int_tp offB, const float beta, cl_mem C, + const int_tp offC, bool is_image_a, bool is_image_b, + enum gemm_type_t gemm_type, const size_t max_image_size) { + CHECK_EQ(gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 + || gemm_type == GEMM_TYPE_FAST_IMAGE_32_2 + || gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE, true) + << "Invalid fast image gemm type." << std::endl; + if (is_image_a) + CHECK_EQ(offA, 0) << "Invalid input image offset." << std::endl; + + if (is_image_b) + CHECK_EQ(offB, 0) << "Invalid input image offset." << std::endl; + + bool halfPrecisionMode = false; + int widthA = (TransA == CblasNoTrans) ? K : M; + int heightA = (TransA == CblasNoTrans) ? M : K; + int widthB = (TransB == CblasNoTrans) ? N : K; + int heightB = (TransB == CblasNoTrans) ? K : N; + + int ldA = widthA; + int ldB = widthB; + int ldC = N; + + int A_start_x = 0, A_start_y = 0, B_start_x = 0, B_start_y = 0, C_start_x = 0, C_start_y = 0; + int blocksize = 1024; + if (gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) + blocksize = max_image_size; + int blockA_width = blocksize; + int blockA_height = blocksize; + int blockB_width = blocksize; + int blockB_height = blocksize; + int blockC_width = blocksize; + int blockC_height = blocksize; + + int use_buffer_indicator = halfPrecisionMode ? 16 : 8; + // To fix the edge problem casued by the sub group block read. + // we have to pad the image if it's not multiple of tile. + // just padding one line is enough as the sub group block read + // will clamp to edge according to the spec. + + viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) + ->program(); + + cl_mem ImA = NULL; + cl_mem ImB = NULL; + + viennacl::ocl::kernel *oclk_gemm_float; + std::string kernel_name("gemm_"); + if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 + || gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) + kernel_name += "32_1_"; + else + kernel_name += "32_2_"; + + if (TransA == CblasNoTrans) + kernel_name += "N"; + else + kernel_name += "T"; + + if (TransB == CblasNoTrans) + kernel_name += "N_"; + else { + kernel_name += "T_"; + if (is_image_b || (K % use_buffer_indicator != 0)) { + kernel_name += "SCALAR_"; + } else { + kernel_name += "BUFFER_"; + } + } + + if (alpha == 1) + kernel_name += "1_"; + else + kernel_name += "0_"; + + if (beta == 0) + kernel_name += "0"; + else + kernel_name += "1"; + kernel_name += "_float"; + + oclk_gemm_float = &program.get_kernel(kernel_name); + while(C_start_y < M) { + blockC_width = std::min((int)N - C_start_x, blocksize); + blockC_height = std::min((int)M - C_start_y, blocksize); + + int isFirstColBlock = 1; + for(int k = 0; k < K; k += blocksize) { + cl_event ev[5]; + cl_uint ev_idx = 0; + memset(ev, 0, sizeof(cl_event) * 5); + struct gemm_callback_arg * arg = new gemm_callback_arg; + + blockA_width = std::min(widthA - A_start_x, blocksize); + blockA_height = std::min(heightA - A_start_y, blocksize); + blockB_width = std::min(widthB - B_start_x, blocksize); + blockB_height = std::min(heightB - B_start_y, blocksize); + int block_Ksize = std::min((int)K - k, blocksize); + + int padded_k = block_Ksize + ((block_Ksize & 7) ? (8 - (block_Ksize & 7)) : 0); + int imageA_w = (TransA == CblasNoTrans) ? padded_k : blockA_width; + int imageA_h = (TransA == CblasNoTrans) ? blockA_height : padded_k; + int imageB_w = (TransB == CblasNoTrans) ? blockB_width : padded_k; + int imageB_h = (TransB == CblasNoTrans) ? padded_k : blockB_height; + + int blockA_offset = offA + A_start_y * ldA + A_start_x; + int blockB_offset = offB + B_start_y * ldB + B_start_x; + int blockC_offset = offC + C_start_y * ldC + C_start_x; + if (TransB == CblasNoTrans) { + bool padding_A = false; + bool padding_B = false; + + if(halfPrecisionMode && is_image_b) { + padding_A = true; + } + + if (!is_image_a && !is_image_b) { + if (M * K < N * K) + padding_B = true; + else + padding_A = true; + } + + if (!is_image_a) { + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, blockA_offset, + true, TransA != CblasNoTrans, + padding_A, imageA_h, imageA_w, + blockA_height, blockA_width, ldA, 0, NULL, &ev[ev_idx]); + if (ev[ev_idx] != NULL) + ev_idx++; + } + if (!is_image_b) { + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, blockB_offset, + false, false, + padding_B, imageB_h, imageB_w, + blockB_height, blockB_width, ldB, 0, NULL, &ev[ev_idx]); + if (ev[ev_idx] != NULL) + ev_idx++; + } + } else { + // We will use normal read_imagef to read image B when B has transpose. + // thus we don't need to pad image A at all. + if (!is_image_a) { + bool padding; + padding = !is_image_b || halfPrecisionMode; + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, blockA_offset, + true, TransA != CblasNoTrans, + padding, imageA_h, imageA_w, + blockA_height, blockA_width, ldA, 0, NULL, &ev[ev_idx]); + if (ev[ev_idx] != NULL) + ev_idx++; + } + + if(!is_image_b && (K % use_buffer_indicator != 0)) { + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, blockB_offset, + false, true, false, imageB_h, imageB_w, + blockB_height, blockB_width, ldB, 0, NULL, &ev[ev_idx]); + if (ev[ev_idx] != NULL) + ev_idx++; + } + } + if (is_image_a) + ImA = A; + if (is_image_b) + ImB = B; + + size_t global[2]; + if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || + gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE ) { + if(halfPrecisionMode) { + global[0] = (size_t)( blockC_width + 15 ) & ~15; + } else { + global[0] = (size_t)( blockC_width + 7 ) & ~7; + } + } + else { + if(halfPrecisionMode) { + global[0] = (size_t)( (blockC_width / 2 ) + 15 ) ^ ~15; + } else { + global[0] = (size_t)( (blockC_width / 2 ) + 7 ) ^ ~7; + } + } + global[1] = (size_t)(blockC_height + 31) / 32; + + size_t local[2]; + + if (halfPrecisionMode) { + local[0] = 16; + } else { + local[0] = 8; + } + local[1] = 1; + + cl_uint arg_idx = 0; + oclk_gemm_float->arg(arg_idx++, WrapHandle(ImA, &ctx)); + if (TransB == CblasNoTrans || is_image_b || (K % use_buffer_indicator != 0)) + oclk_gemm_float->arg(arg_idx++, WrapHandle(ImB, &ctx)); + else { + oclk_gemm_float->arg(arg_idx++, WrapHandle(B, &ctx)); + oclk_gemm_float->arg(arg_idx++, blockB_offset); + oclk_gemm_float->arg(arg_idx++, ldB); + } + oclk_gemm_float->arg(arg_idx++, WrapHandle(C, &ctx)); + oclk_gemm_float->arg(arg_idx++, blockC_offset); + oclk_gemm_float->arg(arg_idx++, blockC_height); + oclk_gemm_float->arg(arg_idx++, blockC_width); + oclk_gemm_float->arg(arg_idx++, ldC); + oclk_gemm_float->arg(arg_idx++, alpha); + oclk_gemm_float->arg(arg_idx++, beta); + oclk_gemm_float->arg(arg_idx++, padded_k); + if (TransB != CblasNoTrans) + oclk_gemm_float->arg(arg_idx++, block_Ksize); + oclk_gemm_float->arg(arg_idx++, isFirstColBlock); + + cl_event *wait_list = NULL; + if (ev_idx != 0) + wait_list = &ev[0]; + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_gemm_float->handle().get(), 2, NULL, + global, local, ev_idx, + wait_list, &ev[ev_idx])); + if(TransA == CblasNoTrans) + A_start_x += blockA_width; + else + A_start_y += blockA_height; + + if(TransB == CblasNoTrans) + B_start_y += blockB_height; + else + B_start_x += blockB_width; + + isFirstColBlock = 0; + arg->evs.assign(ev, ev + ev_idx + 1); + clSetEventCallback(ev[ev_idx], CL_COMPLETE, &gemm_callback, (void*)arg); + } + + C_start_x += blockC_width; + if(TransA == CblasNoTrans) + A_start_x = 0; + else + A_start_y = 0; + if(TransB == CblasNoTrans) { + B_start_x += blockB_width; + B_start_y = 0; + } else { + B_start_y += blockB_height; + B_start_x = 0; + } + if(C_start_x >= N) { + C_start_x = 0; + B_start_x = 0; + B_start_y = 0; + C_start_y += blockC_height; + if(TransA == CblasNoTrans) + A_start_y += blockA_height; + else + A_start_x += blockA_width; + } + } + + if(ImA && !is_image_a) + clReleaseMemObject(ImA); + if(ImB && !is_image_b) + clReleaseMemObject(ImB); +} + +static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int_tp M, + const int_tp N, const int_tp K, const float alpha, + const cl_mem A, const int_tp offA, const cl_mem B, + const int_tp offB, const float beta, cl_mem C, + const int_tp offC, enum gemm_type_t gemm_type) { + CHECK_EQ(gemm_type == GEMM_TYPE_FAST_BUFFER, true) + << "Invalid fast buffer gemm type." << std::endl; + + cl_event ev; + + viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) + ->program(); + bool halfPrecisionMode= false; + + size_t sub_group_size = 8; + bool is_small_batch = (M == 2 || M == 4 || M == 8); + viennacl::ocl::kernel *oclk_gemm_float; + std::string kernel_name("gemm_buffer_"); + if(TransA == CblasNoTrans && TransB == CblasNoTrans) { + kernel_name += "NN_float"; + if(halfPrecisionMode) { + sub_group_size = 16; + } + } else if(TransA == CblasNoTrans && TransB != CblasNoTrans) { + if (M == 2) + kernel_name +="NT_M_2_float"; + else if (M == 4) + kernel_name +="NT_M_4_float"; + else if (M == 8) + kernel_name +="NT_M_8_float"; + else + kernel_name += "NT_float"; + } else if(TransA != CblasNoTrans && TransB == CblasNoTrans) { + kernel_name += "TN_float"; + if(halfPrecisionMode) { + sub_group_size = 16; + } + } else { + kernel_name += "TT_float"; + } + + oclk_gemm_float = &program.get_kernel(kernel_name); + size_t local[2] = {}; + size_t global[2] = {}; + if (TransA == CblasNoTrans && TransB != CblasNoTrans && is_small_batch ) { + if(M == 8) + local[0] = 16; + else if(M == 4) + local[0] = 32; + else + local[0] = 64; + local[1] = 1; + + if(M == 8) + global[0] = N * local[0]; + else + global[0] = (N + 3) / 4 * local[0]; + global[1] = 1; + } else { + size_t lx = sub_group_size; + size_t ly = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 16 : 4; + int dx = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 1 : 4; + int dy = 8; + size_t gx = (size_t)(N + dx - 1) / dx; + size_t gy = (size_t)(M + dy - 1) / dy; + global[0] = (gx + lx - 1) / lx * lx; + global[1] = (gy + ly - 1) / ly * ly; + local[0] = lx; + local[1] = ly; + } + + cl_uint arg_idx = 0; + oclk_gemm_float->arg(arg_idx++, WrapHandle(A, &ctx)); + oclk_gemm_float->arg(arg_idx++, offA); + oclk_gemm_float->arg(arg_idx++, WrapHandle(B, &ctx)); + oclk_gemm_float->arg(arg_idx++, offB); + oclk_gemm_float->arg(arg_idx++, WrapHandle(C, &ctx)); + oclk_gemm_float->arg(arg_idx++, offC); + oclk_gemm_float->arg(arg_idx++, M); + oclk_gemm_float->arg(arg_idx++, N); + oclk_gemm_float->arg(arg_idx++, K); + oclk_gemm_float->arg(arg_idx++, alpha); + oclk_gemm_float->arg(arg_idx++, beta); + + if(TransB == CblasNoTrans || TransA != CblasNoTrans) { + int stride = 256; + for(int start_index = 0; start_index < K; start_index += stride) { + oclk_gemm_float->arg(arg_idx, start_index); + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_gemm_float->handle().get(), 2, NULL, + global, local, 0, + NULL, &ev)); + } + } else { + OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), + oclk_gemm_float->handle().get(), 2, NULL, + global, local, 0, + NULL, &ev)); + } + clReleaseEvent(ev); +} + +template +static void innerprod_common(const int_tp ctx_id, const CBLAS_TRANSPOSE TransB, + const int_tp M, const int_tp N, const int_tp K, + const cl_mem A, const cl_mem B, const cl_mem B_image, + cl_mem C, gemm_type_t gemm_type, const size_t max_image_size) { + + if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || + gemm_type == GEMM_TYPE_FAST_IMAGE_32_2) { + greentea_gpu_fast_image_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, + (Dtype)1., A, 0, B, 0, (Dtype)0., C, + 0, false, false, gemm_type, max_image_size); + } else if (gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) { + greentea_gpu_fast_image_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, + (Dtype)1., A, 0, B_image, 0, (Dtype)0., C, + 0, false, true, GEMM_TYPE_FAST_IMAGE_B_IMAGE, max_image_size); + + } else if (gemm_type == GEMM_TYPE_FAST_BUFFER) { + greentea_gpu_fast_buffer_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, + 1.f, A, 0, B, 0, 0.f, C, + 0, gemm_type); + } else + greentea_gpu_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, + (Dtype)1., A, 0, B, 0, (Dtype)0., C, 0); +} + + + +template +void InnerProductLayer::generate_key() { + std::stringstream keyBuilder; + keyBuilder << M_ << "_" + << N_ << "_" + << K_ << "_" + << transpose_; + + viennacl::ocl::context &ctx = viennacl::ocl::get_context + (this->device_->id()); + std::string prefix = ctx.current_device().name() + ctx.current_device().vendor() + + ctx.current_device().driver_version() + + std::to_string(ctx.current_device().max_compute_units()); + key_ = viennacl::tools::sha1(prefix + keyBuilder.str()); + // short_key_ = keyBuilder.str(); +} + +template void InnerProductLayer::generate_key(); +template void InnerProductLayer::generate_key(); + +template +bool InnerProductLayer::load_cache() { + if (tuned_) + return true; + else { + generate_key(); + // Find cached kernel configuration + string outputFile; + outputFile = cache_path_.str() + key_; + std::ifstream cachedKernel(outputFile.c_str()); + if (cachedKernel) { + int cache_config; + cachedKernel >> cache_config; + innerprod_type_ = (gemm_type_t)cache_config; + tuned_ = true; + return true; + } else { + return false; + } + } +} + +template bool InnerProductLayer::load_cache(); +template bool InnerProductLayer::load_cache(); + +template +void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, + const cl_mem B_image, const size_t max_image_size) { + if (std::is_same::value) { + return; + } else { + //1. load cache + if (load_cache()) { + return; + } else { + //2. if not cached generate tuning + uint element_size = 0; + bool halfPrecisionMode= false; + if(halfPrecisionMode) { + element_size = sizeof(uint16_t); + } else { + element_size = sizeof(float); + } + viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); + cl_int err; + + cl_mem C = clCreateBuffer(ctx.handle().get(), CL_MEM_ALLOC_HOST_PTR, M_ * N_ * element_size, NULL, &err); + OCL_CHECK(err); + + std::vector gemm_tests; + + gemm_tests.push_back(GEMM_TYPE_FAST_IMAGE_32_1); + if (B_image != NULL) + gemm_tests.push_back(GEMM_TYPE_FAST_IMAGE_B_IMAGE); + gemm_tests.push_back(GEMM_TYPE_FAST_BUFFER); + if(!halfPrecisionMode) + gemm_tests.push_back(GEMM_TYPE_DEFAULT); + + // warm up. + for( int i = 0; i < gemm_tests.size(); i++ ) { + innerprod_common(ctx_id, TransB, M_, N_, K_, + A, B, B_image, C, gemm_tests[i], max_image_size); + } + float fastest_time = 1e10; + int fastest_index = -1; + clFinish(ctx.get_queue().handle().get()); + for( int i = 0; i < gemm_tests.size(); i++ ) { + Timer timer; + timer.initted(); + timer.Start(); + innerprod_common(ctx_id, TransB, M_, N_, K_, + A, B, B_image, C, gemm_tests[i], max_image_size); + timer.Stop(); + float elapsedTime = timer.MilliSeconds(); +// #define INNERPROD_PROFILING +#ifdef INNERPROD_PROFILING + std::cout << "innerprod type: " << gemm_tests[i] <<" eclipsed time: " + << elapsedTime << "ms." << std::endl; +#endif + if (elapsedTime < fastest_time) { + fastest_time = elapsedTime; + fastest_index = i; + } + } + clReleaseMemObject(C); + + if (fastest_index >= 0) { + innerprod_type_ = gemm_tests[fastest_index]; + } + //3. store cache. + string outputFile; + outputFile = cache_path_.str() + key_; + std::ofstream outputKernel; + outputKernel.open(outputFile.c_str()); + outputKernel << innerprod_type_; + outputKernel.close(); + tuned_ = true; + return; + } + } + return; +} + +template void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, + const cl_mem B_image, const size_t max_image_size); +template void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, + const cl_mem B_image, const size_t max_image_size); + template void InnerProductLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { @@ -34,6 +727,20 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, #endif // USE CUDA } else { #ifdef USE_GREENTEA + int padded_height = 0, padded_width = 0; + Dtype *bias_mult_data; + Dtype *bias_term_data; + if (bias_term_) { + bias_mult_data = (Dtype*)bias_multiplier_.gpu_data(); + bias_term_data = (Dtype*)this->blobs_[1]->gpu_data(); + } + int height = !transpose_ ? N_ : K_; + int width = !transpose_ ? K_ : N_; + if (M_ != 1) { + padded_height = !transpose_ ? height : (height + ((height & 7) ? 1 : 0)); + padded_width = !transpose_ ? width : (width + ((width & 7) ? 1 : 0)); + } + if (M_ == 1) { greentea_gpu_gemv(this->device_->id(), CblasNoTrans, N_, K_, (Dtype) 1., (cl_mem) weight, 0, @@ -64,30 +771,32 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, weight_image_ = NULL; } greentea_gpu_gemm_copy_buffer_to_image(this->device_->id(), - &weight_image_, (cl_mem) weight, 0, - false, !transpose_, - true, padded_height, padded_width, - height, width, (int)0, NULL, NULL); + &weight_image_, (cl_mem) weight, 0, + false, !transpose_, + true, padded_height, padded_width, + height, width, width, (int)0, NULL, NULL); copied_weight_data_ = this->blobs_[0]->data().get(); } - greentea_gpu_gemm(this->device_->id(), CblasNoTrans, - transpose_ ? CblasNoTrans : CblasTrans, - M_, N_, K_, (Dtype) 1., - (cl_mem) bottom_data, 0, (cl_mem) weight_image_, 0, - (Dtype) 0., (cl_mem) top_data, 0, false, true); - } else - greentea_gpu_gemm(this->device_->id(), CblasNoTrans, - transpose_ ? CblasNoTrans : CblasTrans, - M_, N_, K_, (Dtype) 1., - (cl_mem) bottom_data, 0, (cl_mem) weight, 0, - (Dtype) 0., (cl_mem) top_data, 0); + } - if (bias_term_) + tune_innerprod_type(this->device_->id(), + transpose_ ? CblasNoTrans : CblasTrans, + (cl_mem) bottom_data, (cl_mem) weight, (cl_mem) weight_image_, + max_image_size); + + innerprod_common(this->device_->id(), + transpose_ ? CblasNoTrans : CblasTrans, + M_, N_, K_, (cl_mem) bottom_data, + (cl_mem) weight, (cl_mem) weight_image_, + (cl_mem) top_data, innerprod_type_, max_image_size); + if (bias_term_) { + // Execute kernel greentea_gpu_gemm(this->device_->id(), CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype) 1., (cl_mem) (bias_multiplier_.gpu_data()), 0, (cl_mem) (this->blobs_[1]->gpu_data()), 0, (Dtype) 1., (cl_mem) top_data, 0); + } } #endif // USE_GREENTEA } diff --git a/src/caffe/test/test_inner_product_layer.cpp b/src/caffe/test/test_inner_product_layer.cpp index 2cc780e3b3c..e54aabf1dc5 100644 --- a/src/caffe/test/test_inner_product_layer.cpp +++ b/src/caffe/test/test_inner_product_layer.cpp @@ -138,7 +138,7 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { UniformFiller filler(filler_param); caffe::Caffe::SetDevice(0); - for(auto i = 1; i <= 8; i*=2) { + for(auto i = 1; i <= 64; i*=2) { Blob* const blob_bottom = new Blob(i, 392, 8, 8); Blob* const blob_top = new Blob(); filler.Fill(blob_bottom); @@ -187,10 +187,9 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { elapsedTime /= times; std::cout << "MNK(" << M << ","< filler(filler_param); caffe::Caffe::SetDevice(0); - for(auto i = 1; i <= 8; i*=2) { + for(auto i = 1; i <= 64; i*=2) { Blob* const blob_bottom = new Blob(i, 25088+1, 1, 1); Blob* const blob_top = new Blob(); filler.Fill(blob_bottom); @@ -252,10 +251,9 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { elapsedTime /= times; std::cout << "MNK(" << M << ","< Date: Tue, 23 May 2017 06:32:22 +0800 Subject: [PATCH 23/33] Enable FP16 support for OpenCL backend. This patch depends on FP16 version of clBLAS which only supported by ISAAC, so the FP16 support will only be enabled with ISAAC used. For Intel platform, FP16 could get about 1.3x to 1.7x performance of FP32 format's according to different net models and different batch sizes. The patch introduce FP16 support into the framework. Currently, it could pass 90% of the half type test cases. Signed-off-by: Zhigang Gong --- cmake/Dependencies.cmake | 4 + include/3rdparty/LICENSE | 21 + include/3rdparty/half/half.hpp | 3067 ++++++++++++++++++++ include/caffe/common.hpp | 48 +- include/caffe/device.hpp | 1 + include/caffe/greentea/cl_kernels.hpp | 1 + include/caffe/greentea/greentea.hpp | 4 +- include/caffe/greentea/libdnn_tuner.hpp | 2 +- include/caffe/layer_factory.hpp | 10 +- include/caffe/test/test_caffe_main.hpp | 36 +- include/caffe/test/test_gradient_check_util.hpp | 5 + include/caffe/util/fp16.hpp | 56 + include/caffe/util/math_functions.hpp | 1 + include/caffe/util/mkl_alternate.hpp | 14 + src/caffe/blob.cpp | 24 +- src/caffe/common.cpp | 3 + src/caffe/device.cpp | 11 +- src/caffe/greentea/cl_headers/header.cl | 7 +- src/caffe/greentea/cl_kernels.cpp | 2487 +++++++++------- src/caffe/greentea/cl_kernels.sh | 154 +- src/caffe/greentea/cl_kernels/activation.cl | 6 +- src/caffe/greentea/cl_kernels/auxiliary.cl | 2 +- src/caffe/greentea/cl_kernels/batch_norm.cl | 20 +- src/caffe/greentea/cl_kernels/benchmark.cl | 2 +- src/caffe/greentea/cl_kernels/channel.cl | 2 +- src/caffe/greentea/cl_kernels/contrastive_loss.cl | 4 +- .../greentea/cl_kernels/conv_layer_spatial.cl | 1006 ++++--- src/caffe/greentea/cl_kernels/dropout.cl | 4 +- src/caffe/greentea/cl_kernels/eltwise.cl | 2 +- src/caffe/greentea/cl_kernels/elu.cl | 4 +- src/caffe/greentea/cl_kernels/embed.cl | 39 + src/caffe/greentea/cl_kernels/fft.cl | 6 +- src/caffe/greentea/cl_kernels/fillbuffer.cl | 2 +- src/caffe/greentea/cl_kernels/gemm.cl | 1412 ++++----- src/caffe/greentea/cl_kernels/lrn.cl | 42 +- src/caffe/greentea/cl_kernels/math.cl | 4 +- src/caffe/greentea/cl_kernels/matvec_mul.cl | 261 +- src/caffe/greentea/cl_kernels/pooling.cl | 6 +- src/caffe/greentea/cl_kernels/pooling_nd.cl | 4 +- src/caffe/greentea/cl_kernels/pooling_sk.cl | 6 +- src/caffe/greentea/cl_kernels/softmax_loss.cl | 6 +- src/caffe/greentea/cl_kernels/solvers.cl | 32 +- src/caffe/greentea/greentea_im2col.cpp | 67 + src/caffe/greentea/greentea_math_functions.cpp | 285 +- src/caffe/greentea/libdnn.cpp | 49 +- src/caffe/greentea/libdnn_conv_spatial.cpp | 480 ++- src/caffe/greentea/libdnn_pool.cpp | 16 +- src/caffe/layers/batch_norm_layer.cu | 4 +- src/caffe/layers/contrastive_loss_layer.cpp | 2 +- src/caffe/layers/contrastive_loss_layer.cu | 7 +- src/caffe/layers/conv_layer_spatial.cpp | 403 +-- src/caffe/layers/dropout_layer.cu | 6 +- src/caffe/layers/eltwise_layer.cpp | 5 +- src/caffe/layers/elu_layer.cu | 6 +- src/caffe/layers/embed_layer.cu | 7 + src/caffe/layers/hinge_loss_layer.cpp | 8 +- src/caffe/layers/inner_product_layer.cu | 116 +- src/caffe/layers/lrn_layer.cu | 31 +- src/caffe/layers/pooling_layer.cpp | 8 +- src/caffe/layers/power_layer.cpp | 4 +- src/caffe/layers/power_layer.cu | 4 +- src/caffe/layers/reduction_layer.cpp | 2 +- src/caffe/layers/relu_layer.cu | 5 +- src/caffe/layers/silence_layer.cu | 2 +- src/caffe/layers/softmax_loss_layer.cpp | 5 +- src/caffe/layers/threshold_layer.cu | 2 +- src/caffe/solvers/adadelta_solver.cpp | 4 +- src/caffe/solvers/adadelta_solver.cu | 10 +- src/caffe/solvers/adagrad_solver.cu | 6 +- src/caffe/solvers/adam_solver.cpp | 12 +- src/caffe/solvers/adam_solver.cu | 10 +- src/caffe/solvers/nesterov_solver.cpp | 4 +- src/caffe/solvers/nesterov_solver.cu | 8 +- src/caffe/solvers/rmsprop_solver.cu | 10 +- src/caffe/solvers/sgd_solver.cpp | 2 +- src/caffe/solvers/sgd_solver.cu | 8 +- src/caffe/test/test_accuracy_layer.cpp | 38 +- src/caffe/test/test_batch_norm_layer.cpp | 6 +- src/caffe/test/test_bias_layer.cpp | 54 +- src/caffe/test/test_blob.cpp | 5 +- src/caffe/test/test_caffe_main.cpp | 17 + src/caffe/test/test_contrastive_loss_layer.cpp | 10 +- src/caffe/test/test_convolution_layer.cpp | 54 +- src/caffe/test/test_convolution_layer_spatial.cpp | 78 +- src/caffe/test/test_deconvolution_layer.cpp | 4 +- src/caffe/test/test_eltwise_layer.cpp | 12 +- src/caffe/test/test_embed_layer.cpp | 4 +- src/caffe/test/test_euclidean_loss_layer.cpp | 6 +- src/caffe/test/test_gradient_based_solver.cpp | 10 +- src/caffe/test/test_image_data_layer.cpp | 2 +- src/caffe/test/test_inner_product_layer.cpp | 34 +- src/caffe/test/test_lrn_layer.cpp | 2 + src/caffe/test/test_lstm_layer.cpp | 3 +- src/caffe/test/test_math_functions.cpp | 9 +- src/caffe/test/test_mvn_layer.cpp | 9 +- src/caffe/test/test_net.cpp | 18 +- src/caffe/test/test_neuron_layer.cpp | 37 +- src/caffe/test/test_pooling_layer.cpp | 3 +- src/caffe/test/test_power_layer.cpp | 7 +- src/caffe/test/test_random_number_generator.cpp | 8 +- src/caffe/test/test_scale_layer.cpp | 63 +- .../test/test_sigmoid_cross_entropy_loss_layer.cpp | 6 +- src/caffe/test/test_softmax_layer.cpp | 19 +- src/caffe/test/test_softmax_with_loss_layer.cpp | 4 +- src/caffe/test/test_syncedmem.cpp | 2 +- src/caffe/test/test_tanh_layer.cpp | 9 +- src/caffe/util/blocking_queue.cpp | 4 +- src/caffe/util/hdf5.cpp | 38 + src/caffe/util/im2col.cpp | 36 + src/caffe/util/math_functions.cpp | 307 +- src/gtest/gtest-all.cpp | 49 + src/gtest/gtest.h | 164 +- tools/caffe-fp16.cpp | 592 ++++ 113 files changed, 8941 insertions(+), 3208 deletions(-) create mode 100644 include/3rdparty/LICENSE create mode 100644 include/3rdparty/half/half.hpp create mode 100644 include/caffe/util/fp16.hpp create mode 100644 tools/caffe-fp16.cpp diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 45bd93b0e1b..b599c2c3d2d 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -185,6 +185,10 @@ if (USE_ISAAC) endif() list(APPEND Caffe_LINKER_LIBS PUBLIC ${ISAAC_LIBRARY}) list(APPEND Caffe_DEFINITIONS PUBLIC -DUSE_CLBLAS) + if (USE_GREENTEA AND NOT USE_CUDA) + message(STATUS "Enable half floating point supprot.") + list(APPEND Caffe_DEFINITIONS PUBLIC -DHAS_HALF_SUPPORT) + endif() endif() # ---[ CLBlast diff --git a/include/3rdparty/LICENSE b/include/3rdparty/LICENSE new file mode 100644 index 00000000000..abee50b1321 --- /dev/null +++ b/include/3rdparty/LICENSE @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2012-2017 Christian Rau + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/include/3rdparty/half/half.hpp b/include/3rdparty/half/half.hpp new file mode 100644 index 00000000000..c890a2a97f8 --- /dev/null +++ b/include/3rdparty/half/half.hpp @@ -0,0 +1,3067 @@ +// half - IEEE 754-based half-precision floating point library. +// +// Copyright (c) 2012-2017 Christian Rau +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation +// files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, +// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Version 1.12.0 + +/// \file +/// Main header file for half precision functionality. + +#ifndef HALF_HALF_HPP +#define HALF_HALF_HPP + +/// Combined gcc version number. +#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) + +//check C++11 language features +#if defined(__clang__) //clang + #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif +/*#elif defined(__INTEL_COMPILER) //Intel C++ + #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif*/ +#elif defined(__GNUC__) //gcc + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #endif +#elif defined(_MSC_VER) //Visual C++ + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #define HALF_POP_WARNINGS 1 + #pragma warning(push) + #pragma warning(disable : 4099 4127 4146) //struct vs class, constant in if, negative unsigned +#endif + +//check C++11 library features +#include +#if defined(_LIBCPP_VERSION) //libc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #ifndef HALF_ENABLE_CPP11_CSTDINT + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #ifndef HALF_ENABLE_CPP11_CMATH + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #ifndef HALF_ENABLE_CPP11_HASH + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #endif +#elif defined(__GLIBCXX__) //libstdc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifdef __clang__ + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #else + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #endif + #endif +#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ + #if _CPPLIB_VER >= 520 + #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #ifndef HALF_ENABLE_CPP11_CSTDINT + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #ifndef HALF_ENABLE_CPP11_HASH + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #endif + #if _CPPLIB_VER >= 610 + #ifndef HALF_ENABLE_CPP11_CMATH + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #endif +#endif +#undef HALF_GNUC_VERSION + +//support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR + #define HALF_CONSTEXPR constexpr + #define HALF_CONSTEXPR_CONST constexpr +#else + #define HALF_CONSTEXPR + #define HALF_CONSTEXPR_CONST const +#endif + +//support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT + #define HALF_NOEXCEPT noexcept + #define HALF_NOTHROW noexcept +#else + #define HALF_NOEXCEPT + #define HALF_NOTHROW throw() +#endif + +#include +#include +#include +#include +#include +#include +#if HALF_ENABLE_CPP11_TYPE_TRAITS + #include +#endif +#if HALF_ENABLE_CPP11_CSTDINT + #include +#endif +#if HALF_ENABLE_CPP11_HASH + #include +#endif + + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as +/// for the half_cast() if not specifying a rounding mode explicitly. It can be redefined (before including half.hpp) to one +/// of the standard rounding modes using their respective constants or the equivalent values of `std::float_round_style`: +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest (default) +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `-1` (`std::round_indeterminate`), which uses truncation (round toward zero, but with overflows +/// set to infinity) and is the fastest rounding mode possible. It can even be set to `std::numeric_limits::round_style` +/// to synchronize the rounding mode with that of the underlying single-precision implementation. +#ifndef HALF_ROUND_STYLE + #define HALF_ROUND_STYLE -1 // = std::round_indeterminate +#endif + +/// Tie-breaking behaviour for round to nearest. +/// This specifies if ties in round to nearest should be resolved by rounding to the nearest even value. By default this is +/// defined to `0` resulting in the faster but slightly more biased behaviour of rounding away from zero in half-way cases (and +/// thus equal to the round() function), but can be redefined to `1` (before including half.hpp) if more IEEE-conformant +/// behaviour is needed. +#ifndef HALF_ROUND_TIES_TO_EVEN + #define HALF_ROUND_TIES_TO_EVEN 0 // ties away from zero +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value signaling the overflow of an +/// operation, in particular it just evaluates to positive infinity. +#define HUGE_VALH std::numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is only defined if the fma() function generally executes as fast as, or faster than, a separate +/// half-precision multiplication followed by an addition. Due to the internal single-precision implementation of all +/// arithmetic operations, this is in fact always the case. +#define FP_FAST_FMAH 1 + +#ifndef FP_ILOGB0 + #define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN + #define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL + #define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO + #define FP_ZERO 1 +#endif +#ifndef FP_NAN + #define FP_NAN 2 +#endif +#ifndef FP_INFINITE + #define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL + #define FP_NORMAL 4 +#endif + + +/// Main namespace for half precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float +{ + class half; + +#if HALF_ENABLE_CPP11_USER_LITERALS + /// Library-defined half-precision literals. + /// Import this namespace to enable half-precision floating point literals: + /// ~~~~{.cpp} + /// using namespace half_float::literal; + /// half_float::half = 4.2_h; + /// ~~~~ + namespace literal + { + half operator""_h(long double); + } +#endif + + /// \internal + /// \brief Implementation details. + namespace detail + { + #if HALF_ENABLE_CPP11_TYPE_TRAITS + /// Conditional type. + template struct conditional : std::conditional {}; + + /// Helper for tag dispatching. + template struct bool_type : std::integral_constant {}; + using std::true_type; + using std::false_type; + + /// Type traits for floating point types. + template struct is_float : std::is_floating_point {}; + #else + /// Conditional type. + template struct conditional { typedef T type; }; + template struct conditional { typedef F type; }; + + /// Helper for tag dispatching. + template struct bool_type {}; + typedef bool_type true_type; + typedef bool_type false_type; + + /// Type traits for floating point types. + template struct is_float : false_type {}; + template struct is_float : is_float {}; + template struct is_float : is_float {}; + template struct is_float : is_float {}; + template<> struct is_float : true_type {}; + template<> struct is_float : true_type {}; + template<> struct is_float : true_type {}; + #endif + + /// Type traits for floating point bits. + template struct bits { typedef unsigned char type; }; + template struct bits : bits {}; + template struct bits : bits {}; + template struct bits : bits {}; + + #if HALF_ENABLE_CPP11_CSTDINT + /// Unsigned integer of (at least) 16 bits width. + typedef std::uint_least16_t uint16; + + /// Unsigned integer of (at least) 32 bits width. + template<> struct bits { typedef std::uint_least32_t type; }; + + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits { typedef std::uint_least64_t type; }; + #else + /// Unsigned integer of (at least) 16 bits width. + typedef unsigned short uint16; + + /// Unsigned integer of (at least) 32 bits width. + template<> struct bits : conditional::digits>=32,unsigned int,unsigned long> {}; + + #if HALF_ENABLE_CPP11_LONG_LONG + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits : conditional::digits>=64,unsigned long,unsigned long long> {}; + #else + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits { typedef unsigned long type; }; + #endif + #endif + + /// Tag type for binary construction. + struct binary_t {}; + + /// Tag for binary construction. + HALF_CONSTEXPR_CONST binary_t binary = binary_t(); + + /// Temporary half-precision expression. + /// This class represents a half-precision expression which just stores a single-precision value internally. + struct expr + { + /// Conversion constructor. + /// \param f single-precision value to convert + explicit HALF_CONSTEXPR expr(float f) HALF_NOEXCEPT : value_(f) {} + + /// Conversion to single-precision. + /// \return single precision value representing expression value + HALF_CONSTEXPR operator float() const HALF_NOEXCEPT { return value_; } + + private: + /// Internal expression value stored in single-precision. + float value_; + }; + + /// SFINAE helper for generic half-precision functions. + /// This class template has to be specialized for each valid combination of argument types to provide a corresponding + /// `type` member equivalent to \a T. + /// \tparam T type to return + template struct enable {}; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + template struct enable { typedef T type; }; + + /// Return type for specialized generic 2-argument half-precision functions. + /// This class template has to be specialized for each valid combination of argument types to provide a corresponding + /// `type` member denoting the appropriate return type. + /// \tparam T first argument type + /// \tparam U first argument type + template struct result : enable {}; + template<> struct result { typedef half type; }; + + /// \name Classification helpers + /// \{ + + /// Check for infinity. + /// \tparam T argument type (builtin floating point type) + /// \param arg value to query + /// \retval true if infinity + /// \retval false else + template bool builtin_isinf(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::isinf(arg); + #elif defined(_MSC_VER) + return !::_finite(static_cast(arg)) && !::_isnan(static_cast(arg)); + #else + return arg == std::numeric_limits::infinity() || arg == -std::numeric_limits::infinity(); + #endif + } + + /// Check for NaN. + /// \tparam T argument type (builtin floating point type) + /// \param arg value to query + /// \retval true if not a number + /// \retval false else + template bool builtin_isnan(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); + #elif defined(_MSC_VER) + return ::_isnan(static_cast(arg)) != 0; + #else + return arg != arg; + #endif + } + + /// Check sign. + /// \tparam T argument type (builtin floating point type) + /// \param arg value to query + /// \retval true if signbit set + /// \retval false else + template bool builtin_signbit(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); + #else + return arg < T() || (arg == T() && T(1)/arg < T()); + #endif + } + + /// \} + /// \name Conversion + /// \{ + + /// Convert IEEE single-precision to half-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \param value single-precision value + /// \return binary representation of half-precision value + template uint16 float2half_impl(float value, true_type) + { + typedef bits::type uint32; + uint32 bits;// = *reinterpret_cast(&value); //violating strict aliasing! + std::memcpy(&bits, &value, sizeof(float)); +/* uint16 hbits = (bits>>16) & 0x8000; + bits &= 0x7FFFFFFF; + int exp = bits >> 23; + if(exp == 255) + return hbits | 0x7C00 | (0x3FF&-static_cast((bits&0x7FFFFF)!=0)); + if(exp > 142) + { + if(R == std::round_toward_infinity) + return hbits | 0x7C00 - (hbits>>15); + if(R == std::round_toward_neg_infinity) + return hbits | 0x7BFF + (hbits>>15); + return hbits | 0x7BFF + (R!=std::round_toward_zero); + } + int g, s; + if(exp > 112) + { + g = (bits>>12) & 1; + s = (bits&0xFFF) != 0; + hbits |= ((exp-112)<<10) | ((bits>>13)&0x3FF); + } + else if(exp > 101) + { + int i = 125 - exp; + bits = (bits&0x7FFFFF) | 0x800000; + g = (bits>>i) & 1; + s = (bits&((1L<> (i+1); + } + else + { + g = 0; + s = bits != 0; + } + if(R == std::round_to_nearest) + #if HALF_ROUND_TIES_TO_EVEN + hbits += g & (s|hbits); + #else + hbits += g; + #endif + else if(R == std::round_toward_infinity) + hbits += ~(hbits>>15) & (s|g); + else if(R == std::round_toward_neg_infinity) + hbits += (hbits>>15) & (g|s); +*/ static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, 0x0100, + 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, + 0x4000, 0x4400, 0x4800, 0x4C00, 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, + 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, + 0xC000, 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, 0xF000, 0xF400, 0xF800, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00 }; + static const unsigned char shift_table[512] = { + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13 }; + uint16 hbits = base_table[bits>>23] + static_cast((bits&0x7FFFFF)>>shift_table[bits>>23]); + if(R == std::round_to_nearest) + hbits += (((bits&0x7FFFFF)>>(shift_table[bits>>23]-1))|(((bits>>23)&0xFF)==102)) & ((hbits&0x7C00)!=0x7C00) + #if HALF_ROUND_TIES_TO_EVEN + & (((((static_cast(1)<<(shift_table[bits>>23]-1))-1)&bits)!=0)|hbits) + #endif + ; + else if(R == std::round_toward_zero) + hbits -= ((hbits&0x7FFF)==0x7C00) & ~shift_table[bits>>23]; + else if(R == std::round_toward_infinity) + hbits += ((((bits&0x7FFFFF&((static_cast(1)<<(shift_table[bits>>23]))-1))!=0)|(((bits>>23)<=102)& + ((bits>>23)!=0)))&(hbits<0x7C00)) - ((hbits==0xFC00)&((bits>>23)!=511)); + else if(R == std::round_toward_neg_infinity) + hbits += ((((bits&0x7FFFFF&((static_cast(1)<<(shift_table[bits>>23]))-1))!=0)|(((bits>>23)<=358)& + ((bits>>23)!=256)))&(hbits<0xFC00)&(hbits>>15)) - ((hbits==0x7C00)&((bits>>23)!=255)); + return hbits; + } + + /// Convert IEEE double-precision to half-precision. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \param value double-precision value + /// \return binary representation of half-precision value + template uint16 float2half_impl(double value, true_type) + { + typedef bits::type uint32; + typedef bits::type uint64; + uint64 bits;// = *reinterpret_cast(&value); //violating strict aliasing! + std::memcpy(&bits, &value, sizeof(double)); + uint32 hi = bits >> 32, lo = bits & 0xFFFFFFFF; + uint16 hbits = (hi>>16) & 0x8000; + hi &= 0x7FFFFFFF; + int exp = hi >> 20; + if(exp == 2047) + return hbits | 0x7C00 | (0x3FF&-static_cast((bits&0xFFFFFFFFFFFFF)!=0)); + if(exp > 1038) + { + if(R == std::round_toward_infinity) + return hbits | 0x7C00 - (hbits>>15); + if(R == std::round_toward_neg_infinity) + return hbits | 0x7BFF + (hbits>>15); + return hbits | 0x7BFF + (R!=std::round_toward_zero); + } + int g, s = lo != 0; + if(exp > 1008) + { + g = (hi>>9) & 1; + s |= (hi&0x1FF) != 0; + hbits |= ((exp-1008)<<10) | ((hi>>10)&0x3FF); + } + else if(exp > 997) + { + int i = 1018 - exp; + hi = (hi&0xFFFFF) | 0x100000; + g = (hi>>i) & 1; + s |= (hi&((1L<> (i+1); + } + else + { + g = 0; + s |= hi != 0; + } + if(R == std::round_to_nearest) + #if HALF_ROUND_TIES_TO_EVEN + hbits += g & (s|hbits); + #else + hbits += g; + #endif + else if(R == std::round_toward_infinity) + hbits += ~(hbits>>15) & (s|g); + else if(R == std::round_toward_neg_infinity) + hbits += (hbits>>15) & (g|s); + return hbits; + } + + /// Convert non-IEEE floating point to half-precision. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam T source type (builtin floating point type) + /// \param value floating point value + /// \return binary representation of half-precision value + template uint16 float2half_impl(T value, ...) + { + uint16 hbits = static_cast(builtin_signbit(value)) << 15; + if(value == T()) + return hbits; + if(builtin_isnan(value)) + return hbits | 0x7FFF; + if(builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if(exp > 16) + { + if(R == std::round_toward_infinity) + return hbits | (0x7C00 - (hbits>>15)); + else if(R == std::round_toward_neg_infinity) + return hbits | (0x7BFF + (hbits>>15)); + return hbits | (0x7BFF + (R!=std::round_toward_zero)); + } + if(exp < -13) + value = std::ldexp(value, 24); + else + { + value = std::ldexp(value, 11-exp); + hbits |= ((exp+13)<<10); + } + T ival, frac = std::modf(value, &ival); + hbits += static_cast(std::abs(static_cast(ival))); + if(R == std::round_to_nearest) + { + frac = std::abs(frac); + #if HALF_ROUND_TIES_TO_EVEN + hbits += (frac>T(0.5)) | ((frac==T(0.5))&hbits); + #else + hbits += frac >= T(0.5); + #endif + } + else if(R == std::round_toward_infinity) + hbits += frac > T(); + else if(R == std::round_toward_neg_infinity) + hbits += frac < T(); + return hbits; + } + + /// Convert floating point to half-precision. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam T source type (builtin floating point type) + /// \param value floating point value + /// \return binary representation of half-precision value + template uint16 float2half(T value) + { + return float2half_impl(value, bool_type::is_iec559&&sizeof(typename bits::type)==sizeof(T)>()); + } + + /// Convert integer to half-precision floating point. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam S `true` if value negative, `false` else + /// \tparam T type to convert (builtin integer type) + /// \param value non-negative integral value + /// \return binary representation of half-precision value + template uint16 int2half_impl(T value) + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_integral::value, "int to half conversion only supports builtin integer types"); + #endif + if(S) + value = -value; + uint16 bits = S << 15; + if(value > 0xFFFF) + { + if(R == std::round_toward_infinity) + bits |= 0x7C00 - S; + else if(R == std::round_toward_neg_infinity) + bits |= 0x7BFF + S; + else + bits |= 0x7BFF + (R!=std::round_toward_zero); + } + else if(value) + { + unsigned int m = value, exp = 24; + for(; m<0x400; m<<=1,--exp) ; + for(; m>0x7FF; m>>=1,++exp) ; + bits |= (exp<<10) + m; + if(exp > 24) + { + if(R == std::round_to_nearest) + bits += (value>>(exp-25)) & 1 + #if HALF_ROUND_TIES_TO_EVEN + & (((((1<<(exp-25))-1)&value)!=0)|bits) + #endif + ; + else if(R == std::round_toward_infinity) + bits += ((value&((1<<(exp-24))-1))!=0) & !S; + else if(R == std::round_toward_neg_infinity) + bits += ((value&((1<<(exp-24))-1))!=0) & S; + } + } + return bits; + } + + /// Convert integer to half-precision floating point. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam T type to convert (builtin integer type) + /// \param value integral value + /// \return binary representation of half-precision value + template uint16 int2half(T value) + { + return (value<0) ? int2half_impl(value) : int2half_impl(value); + } + + /// Convert half-precision to IEEE single-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \param value binary representation of half-precision value + /// \return single-precision value + inline float half2float_impl(uint16 value, float, true_type) + { + typedef bits::type uint32; +/* uint32 bits = static_cast(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + bits |= 0x38000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,bits-=0x800000) ; + bits += static_cast(abs) << 13; + } +*/ static const uint32 mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, 0x35600000, 0x35700000, + 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, + 0x36000000, 0x36040000, 0x36080000, 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, + 0x36400000, 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, 0x367C0000, + 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, 0x369A0000, 0x369C0000, 0x369E0000, + 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, + 0x36C00000, 0x36C20000, 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, + 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, 0x36FC0000, 0x36FE0000, + 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, + 0x37100000, 0x37110000, 0x37120000, 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, + 0x37200000, 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, 0x372F0000, + 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, 0x373D0000, 0x373E0000, 0x373F0000, + 0x37400000, 0x37410000, 0x37420000, 0x37430000, 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, + 0x37500000, 0x37510000, 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, + 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, 0x376E0000, 0x376F0000, + 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, + 0x37800000, 0x37808000, 0x37810000, 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, + 0x37880000, 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, 0x378F8000, + 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, 0x37968000, 0x37970000, 0x37978000, + 0x37980000, 0x37988000, 0x37990000, 0x37998000, 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, + 0x37A00000, 0x37A08000, 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, 0x37AF0000, 0x37AF8000, + 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, + 0x37B80000, 0x37B88000, 0x37B90000, 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, + 0x37C00000, 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, 0x37C78000, + 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, 0x37CE8000, 0x37CF0000, 0x37CF8000, + 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, + 0x37D80000, 0x37D88000, 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, + 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, 0x37E70000, 0x37E78000, + 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, + 0x37F00000, 0x37F08000, 0x37F10000, 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, + 0x37F80000, 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, 0x37FF8000, + 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, 0x38034000, 0x38038000, 0x3803C000, + 0x38040000, 0x38044000, 0x38048000, 0x3804C000, 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, + 0x38080000, 0x38084000, 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, + 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, 0x380F8000, 0x380FC000, + 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, 0x38130000, 0x38134000, 0x38138000, 0x3813C000, + 0x38140000, 0x38144000, 0x38148000, 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, + 0x38180000, 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, 0x381BC000, + 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, 0x381F4000, 0x381F8000, 0x381FC000, + 0x38200000, 0x38204000, 0x38208000, 0x3820C000, 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, + 0x38240000, 0x38244000, 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, 0x382B8000, 0x382BC000, + 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, + 0x38300000, 0x38304000, 0x38308000, 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, + 0x38340000, 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, 0x3837C000, + 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, 0x383B4000, 0x383B8000, 0x383BC000, + 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, + 0x38400000, 0x38404000, 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, + 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, 0x38478000, 0x3847C000, + 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, + 0x384C0000, 0x384C4000, 0x384C8000, 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, + 0x38500000, 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, 0x3853C000, + 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, 0x38574000, 0x38578000, 0x3857C000, + 0x38580000, 0x38584000, 0x38588000, 0x3858C000, 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, + 0x385C0000, 0x385C4000, 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, + 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, 0x38638000, 0x3863C000, + 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, 0x38670000, 0x38674000, 0x38678000, 0x3867C000, + 0x38680000, 0x38684000, 0x38688000, 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, + 0x386C0000, 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, 0x386FC000, + 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, 0x38734000, 0x38738000, 0x3873C000, + 0x38740000, 0x38744000, 0x38748000, 0x3874C000, 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, + 0x38780000, 0x38784000, 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, 0x387F8000, 0x387FC000, + 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, + 0x38020000, 0x38022000, 0x38024000, 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, + 0x38040000, 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, 0x3805E000, + 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, 0x3807A000, 0x3807C000, 0x3807E000, + 0x38080000, 0x38082000, 0x38084000, 0x38086000, 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, + 0x380A0000, 0x380A2000, 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, + 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, 0x380DC000, 0x380DE000, + 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, + 0x38100000, 0x38102000, 0x38104000, 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, + 0x38120000, 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, 0x3813E000, + 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, 0x3815A000, 0x3815C000, 0x3815E000, + 0x38160000, 0x38162000, 0x38164000, 0x38166000, 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, + 0x38180000, 0x38182000, 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, + 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, 0x381BC000, 0x381BE000, + 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, + 0x381E0000, 0x381E2000, 0x381E4000, 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, + 0x38200000, 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, 0x3821E000, + 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, 0x3823A000, 0x3823C000, 0x3823E000, + 0x38240000, 0x38242000, 0x38244000, 0x38246000, 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, + 0x38260000, 0x38262000, 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, 0x3829C000, 0x3829E000, + 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, + 0x382C0000, 0x382C2000, 0x382C4000, 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, + 0x382E0000, 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, 0x382FE000, + 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, 0x3831A000, 0x3831C000, 0x3831E000, + 0x38320000, 0x38322000, 0x38324000, 0x38326000, 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, + 0x38340000, 0x38342000, 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, + 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, 0x3837C000, 0x3837E000, + 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, + 0x383A0000, 0x383A2000, 0x383A4000, 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, + 0x383C0000, 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, 0x383DE000, + 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, 0x383FA000, 0x383FC000, 0x383FE000, + 0x38400000, 0x38402000, 0x38404000, 0x38406000, 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, + 0x38420000, 0x38422000, 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, + 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, 0x3845C000, 0x3845E000, + 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, + 0x38480000, 0x38482000, 0x38484000, 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, + 0x384A0000, 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, 0x384BE000, + 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, 0x384DA000, 0x384DC000, 0x384DE000, + 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, + 0x38500000, 0x38502000, 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, 0x3853C000, 0x3853E000, + 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, + 0x38560000, 0x38562000, 0x38564000, 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, + 0x38580000, 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, 0x3859E000, + 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, 0x385BA000, 0x385BC000, 0x385BE000, + 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, + 0x385E0000, 0x385E2000, 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, + 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, 0x3861C000, 0x3861E000, + 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, + 0x38640000, 0x38642000, 0x38644000, 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, + 0x38660000, 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, 0x3867E000, + 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, 0x3869A000, 0x3869C000, 0x3869E000, + 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, + 0x386C0000, 0x386C2000, 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, + 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, 0x386FC000, 0x386FE000, + 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, + 0x38720000, 0x38722000, 0x38724000, 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, + 0x38740000, 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, 0x3875E000, + 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, 0x3877A000, 0x3877C000, 0x3877E000, + 0x38780000, 0x38782000, 0x38784000, 0x38786000, 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, + 0x387A0000, 0x387A2000, 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, 0x387DC000, 0x387DE000, + 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000 }; + static const uint32 exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, 0x07000000, 0x07800000, + 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, + 0x80000000, 0x80800000, 0x81000000, 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, + 0x88000000, 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, 0xC7800000 }; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024 }; + uint32 bits = mantissa_table[offset_table[value>>10]+(value&0x3FF)] + exponent_table[value>>10]; +// return *reinterpret_cast(&bits); //violating strict aliasing! + float out; + std::memcpy(&out, &bits, sizeof(float)); + return out; + } + + /// Convert half-precision to IEEE double-precision. + /// \param value binary representation of half-precision value + /// \return double-precision value + inline double half2float_impl(uint16 value, double, true_type) + { + typedef bits::type uint32; + typedef bits::type uint64; + uint32 hi = static_cast(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + hi |= 0x3F000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,hi-=0x100000) ; + hi += static_cast(abs) << 10; + } + uint64 bits = static_cast(hi) << 32; +// return *reinterpret_cast(&bits); //violating strict aliasing! + double out; + std::memcpy(&out, &bits, sizeof(double)); + return out; + } + + /// Convert half-precision to non-IEEE floating point. + /// \tparam T type to convert to (builtin integer type) + /// \param value binary representation of half-precision value + /// \return floating point value + template T half2float_impl(uint16 value, T, ...) + { + T out; + int abs = value & 0x7FFF; + if(abs > 0x7C00) + out = std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); + else if(abs == 0x7C00) + out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() : std::numeric_limits::max(); + else if(abs > 0x3FF) + out = std::ldexp(static_cast((abs&0x3FF)|0x400), (abs>>10)-25); + else + out = std::ldexp(static_cast(abs), -24); + return (value&0x8000) ? -out : out; + } + + /// Convert half-precision to floating point. + /// \tparam T type to convert to (builtin integer type) + /// \param value binary representation of half-precision value + /// \return floating point value + template T half2float(uint16 value) + { + return half2float_impl(value, T(), bool_type::is_iec559&&sizeof(typename bits::type)==sizeof(T)>()); + } + + /// Convert half-precision floating point to integer. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam E `true` for round to even, `false` for round away from zero + /// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits) + /// \param value binary representation of half-precision value + /// \return integral value + template T half2int_impl(uint16 value) + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_integral::value, "half to int conversion only supports builtin integer types"); + #endif + unsigned int e = value & 0x7FFF; + if(e >= 0x7C00) + return (value&0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); + if(e < 0x3800) + { + if(R == std::round_toward_infinity) + return T(~(value>>15)&(e!=0)); + else if(R == std::round_toward_neg_infinity) + return -T(value>0x8000); + return T(); + } + unsigned int m = (value&0x3FF) | 0x400; + e >>= 10; + if(e < 25) + { + if(R == std::round_to_nearest) + m += (1<<(24-e)) - (~(m>>(25-e))&E); + else if(R == std::round_toward_infinity) + m += ((value>>15)-1) & ((1<<(25-e))-1U); + else if(R == std::round_toward_neg_infinity) + m += -(value>>15) & ((1<<(25-e))-1U); + m >>= 25 - e; + } + else + m <<= e - 25; + return (value&0x8000) ? -static_cast(m) : static_cast(m); + } + + /// Convert half-precision floating point to integer. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits) + /// \param value binary representation of half-precision value + /// \return integral value + template T half2int(uint16 value) { return half2int_impl(value); } + + /// Convert half-precision floating point to integer using round-to-nearest-away-from-zero. + /// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits) + /// \param value binary representation of half-precision value + /// \return integral value + template T half2int_up(uint16 value) { return half2int_impl(value); } + + /// Round half-precision number to nearest integer value. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam E `true` for round to even, `false` for round away from zero + /// \param value binary representation of half-precision value + /// \return half-precision bits for nearest integral value + template uint16 round_half_impl(uint16 value) + { + unsigned int e = value & 0x7FFF; + uint16 result = value; + if(e < 0x3C00) + { + result &= 0x8000; + if(R == std::round_to_nearest) + result |= 0x3C00U & -(e>=(0x3800+E)); + else if(R == std::round_toward_infinity) + result |= 0x3C00U & -(~(value>>15)&(e!=0)); + else if(R == std::round_toward_neg_infinity) + result |= 0x3C00U & -(value>0x8000); + } + else if(e < 0x6400) + { + e = 25 - (e>>10); + unsigned int mask = (1<>e)&E); + else if(R == std::round_toward_infinity) + result += mask & ((value>>15)-1); + else if(R == std::round_toward_neg_infinity) + result += mask & -(value>>15); + result &= ~mask; + } + return result; + } + + /// Round half-precision number to nearest integer value. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \param value binary representation of half-precision value + /// \return half-precision bits for nearest integral value + template uint16 round_half(uint16 value) { return round_half_impl(value); } + + /// Round half-precision number to nearest integer value using round-to-nearest-away-from-zero. + /// \param value binary representation of half-precision value + /// \return half-precision bits for nearest integral value + inline uint16 round_half_up(uint16 value) { return round_half_impl(value); } + /// \} + + struct functions; + template struct unary_specialized; + template struct binary_specialized; + template struct half_caster; + } + + /// Half-precision floating point type. + /// This class implements an IEEE-conformant half-precision floating point type with the usual arithmetic operators and + /// conversions. It is implicitly convertible to single-precision floating point, which makes artihmetic expressions and + /// functions with mixed-type operands to be of the most precise operand type. Additionally all arithmetic operations + /// (and many mathematical functions) are carried out in single-precision internally. All conversions from single- to + /// half-precision are done using the library's default rounding mode, but temporary results inside chained arithmetic + /// expressions are kept in single-precision as long as possible (while of course still maintaining a strong half-precision type). + /// + /// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and + /// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which + /// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the + /// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be of + /// exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will most + /// probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying 16-bit + /// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16 bits if + /// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this should be the case on + /// nearly any reasonable platform. + /// + /// So if your C++ implementation is not totally exotic or imposes special alignment requirements, it is a reasonable + /// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE representation. + class half + { + friend struct detail::functions; + friend struct detail::unary_specialized; + friend struct detail::binary_specialized; + template friend struct detail::half_caster; + friend class std::numeric_limits; + #if HALF_ENABLE_CPP11_HASH + friend struct std::hash; + #endif + #if HALF_ENABLE_CPP11_USER_LITERALS + friend half literal::operator""_h(long double); + #endif + + public: + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin types' default-initialization semantics + /// and may be less efficient than no initialization, it is needed to provide proper value-initialization semantics. + HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} + + /// Copy constructor. + /// \tparam T type of concrete half expression + /// \param rhs half expression to copy from + half(detail::expr rhs) : data_(detail::float2half(static_cast(rhs))) {} + + /// Conversion constructor. + /// \param rhs float to convert + half(float rhs) : data_(detail::float2half(rhs)) {} + //half(int rhs) : data_(detail::float2half(static_cast(rhs))) {} + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float(data_); } + + /// Assignment operator. + /// \tparam T type of concrete half expression + /// \param rhs half expression to copy from + /// \return reference to this half + half& operator=(detail::expr rhs) { return *this = static_cast(rhs); } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + template typename detail::enable::type operator+=(T rhs) { return *this += static_cast(rhs); } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + template typename detail::enable::type operator-=(T rhs) { return *this -= static_cast(rhs); } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + template typename detail::enable::type operator*=(T rhs) { return *this *= static_cast(rhs); } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + template typename detail::enable::type operator/=(T rhs) { return *this /= static_cast(rhs); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + half& operator=(float rhs) { data_ = detail::float2half(rhs); return *this; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + half& operator+=(float rhs) { data_ = detail::float2half(detail::half2float(data_)+rhs); return *this; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + half& operator-=(float rhs) { data_ = detail::float2half(detail::half2float(data_)-rhs); return *this; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + half& operator*=(float rhs) { data_ = detail::float2half(detail::half2float(data_)*rhs); return *this; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + half& operator/=(float rhs) { data_ = detail::float2half(detail::half2float(data_)/rhs); return *this; } + + /// Prefix increment. + /// \return incremented half value + half& operator++() { return *this += 1.0f; } + + /// Prefix decrement. + /// \return decremented half value + half& operator--() { return *this -= 1.0f; } + + /// Postfix increment. + /// \return non-incremented half value + half operator++(int) { half out(*this); ++*this; return out; } + + /// Postfix decrement. + /// \return non-decremented half value + half operator--(int) { half out(*this); --*this; return out; } + + private: + /// Rounding mode to use + static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + HALF_CONSTEXPR half(detail::binary_t, detail::uint16 bits) HALF_NOEXCEPT : data_(bits) {} + + /// Internal binary representation + detail::uint16 data_; + }; + +#if HALF_ENABLE_CPP11_USER_LITERALS + namespace literal + { + /// Half literal. + /// While this returns an actual half-precision value, half literals can unfortunately not be constant expressions due + /// to rather involved conversions. + /// \param value literal value + /// \return half with given value (if representable) + inline half operator""_h(long double value) { return half(detail::binary, detail::float2half(value)); } + } +#endif + + namespace detail + { + /// Wrapper implementing unspecialized half-precision functions. + struct functions + { + /// Addition implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision sum stored in single-precision + static expr plus(float x, float y) { return expr(x+y); } + + /// Subtraction implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision difference stored in single-precision + static expr minus(float x, float y) { return expr(x-y); } + + /// Multiplication implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision product stored in single-precision + static expr multiplies(float x, float y) { return expr(x*y); } + + /// Division implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision quotient stored in single-precision + static expr divides(float x, float y) { return expr(x/y); } + + /// Output implementation. + /// \param out stream to write to + /// \param arg value to write + /// \return reference to stream + template static std::basic_ostream& write(std::basic_ostream &out, float arg) { return out << arg; } + + /// Input implementation. + /// \param in stream to read from + /// \param arg half to read into + /// \return reference to stream + template static std::basic_istream& read(std::basic_istream &in, half &arg) + { + float f; + if(in >> f) + arg = f; + return in; + } + + /// Modulo implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision division remainder stored in single-precision + static expr fmod(float x, float y) { return expr(std::fmod(x, y)); } + + /// Remainder implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision division remainder stored in single-precision + static expr remainder(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::remainder(x, y)); + #else + if(builtin_isnan(x) || builtin_isnan(y)) + return expr(std::numeric_limits::quiet_NaN()); + float ax = std::fabs(x), ay = std::fabs(y); + if(ax >= 65536.0f || ay < std::ldexp(1.0f, -24)) + return expr(std::numeric_limits::quiet_NaN()); + if(ay >= 65536.0f) + return expr(x); + if(ax == ay) + return expr(builtin_signbit(x) ? -0.0f : 0.0f); + ax = std::fmod(ax, ay+ay); + float y2 = 0.5f * ay; + if(ax > y2) + { + ax -= ay; + if(ax >= y2) + ax -= ay; + } + return expr(builtin_signbit(x) ? -ax : ax); + #endif + } + + /// Remainder implementation. + /// \param x first operand + /// \param y second operand + /// \param quo address to store quotient bits at + /// \return Half-precision division remainder stored in single-precision + static expr remquo(float x, float y, int *quo) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::remquo(x, y, quo)); + #else + if(builtin_isnan(x) || builtin_isnan(y)) + return expr(std::numeric_limits::quiet_NaN()); + bool sign = builtin_signbit(x), qsign = static_cast(sign^builtin_signbit(y)); + float ax = std::fabs(x), ay = std::fabs(y); + if(ax >= 65536.0f || ay < std::ldexp(1.0f, -24)) + return expr(std::numeric_limits::quiet_NaN()); + if(ay >= 65536.0f) + return expr(x); + if(ax == ay) + return *quo = qsign ? -1 : 1, expr(sign ? -0.0f : 0.0f); + ax = std::fmod(ax, 8.0f*ay); + int cquo = 0; + if(ax >= 4.0f * ay) + { + ax -= 4.0f * ay; + cquo += 4; + } + if(ax >= 2.0f * ay) + { + ax -= 2.0f * ay; + cquo += 2; + } + float y2 = 0.5f * ay; + if(ax > y2) + { + ax -= ay; + ++cquo; + if(ax >= y2) + { + ax -= ay; + ++cquo; + } + } + return *quo = qsign ? -cquo : cquo, expr(sign ? -ax : ax); + #endif + } + + /// Positive difference implementation. + /// \param x first operand + /// \param y second operand + /// \return Positive difference stored in single-precision + static expr fdim(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::fdim(x, y)); + #else + return expr((x<=y) ? 0.0f : (x-y)); + #endif + } + + /// Fused multiply-add implementation. + /// \param x first operand + /// \param y second operand + /// \param z third operand + /// \return \a x * \a y + \a z stored in single-precision + static expr fma(float x, float y, float z) + { + #if HALF_ENABLE_CPP11_CMATH && defined(FP_FAST_FMAF) + return expr(std::fma(x, y, z)); + #else + return expr(x*y+z); + #endif + } + + /// Get NaN. + /// \return Half-precision quiet NaN + static half nanh() { return half(binary, 0x7FFF); } + + /// Exponential implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr exp(float arg) { return expr(std::exp(arg)); } + + /// Exponential implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr expm1(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::expm1(arg)); + #else + return expr(static_cast(std::exp(static_cast(arg))-1.0)); + #endif + } + + /// Binary exponential implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr exp2(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::exp2(arg)); + #else + return expr(static_cast(std::exp(arg*0.69314718055994530941723212145818))); + #endif + } + + /// Logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr log(float arg) { return expr(std::log(arg)); } + + /// Common logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr log10(float arg) { return expr(std::log10(arg)); } + + /// Logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr log1p(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::log1p(arg)); + #else + return expr(static_cast(std::log(1.0+arg))); + #endif + } + + /// Binary logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr log2(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::log2(arg)); + #else + return expr(static_cast(std::log(static_cast(arg))*1.4426950408889634073599246810019)); + #endif + } + + /// Square root implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr sqrt(float arg) { return expr(std::sqrt(arg)); } + + /// Cubic root implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr cbrt(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::cbrt(arg)); + #else + if(builtin_isnan(arg) || builtin_isinf(arg)) + return expr(arg); + return expr(builtin_signbit(arg) ? -static_cast(std::pow(-static_cast(arg), 1.0/3.0)) : + static_cast(std::pow(static_cast(arg), 1.0/3.0))); + #endif + } + + /// Hypotenuse implementation. + /// \param x first argument + /// \param y second argument + /// \return function value stored in single-preicision + static expr hypot(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::hypot(x, y)); + #else + return expr((builtin_isinf(x) || builtin_isinf(y)) ? std::numeric_limits::infinity() : + static_cast(std::sqrt(static_cast(x)*x+static_cast(y)*y))); + #endif + } + + /// Power implementation. + /// \param base value to exponentiate + /// \param exp power to expontiate to + /// \return function value stored in single-preicision + static expr pow(float base, float exp) { return expr(std::pow(base, exp)); } + + /// Sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr sin(float arg) { return expr(std::sin(arg)); } + + /// Cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr cos(float arg) { return expr(std::cos(arg)); } + + /// Tan implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr tan(float arg) { return expr(std::tan(arg)); } + + /// Arc sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr asin(float arg) { return expr(std::asin(arg)); } + + /// Arc cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr acos(float arg) { return expr(std::acos(arg)); } + + /// Arc tangent implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr atan(float arg) { return expr(std::atan(arg)); } + + /// Arc tangent implementation. + /// \param x first argument + /// \param y second argument + /// \return function value stored in single-preicision + static expr atan2(float x, float y) { return expr(std::atan2(x, y)); } + + /// Hyperbolic sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr sinh(float arg) { return expr(std::sinh(arg)); } + + /// Hyperbolic cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr cosh(float arg) { return expr(std::cosh(arg)); } + + /// Hyperbolic tangent implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr tanh(float arg) { return expr(std::tanh(arg)); } + + /// Hyperbolic area sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr asinh(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::asinh(arg)); + #else + return expr((arg==-std::numeric_limits::infinity()) ? arg : static_cast(std::log(arg+std::sqrt(arg*arg+1.0)))); + #endif + } + + /// Hyperbolic area cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr acosh(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::acosh(arg)); + #else + return expr((arg<-1.0f) ? std::numeric_limits::quiet_NaN() : static_cast(std::log(arg+std::sqrt(arg*arg-1.0)))); + #endif + } + + /// Hyperbolic area tangent implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr atanh(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::atanh(arg)); + #else + return expr(static_cast(0.5*std::log((1.0+arg)/(1.0-arg)))); + #endif + } + + /// Error function implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr erf(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::erf(arg)); + #else + return expr(static_cast(erf(static_cast(arg)))); + #endif + } + + /// Complementary implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr erfc(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::erfc(arg)); + #else + return expr(static_cast(1.0-erf(static_cast(arg)))); + #endif + } + + /// Gamma logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr lgamma(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::lgamma(arg)); + #else + if(builtin_isinf(arg)) + return expr(std::numeric_limits::infinity()); + if(arg < 0.0f) + { + float i, f = std::modf(-arg, &i); + if(f == 0.0f) + return expr(std::numeric_limits::infinity()); + return expr(static_cast(1.1447298858494001741434273513531- + std::log(std::abs(std::sin(3.1415926535897932384626433832795*f)))-lgamma(1.0-arg))); + } + return expr(static_cast(lgamma(static_cast(arg)))); + #endif + } + + /// Gamma implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr tgamma(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::tgamma(arg)); + #else + if(arg == 0.0f) + return builtin_signbit(arg) ? expr(-std::numeric_limits::infinity()) : expr(std::numeric_limits::infinity()); + if(arg < 0.0f) + { + float i, f = std::modf(-arg, &i); + if(f == 0.0f) + return expr(std::numeric_limits::quiet_NaN()); + double value = 3.1415926535897932384626433832795 / (std::sin(3.1415926535897932384626433832795*f)*std::exp(lgamma(1.0-arg))); + return expr(static_cast((std::fmod(i, 2.0f)==0.0f) ? -value : value)); + } + if(builtin_isinf(arg)) + return expr(arg); + return expr(static_cast(std::exp(lgamma(static_cast(arg))))); + #endif + } + + /// Floor implementation. + /// \param arg value to round + /// \return rounded value + static half floor(half arg) { return half(binary, round_half(arg.data_)); } + + /// Ceiling implementation. + /// \param arg value to round + /// \return rounded value + static half ceil(half arg) { return half(binary, round_half(arg.data_)); } + + /// Truncation implementation. + /// \param arg value to round + /// \return rounded value + static half trunc(half arg) { return half(binary, round_half(arg.data_)); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static half round(half arg) { return half(binary, round_half_up(arg.data_)); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static long lround(half arg) { return detail::half2int_up(arg.data_); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static half rint(half arg) { return half(binary, round_half(arg.data_)); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static long lrint(half arg) { return detail::half2int(arg.data_); } + + #if HALF_ENABLE_CPP11_LONG_LONG + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static long long llround(half arg) { return detail::half2int_up(arg.data_); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static long long llrint(half arg) { return detail::half2int(arg.data_); } + #endif + + /// Decompression implementation. + /// \param arg number to decompress + /// \param exp address to store exponent at + /// \return normalized significant + static half frexp(half arg, int *exp) + { + int m = arg.data_ & 0x7FFF, e = -14; + if(m >= 0x7C00 || !m) + return *exp = 0, arg; + for(; m<0x400; m<<=1,--e) ; + return *exp = e+(m>>10), half(binary, (arg.data_&0x8000)|0x3800|(m&0x3FF)); + } + + /// Decompression implementation. + /// \param arg number to decompress + /// \param iptr address to store integer part at + /// \return fractional part + static half modf(half arg, half *iptr) + { + unsigned int e = arg.data_ & 0x7FFF; + if(e >= 0x6400) + return *iptr = arg, half(binary, arg.data_&(0x8000U|-(e>0x7C00))); + if(e < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + e >>= 10; + unsigned int mask = (1<<(25-e)) - 1, m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if(!m) + return half(binary, arg.data_&0x8000); + for(; m<0x400; m<<=1,--e) ; + return half(binary, static_cast((arg.data_&0x8000)|(e<<10)|(m&0x3FF))); + } + + /// Scaling implementation. + /// \param arg number to scale + /// \param exp power of two to scale by + /// \return scaled number + static half scalbln(half arg, long exp) + { + unsigned int m = arg.data_ & 0x7FFF; + if(m >= 0x7C00 || !m) + return arg; + for(; m<0x400; m<<=1,--exp) ; + exp += m >> 10; + uint16 value = arg.data_ & 0x8000; + if(exp > 30) + { + if(half::round_style == std::round_toward_zero) + value |= 0x7BFF; + else if(half::round_style == std::round_toward_infinity) + value |= 0x7C00 - (value>>15); + else if(half::round_style == std::round_toward_neg_infinity) + value |= 0x7BFF + (value>>15); + else + value |= 0x7C00; + } + else if(exp > 0) + value |= (exp<<10) | (m&0x3FF); + else if(exp > -11) + { + m = (m&0x3FF) | 0x400; + if(half::round_style == std::round_to_nearest) + { + m += 1 << -exp; + #if HALF_ROUND_TIES_TO_EVEN + m -= (m>>(1-exp)) & 1; + #endif + } + else if(half::round_style == std::round_toward_infinity) + m += ((value>>15)-1) & ((1<<(1-exp))-1U); + else if(half::round_style == std::round_toward_neg_infinity) + m += -(value>>15) & ((1<<(1-exp))-1U); + value |= m >> (1-exp); + } + else if(half::round_style == std::round_toward_infinity) + value -= (value>>15) - 1; + else if(half::round_style == std::round_toward_neg_infinity) + value += value >> 15; + return half(binary, value); + } + + /// Exponent implementation. + /// \param arg number to query + /// \return floating point exponent + static int ilogb(half arg) + { + int abs = arg.data_ & 0x7FFF; + if(!abs) + return FP_ILOGB0; + if(abs < 0x7C00) + { + int exp = (abs>>10) - 15; + if(abs < 0x400) + for(; abs<0x200; abs<<=1,--exp) ; + return exp; + } + if(abs > 0x7C00) + return FP_ILOGBNAN; + return INT_MAX; + } + + /// Exponent implementation. + /// \param arg number to query + /// \return floating point exponent + static half logb(half arg) + { + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(binary, 0xFC00); + if(abs < 0x7C00) + { + int exp = (abs>>10) - 15; + if(abs < 0x400) + for(; abs<0x200; abs<<=1,--exp) ; + uint16 bits = (exp<0) << 15; + if(exp) + { + unsigned int m = std::abs(exp) << 6, e = 18; + for(; m<0x400; m<<=1,--e) ; + bits |= (e<<10) + m; + } + return half(binary, bits); + } + if(abs > 0x7C00) + return arg; + return half(binary, 0x7C00); + } + + /// Enumeration implementation. + /// \param from number to increase/decrease + /// \param to direction to enumerate into + /// \return next representable number + static half nextafter(half from, half to) + { + uint16 fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if(fabs > 0x7C00) + return from; + if(tabs > 0x7C00 || from.data_ == to.data_ || !(fabs|tabs)) + return to; + if(!fabs) + return half(binary, (to.data_&0x8000)+1); + bool lt = ((fabs==from.data_) ? static_cast(fabs) : -static_cast(fabs)) < + ((tabs==to.data_) ? static_cast(tabs) : -static_cast(tabs)); + return half(binary, from.data_+(((from.data_>>15)^static_cast(lt))<<1)-1); + } + + /// Enumeration implementation. + /// \param from number to increase/decrease + /// \param to direction to enumerate into + /// \return next representable number + static half nexttoward(half from, long double to) + { + if(isnan(from)) + return from; + long double lfrom = static_cast(from); + if(builtin_isnan(to) || lfrom == to) + return half(static_cast(to)); + if(!(from.data_&0x7FFF)) + return half(binary, (static_cast(builtin_signbit(to))<<15)+1); + return half(binary, from.data_+(((from.data_>>15)^static_cast(lfrom0x3FF) ? ((abs>=0x7C00) ? ((abs>0x7C00) ? FP_NAN : FP_INFINITE) : FP_NORMAL) :FP_SUBNORMAL) : FP_ZERO; + } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if finite number + /// \retval false else + static bool isfinite(half arg) { return (arg.data_&0x7C00) != 0x7C00; } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if infinite number + /// \retval false else + static bool isinf(half arg) { return (arg.data_&0x7FFF) == 0x7C00; } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if not a number + /// \retval false else + static bool isnan(half arg) { return (arg.data_&0x7FFF) > 0x7C00; } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if normal number + /// \retval false else + static bool isnormal(half arg) { return ((arg.data_&0x7C00)!=0) & ((arg.data_&0x7C00)!=0x7C00); } + + /// Sign bit implementation. + /// \param arg value to check + /// \retval true if signed + /// \retval false if unsigned + static bool signbit(half arg) { return (arg.data_&0x8000) != 0; } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if operands equal + /// \retval false else + static bool isequal(half x, half y) { return (x.data_==y.data_ || !((x.data_|y.data_)&0x7FFF)) && !isnan(x); } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if operands not equal + /// \retval false else + static bool isnotequal(half x, half y) { return (x.data_!=y.data_ && ((x.data_|y.data_)&0x7FFF)) || isnan(x); } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x > \a y + /// \retval false else + static bool isgreater(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + return xabs<=0x7C00 && yabs<=0x7C00 && (((xabs==x.data_) ? xabs : -xabs) > ((yabs==y.data_) ? yabs : -yabs)); + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x >= \a y + /// \retval false else + static bool isgreaterequal(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + return xabs<=0x7C00 && yabs<=0x7C00 && (((xabs==x.data_) ? xabs : -xabs) >= ((yabs==y.data_) ? yabs : -yabs)); + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x < \a y + /// \retval false else + static bool isless(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + return xabs<=0x7C00 && yabs<=0x7C00 && (((xabs==x.data_) ? xabs : -xabs) < ((yabs==y.data_) ? yabs : -yabs)); + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x <= \a y + /// \retval false else + static bool islessequal(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + return xabs<=0x7C00 && yabs<=0x7C00 && (((xabs==x.data_) ? xabs : -xabs) <= ((yabs==y.data_) ? yabs : -yabs)); + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if either \a x > \a y nor \a x < \a y + /// \retval false else + static bool islessgreater(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + if(xabs > 0x7C00 || yabs > 0x7C00) + return false; + int a = (xabs==x.data_) ? xabs : -xabs, b = (yabs==y.data_) ? yabs : -yabs; + return a < b || a > b; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if operand unordered + /// \retval false else + static bool isunordered(half x, half y) { return isnan(x) || isnan(y); } + + private: + static double erf(double arg) + { + if(builtin_isinf(arg)) + return (arg<0.0) ? -1.0 : 1.0; + double x2 = arg * arg, ax2 = 0.147 * x2, value = std::sqrt(1.0-std::exp(-x2*(1.2732395447351626861510701069801+ax2)/(1.0+ax2))); + return builtin_signbit(arg) ? -value : value; + } + + static double lgamma(double arg) + { + double v = 1.0; + for(; arg<8.0; ++arg) v *= arg; + double w = 1.0 / (arg*arg); + return (((((((-0.02955065359477124183006535947712*w+0.00641025641025641025641025641026)*w+ + -0.00191752691752691752691752691753)*w+8.4175084175084175084175084175084e-4)*w+ + -5.952380952380952380952380952381e-4)*w+7.9365079365079365079365079365079e-4)*w+ + -0.00277777777777777777777777777778)*w+0.08333333333333333333333333333333)/arg + + 0.91893853320467274178032973640562 - std::log(v) - arg + (arg-0.5) * std::log(arg); + } + }; + + /// Wrapper for unary half-precision functions needing specialization for individual argument types. + /// \tparam T argument type + template struct unary_specialized + { + /// Negation implementation. + /// \param arg value to negate + /// \return negated value + static HALF_CONSTEXPR half negate(half arg) { return half(binary, arg.data_^0x8000); } + + /// Absolute value implementation. + /// \param arg function argument + /// \return absolute value + static half fabs(half arg) { return half(binary, arg.data_&0x7FFF); } + }; + template<> struct unary_specialized + { + static HALF_CONSTEXPR expr negate(float arg) { return expr(-arg); } + static expr fabs(float arg) { return expr(std::fabs(arg)); } + }; + + /// Wrapper for binary half-precision functions needing specialization for individual argument types. + /// \tparam T first argument type + /// \tparam U first argument type + template struct binary_specialized + { + /// Minimum implementation. + /// \param x first operand + /// \param y second operand + /// \return minimum value + static expr fmin(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::fmin(x, y)); + #else + if(builtin_isnan(x)) + return expr(y); + if(builtin_isnan(y)) + return expr(x); + return expr(std::min(x, y)); + #endif + } + + /// Maximum implementation. + /// \param x first operand + /// \param y second operand + /// \return maximum value + static expr fmax(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::fmax(x, y)); + #else + if(builtin_isnan(x)) + return expr(y); + if(builtin_isnan(y)) + return expr(x); + return expr(std::max(x, y)); + #endif + } + }; + template<> struct binary_specialized + { + static half fmin(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + if(xabs > 0x7C00) + return y; + if(yabs > 0x7C00) + return x; + return (((xabs==x.data_) ? xabs : -xabs) > ((yabs==y.data_) ? yabs : -yabs)) ? y : x; + } + static half fmax(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + if(xabs > 0x7C00) + return y; + if(yabs > 0x7C00) + return x; + return (((xabs==x.data_) ? xabs : -xabs) < ((yabs==y.data_) ? yabs : -yabs)) ? y : x; + } + }; + + /// Helper class for half casts. + /// This class template has to be specialized for all valid cast argument to define an appropriate static `cast` member + /// function and a corresponding `type` member denoting its return type. + /// \tparam T destination type + /// \tparam U source type + /// \tparam R rounding mode to use + template struct half_caster {}; + template struct half_caster + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); + #endif + + static half cast(U arg) { return cast_impl(arg, is_float()); }; + + private: + static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } + static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } + }; + template struct half_caster + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); + #endif + + static T cast(half arg) { return cast_impl(arg, is_float()); } + + private: + static T cast_impl(half arg, true_type) { return half2float(arg.data_); } + static T cast_impl(half arg, false_type) { return half2int(arg.data_); } + }; + template struct half_caster + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); + #endif + + static T cast(expr arg) { return cast_impl(arg, is_float()); } + + private: + static T cast_impl(float arg, true_type) { return static_cast(arg); } + static T cast_impl(half arg, false_type) { return half2int(arg.data_); } + }; + template struct half_caster + { + static half cast(half arg) { return arg; } + }; + template struct half_caster : half_caster {}; + + /// \name Comparison operators + /// \{ + + /// Comparison for equality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands equal + /// \retval false else + template typename enable::type operator==(T x, U y) { return functions::isequal(x, y); } + + /// Comparison for inequality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands not equal + /// \retval false else + template typename enable::type operator!=(T x, U y) { return functions::isnotequal(x, y); } + + /// Comparison for less than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else + template typename enable::type operator<(T x, U y) { return functions::isless(x, y); } + + /// Comparison for greater than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else + template typename enable::type operator>(T x, U y) { return functions::isgreater(x, y); } + + /// Comparison for less equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else + template typename enable::type operator<=(T x, U y) { return functions::islessequal(x, y); } + + /// Comparison for greater equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else + template typename enable::type operator>=(T x, U y) { return functions::isgreaterequal(x, y); } + + /// \} + /// \name Arithmetic operators + /// \{ + + /// Add halfs. + /// \param x left operand + /// \param y right operand + /// \return sum of half expressions + template typename enable::type operator+(T x, U y) { return functions::plus(x, y); } + + /// Subtract halfs. + /// \param x left operand + /// \param y right operand + /// \return difference of half expressions + template typename enable::type operator-(T x, U y) { return functions::minus(x, y); } + + /// Multiply halfs. + /// \param x left operand + /// \param y right operand + /// \return product of half expressions + template typename enable::type operator*(T x, U y) { return functions::multiplies(x, y); } + + /// Divide halfs. + /// \param x left operand + /// \param y right operand + /// \return quotient of half expressions + template typename enable::type operator/(T x, U y) { return functions::divides(x, y); } + + /// Identity. + /// \param arg operand + /// \return uncahnged operand + template HALF_CONSTEXPR typename enable::type operator+(T arg) { return arg; } + + /// Negation. + /// \param arg operand + /// \return negated operand + template HALF_CONSTEXPR typename enable::type operator-(T arg) { return unary_specialized::negate(arg); } + + /// \} + /// \name Input and output + /// \{ + + /// Output operator. + /// \param out output stream to write into + /// \param arg half expression to write + /// \return reference to output stream + template typename enable&,T>::type + operator<<(std::basic_ostream &out, T arg) { return functions::write(out, arg); } + + /// Input operator. + /// \param in input stream to read from + /// \param arg half to read into + /// \return reference to input stream + template std::basic_istream& + operator>>(std::basic_istream &in, half &arg) { return functions::read(in, arg); } + + /// \} + /// \name Basic mathematical operations + /// \{ + + /// Absolute value. + /// \param arg operand + /// \return absolute value of \a arg +// template typename enable::type abs(T arg) { return unary_specialized::fabs(arg); } + inline half abs(half arg) { return unary_specialized::fabs(arg); } + inline expr abs(expr arg) { return unary_specialized::fabs(arg); } + + /// Absolute value. + /// \param arg operand + /// \return absolute value of \a arg +// template typename enable::type fabs(T arg) { return unary_specialized::fabs(arg); } + inline half fabs(half arg) { return unary_specialized::fabs(arg); } + inline expr fabs(expr arg) { return unary_specialized::fabs(arg); } + + /// Remainder of division. + /// \param x first operand + /// \param y second operand + /// \return remainder of floating point division. +// template typename enable::type fmod(T x, U y) { return functions::fmod(x, y); } + inline expr fmod(half x, half y) { return functions::fmod(x, y); } + inline expr fmod(half x, expr y) { return functions::fmod(x, y); } + inline expr fmod(expr x, half y) { return functions::fmod(x, y); } + inline expr fmod(expr x, expr y) { return functions::fmod(x, y); } + + /// Remainder of division. + /// \param x first operand + /// \param y second operand + /// \return remainder of floating point division. +// template typename enable::type remainder(T x, U y) { return functions::remainder(x, y); } + inline expr remainder(half x, half y) { return functions::remainder(x, y); } + inline expr remainder(half x, expr y) { return functions::remainder(x, y); } + inline expr remainder(expr x, half y) { return functions::remainder(x, y); } + inline expr remainder(expr x, expr y) { return functions::remainder(x, y); } + + /// Remainder of division. + /// \param x first operand + /// \param y second operand + /// \param quo address to store some bits of quotient at + /// \return remainder of floating point division. +// template typename enable::type remquo(T x, U y, int *quo) { return functions::remquo(x, y, quo); } + inline expr remquo(half x, half y, int *quo) { return functions::remquo(x, y, quo); } + inline expr remquo(half x, expr y, int *quo) { return functions::remquo(x, y, quo); } + inline expr remquo(expr x, half y, int *quo) { return functions::remquo(x, y, quo); } + inline expr remquo(expr x, expr y, int *quo) { return functions::remquo(x, y, quo); } + + /// Fused multiply add. + /// \param x first operand + /// \param y second operand + /// \param z third operand + /// \return ( \a x * \a y ) + \a z rounded as one operation. +// template typename enable::type fma(T x, U y, V z) { return functions::fma(x, y, z); } + inline expr fma(half x, half y, half z) { return functions::fma(x, y, z); } + inline expr fma(half x, half y, expr z) { return functions::fma(x, y, z); } + inline expr fma(half x, expr y, half z) { return functions::fma(x, y, z); } + inline expr fma(half x, expr y, expr z) { return functions::fma(x, y, z); } + inline expr fma(expr x, half y, half z) { return functions::fma(x, y, z); } + inline expr fma(expr x, half y, expr z) { return functions::fma(x, y, z); } + inline expr fma(expr x, expr y, half z) { return functions::fma(x, y, z); } + inline expr fma(expr x, expr y, expr z) { return functions::fma(x, y, z); } + + /// Maximum of half expressions. + /// \param x first operand + /// \param y second operand + /// \return maximum of operands +// template typename result::type fmax(T x, U y) { return binary_specialized::fmax(x, y); } + inline half fmax(half x, half y) { return binary_specialized::fmax(x, y); } + inline expr fmax(half x, expr y) { return binary_specialized::fmax(x, y); } + inline expr fmax(expr x, half y) { return binary_specialized::fmax(x, y); } + inline expr fmax(expr x, expr y) { return binary_specialized::fmax(x, y); } + + /// Minimum of half expressions. + /// \param x first operand + /// \param y second operand + /// \return minimum of operands +// template typename result::type fmin(T x, U y) { return binary_specialized::fmin(x, y); } + inline half fmin(half x, half y) { return binary_specialized::fmin(x, y); } + inline expr fmin(half x, expr y) { return binary_specialized::fmin(x, y); } + inline expr fmin(expr x, half y) { return binary_specialized::fmin(x, y); } + inline expr fmin(expr x, expr y) { return binary_specialized::fmin(x, y); } + + /// Positive difference. + /// \param x first operand + /// \param y second operand + /// \return \a x - \a y or 0 if difference negative +// template typename enable::type fdim(T x, U y) { return functions::fdim(x, y); } + inline expr fdim(half x, half y) { return functions::fdim(x, y); } + inline expr fdim(half x, expr y) { return functions::fdim(x, y); } + inline expr fdim(expr x, half y) { return functions::fdim(x, y); } + inline expr fdim(expr x, expr y) { return functions::fdim(x, y); } + + /// Get NaN value. + /// \return quiet NaN + inline half nanh(const char*) { return functions::nanh(); } + + /// \} + /// \name Exponential functions + /// \{ + + /// Exponential function. + /// \param arg function argument + /// \return e raised to \a arg +// template typename enable::type exp(T arg) { return functions::exp(arg); } + inline expr exp(half arg) { return functions::exp(arg); } + inline expr exp(expr arg) { return functions::exp(arg); } + + /// Exponential minus one. + /// \param arg function argument + /// \return e raised to \a arg subtracted by 1 +// template typename enable::type expm1(T arg) { return functions::expm1(arg); } + inline expr expm1(half arg) { return functions::expm1(arg); } + inline expr expm1(expr arg) { return functions::expm1(arg); } + + /// Binary exponential. + /// \param arg function argument + /// \return 2 raised to \a arg +// template typename enable::type exp2(T arg) { return functions::exp2(arg); } + inline expr exp2(half arg) { return functions::exp2(arg); } + inline expr exp2(expr arg) { return functions::exp2(arg); } + + /// Natural logorithm. + /// \param arg function argument + /// \return logarithm of \a arg to base e +// template typename enable::type log(T arg) { return functions::log(arg); } + inline expr log(half arg) { return functions::log(arg); } + inline expr log(expr arg) { return functions::log(arg); } + + /// Common logorithm. + /// \param arg function argument + /// \return logarithm of \a arg to base 10 +// template typename enable::type log10(T arg) { return functions::log10(arg); } + inline expr log10(half arg) { return functions::log10(arg); } + inline expr log10(expr arg) { return functions::log10(arg); } + + /// Natural logorithm. + /// \param arg function argument + /// \return logarithm of \a arg plus 1 to base e +// template typename enable::type log1p(T arg) { return functions::log1p(arg); } + inline expr log1p(half arg) { return functions::log1p(arg); } + inline expr log1p(expr arg) { return functions::log1p(arg); } + + /// Binary logorithm. + /// \param arg function argument + /// \return logarithm of \a arg to base 2 +// template typename enable::type log2(T arg) { return functions::log2(arg); } + inline expr log2(half arg) { return functions::log2(arg); } + inline expr log2(expr arg) { return functions::log2(arg); } + + /// \} + /// \name Power functions + /// \{ + + /// Square root. + /// \param arg function argument + /// \return square root of \a arg +// template typename enable::type sqrt(T arg) { return functions::sqrt(arg); } + inline expr sqrt(half arg) { return functions::sqrt(arg); } + inline expr sqrt(expr arg) { return functions::sqrt(arg); } + + /// Cubic root. + /// \param arg function argument + /// \return cubic root of \a arg +// template typename enable::type cbrt(T arg) { return functions::cbrt(arg); } + inline expr cbrt(half arg) { return functions::cbrt(arg); } + inline expr cbrt(expr arg) { return functions::cbrt(arg); } + + /// Hypotenuse function. + /// \param x first argument + /// \param y second argument + /// \return square root of sum of squares without internal over- or underflows +// template typename enable::type hypot(T x, U y) { return functions::hypot(x, y); } + inline expr hypot(half x, half y) { return functions::hypot(x, y); } + inline expr hypot(half x, expr y) { return functions::hypot(x, y); } + inline expr hypot(expr x, half y) { return functions::hypot(x, y); } + inline expr hypot(expr x, expr y) { return functions::hypot(x, y); } + + /// Power function. + /// \param base first argument + /// \param exp second argument + /// \return \a base raised to \a exp +// template typename enable::type pow(T base, U exp) { return functions::pow(base, exp); } + inline expr pow(half base, half exp) { return functions::pow(base, exp); } + inline expr pow(half base, expr exp) { return functions::pow(base, exp); } + inline expr pow(expr base, half exp) { return functions::pow(base, exp); } + inline expr pow(expr base, expr exp) { return functions::pow(base, exp); } + + /// \} + /// \name Trigonometric functions + /// \{ + + /// Sine function. + /// \param arg function argument + /// \return sine value of \a arg +// template typename enable::type sin(T arg) { return functions::sin(arg); } + inline expr sin(half arg) { return functions::sin(arg); } + inline expr sin(expr arg) { return functions::sin(arg); } + + /// Cosine function. + /// \param arg function argument + /// \return cosine value of \a arg +// template typename enable::type cos(T arg) { return functions::cos(arg); } + inline expr cos(half arg) { return functions::cos(arg); } + inline expr cos(expr arg) { return functions::cos(arg); } + + /// Tangent function. + /// \param arg function argument + /// \return tangent value of \a arg +// template typename enable::type tan(T arg) { return functions::tan(arg); } + inline expr tan(half arg) { return functions::tan(arg); } + inline expr tan(expr arg) { return functions::tan(arg); } + + /// Arc sine. + /// \param arg function argument + /// \return arc sine value of \a arg +// template typename enable::type asin(T arg) { return functions::asin(arg); } + inline expr asin(half arg) { return functions::asin(arg); } + inline expr asin(expr arg) { return functions::asin(arg); } + + /// Arc cosine function. + /// \param arg function argument + /// \return arc cosine value of \a arg +// template typename enable::type acos(T arg) { return functions::acos(arg); } + inline expr acos(half arg) { return functions::acos(arg); } + inline expr acos(expr arg) { return functions::acos(arg); } + + /// Arc tangent function. + /// \param arg function argument + /// \return arc tangent value of \a arg +// template typename enable::type atan(T arg) { return functions::atan(arg); } + inline expr atan(half arg) { return functions::atan(arg); } + inline expr atan(expr arg) { return functions::atan(arg); } + + /// Arc tangent function. + /// \param x first argument + /// \param y second argument + /// \return arc tangent value +// template typename enable::type atan2(T x, U y) { return functions::atan2(x, y); } + inline expr atan2(half x, half y) { return functions::atan2(x, y); } + inline expr atan2(half x, expr y) { return functions::atan2(x, y); } + inline expr atan2(expr x, half y) { return functions::atan2(x, y); } + inline expr atan2(expr x, expr y) { return functions::atan2(x, y); } + + /// \} + /// \name Hyperbolic functions + /// \{ + + /// Hyperbolic sine. + /// \param arg function argument + /// \return hyperbolic sine value of \a arg +// template typename enable::type sinh(T arg) { return functions::sinh(arg); } + inline expr sinh(half arg) { return functions::sinh(arg); } + inline expr sinh(expr arg) { return functions::sinh(arg); } + + /// Hyperbolic cosine. + /// \param arg function argument + /// \return hyperbolic cosine value of \a arg +// template typename enable::type cosh(T arg) { return functions::cosh(arg); } + inline expr cosh(half arg) { return functions::cosh(arg); } + inline expr cosh(expr arg) { return functions::cosh(arg); } + + /// Hyperbolic tangent. + /// \param arg function argument + /// \return hyperbolic tangent value of \a arg +// template typename enable::type tanh(T arg) { return functions::tanh(arg); } + inline expr tanh(half arg) { return functions::tanh(arg); } + inline expr tanh(expr arg) { return functions::tanh(arg); } + + /// Hyperbolic area sine. + /// \param arg function argument + /// \return area sine value of \a arg +// template typename enable::type asinh(T arg) { return functions::asinh(arg); } + inline expr asinh(half arg) { return functions::asinh(arg); } + inline expr asinh(expr arg) { return functions::asinh(arg); } + + /// Hyperbolic area cosine. + /// \param arg function argument + /// \return area cosine value of \a arg +// template typename enable::type acosh(T arg) { return functions::acosh(arg); } + inline expr acosh(half arg) { return functions::acosh(arg); } + inline expr acosh(expr arg) { return functions::acosh(arg); } + + /// Hyperbolic area tangent. + /// \param arg function argument + /// \return area tangent value of \a arg +// template typename enable::type atanh(T arg) { return functions::atanh(arg); } + inline expr atanh(half arg) { return functions::atanh(arg); } + inline expr atanh(expr arg) { return functions::atanh(arg); } + + /// \} + /// \name Error and gamma functions + /// \{ + + /// Error function. + /// \param arg function argument + /// \return error function value of \a arg +// template typename enable::type erf(T arg) { return functions::erf(arg); } + inline expr erf(half arg) { return functions::erf(arg); } + inline expr erf(expr arg) { return functions::erf(arg); } + + /// Complementary error function. + /// \param arg function argument + /// \return 1 minus error function value of \a arg +// template typename enable::type erfc(T arg) { return functions::erfc(arg); } + inline expr erfc(half arg) { return functions::erfc(arg); } + inline expr erfc(expr arg) { return functions::erfc(arg); } + + /// Natural logarithm of gamma function. + /// \param arg function argument + /// \return natural logarith of gamma function for \a arg +// template typename enable::type lgamma(T arg) { return functions::lgamma(arg); } + inline expr lgamma(half arg) { return functions::lgamma(arg); } + inline expr lgamma(expr arg) { return functions::lgamma(arg); } + + /// Gamma function. + /// \param arg function argument + /// \return gamma function value of \a arg +// template typename enable::type tgamma(T arg) { return functions::tgamma(arg); } + inline expr tgamma(half arg) { return functions::tgamma(arg); } + inline expr tgamma(expr arg) { return functions::tgamma(arg); } + + /// \} + /// \name Rounding + /// \{ + + /// Nearest integer not less than half value. + /// \param arg half to round + /// \return nearest integer not less than \a arg +// template typename enable::type ceil(T arg) { return functions::ceil(arg); } + inline half ceil(half arg) { return functions::ceil(arg); } + inline half ceil(expr arg) { return functions::ceil(arg); } + + /// Nearest integer not greater than half value. + /// \param arg half to round + /// \return nearest integer not greater than \a arg +// template typename enable::type floor(T arg) { return functions::floor(arg); } + inline half floor(half arg) { return functions::floor(arg); } + inline half floor(expr arg) { return functions::floor(arg); } + + /// Nearest integer not greater in magnitude than half value. + /// \param arg half to round + /// \return nearest integer not greater in magnitude than \a arg +// template typename enable::type trunc(T arg) { return functions::trunc(arg); } + inline half trunc(half arg) { return functions::trunc(arg); } + inline half trunc(expr arg) { return functions::trunc(arg); } + + /// Nearest integer. + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases +// template typename enable::type round(T arg) { return functions::round(arg); } + inline half round(half arg) { return functions::round(arg); } + inline half round(expr arg) { return functions::round(arg); } + + /// Nearest integer. + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases +// template typename enable::type lround(T arg) { return functions::lround(arg); } + inline long lround(half arg) { return functions::lround(arg); } + inline long lround(expr arg) { return functions::lround(arg); } + + /// Nearest integer using half's internal rounding mode. + /// \param arg half expression to round + /// \return nearest integer using default rounding mode +// template typename enable::type nearbyint(T arg) { return functions::nearbyint(arg); } + inline half nearbyint(half arg) { return functions::rint(arg); } + inline half nearbyint(expr arg) { return functions::rint(arg); } + + /// Nearest integer using half's internal rounding mode. + /// \param arg half expression to round + /// \return nearest integer using default rounding mode +// template typename enable::type rint(T arg) { return functions::rint(arg); } + inline half rint(half arg) { return functions::rint(arg); } + inline half rint(expr arg) { return functions::rint(arg); } + + /// Nearest integer using half's internal rounding mode. + /// \param arg half expression to round + /// \return nearest integer using default rounding mode +// template typename enable::type lrint(T arg) { return functions::lrint(arg); } + inline long lrint(half arg) { return functions::lrint(arg); } + inline long lrint(expr arg) { return functions::lrint(arg); } + #if HALF_ENABLE_CPP11_LONG_LONG + /// Nearest integer. + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases +// template typename enable::type llround(T arg) { return functions::llround(arg); } + inline long long llround(half arg) { return functions::llround(arg); } + inline long long llround(expr arg) { return functions::llround(arg); } + + /// Nearest integer using half's internal rounding mode. + /// \param arg half expression to round + /// \return nearest integer using default rounding mode +// template typename enable::type llrint(T arg) { return functions::llrint(arg); } + inline long long llrint(half arg) { return functions::llrint(arg); } + inline long long llrint(expr arg) { return functions::llrint(arg); } + #endif + + /// \} + /// \name Floating point manipulation + /// \{ + + /// Decompress floating point number. + /// \param arg number to decompress + /// \param exp address to store exponent at + /// \return significant in range [0.5, 1) +// template typename enable::type frexp(T arg, int *exp) { return functions::frexp(arg, exp); } + inline half frexp(half arg, int *exp) { return functions::frexp(arg, exp); } + inline half frexp(expr arg, int *exp) { return functions::frexp(arg, exp); } + + /// Multiply by power of two. + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp +// template typename enable::type ldexp(T arg, int exp) { return functions::scalbln(arg, exp); } + inline half ldexp(half arg, int exp) { return functions::scalbln(arg, exp); } + inline half ldexp(expr arg, int exp) { return functions::scalbln(arg, exp); } + + /// Extract integer and fractional parts. + /// \param arg number to decompress + /// \param iptr address to store integer part at + /// \return fractional part +// template typename enable::type modf(T arg, half *iptr) { return functions::modf(arg, iptr); } + inline half modf(half arg, half *iptr) { return functions::modf(arg, iptr); } + inline half modf(expr arg, half *iptr) { return functions::modf(arg, iptr); } + + /// Multiply by power of two. + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp +// template typename enable::type scalbn(T arg, int exp) { return functions::scalbln(arg, exp); } + inline half scalbn(half arg, int exp) { return functions::scalbln(arg, exp); } + inline half scalbn(expr arg, int exp) { return functions::scalbln(arg, exp); } + + /// Multiply by power of two. + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp +// template typename enable::type scalbln(T arg, long exp) { return functions::scalbln(arg, exp); } + inline half scalbln(half arg, long exp) { return functions::scalbln(arg, exp); } + inline half scalbln(expr arg, long exp) { return functions::scalbln(arg, exp); } + + /// Extract exponent. + /// \param arg number to query + /// \return floating point exponent + /// \retval FP_ILOGB0 for zero + /// \retval FP_ILOGBNAN for NaN + /// \retval MAX_INT for infinity +// template typename enable::type ilogb(T arg) { return functions::ilogb(arg); } + inline int ilogb(half arg) { return functions::ilogb(arg); } + inline int ilogb(expr arg) { return functions::ilogb(arg); } + + /// Extract exponent. + /// \param arg number to query + /// \return floating point exponent +// template typename enable::type logb(T arg) { return functions::logb(arg); } + inline half logb(half arg) { return functions::logb(arg); } + inline half logb(expr arg) { return functions::logb(arg); } + + /// Next representable value. + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to +// template typename enable::type nextafter(T from, U to) { return functions::nextafter(from, to); } + inline half nextafter(half from, half to) { return functions::nextafter(from, to); } + inline half nextafter(half from, expr to) { return functions::nextafter(from, to); } + inline half nextafter(expr from, half to) { return functions::nextafter(from, to); } + inline half nextafter(expr from, expr to) { return functions::nextafter(from, to); } + + /// Next representable value. + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to +// template typename enable::type nexttoward(T from, long double to) { return functions::nexttoward(from, to); } + inline half nexttoward(half from, long double to) { return functions::nexttoward(from, to); } + inline half nexttoward(expr from, long double to) { return functions::nexttoward(from, to); } + + /// Take sign. + /// \param x value to change sign for + /// \param y value to take sign from + /// \return value equal to \a x in magnitude and to \a y in sign +// template typename enable::type copysign(T x, U y) { return functions::copysign(x, y); } + inline half copysign(half x, half y) { return functions::copysign(x, y); } + inline half copysign(half x, expr y) { return functions::copysign(x, y); } + inline half copysign(expr x, half y) { return functions::copysign(x, y); } + inline half copysign(expr x, expr y) { return functions::copysign(x, y); } + + /// \} + /// \name Floating point classification + /// \{ + + + /// Classify floating point value. + /// \param arg number to classify + /// \retval FP_ZERO for positive and negative zero + /// \retval FP_SUBNORMAL for subnormal numbers + /// \retval FP_INFINITY for positive and negative infinity + /// \retval FP_NAN for NaNs + /// \retval FP_NORMAL for all other (normal) values +// template typename enable::type fpclassify(T arg) { return functions::fpclassify(arg); } + inline int fpclassify(half arg) { return functions::fpclassify(arg); } + inline int fpclassify(expr arg) { return functions::fpclassify(arg); } + + /// Check if finite number. + /// \param arg number to check + /// \retval true if neither infinity nor NaN + /// \retval false else +// template typename enable::type isfinite(T arg) { return functions::isfinite(arg); } + inline bool isfinite(half arg) { return functions::isfinite(arg); } + inline bool isfinite(expr arg) { return functions::isfinite(arg); } + + /// Check for infinity. + /// \param arg number to check + /// \retval true for positive or negative infinity + /// \retval false else +// template typename enable::type isinf(T arg) { return functions::isinf(arg); } + inline bool isinf(half arg) { return functions::isinf(arg); } + inline bool isinf(expr arg) { return functions::isinf(arg); } + + /// Check for NaN. + /// \param arg number to check + /// \retval true for NaNs + /// \retval false else +// template typename enable::type isnan(T arg) { return functions::isnan(arg); } + inline bool isnan(half arg) { return functions::isnan(arg); } + inline bool isnan(expr arg) { return functions::isnan(arg); } + + /// Check if normal number. + /// \param arg number to check + /// \retval true if normal number + /// \retval false if either subnormal, zero, infinity or NaN +// template typename enable::type isnormal(T arg) { return functions::isnormal(arg); } + inline bool isnormal(half arg) { return functions::isnormal(arg); } + inline bool isnormal(expr arg) { return functions::isnormal(arg); } + + /// Check sign. + /// \param arg number to check + /// \retval true for negative number + /// \retval false for positive number +// template typename enable::type signbit(T arg) { return functions::signbit(arg); } + inline bool signbit(half arg) { return functions::signbit(arg); } + inline bool signbit(expr arg) { return functions::signbit(arg); } + + /// \} + /// \name Comparison + /// \{ + + /// Comparison for greater than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else +// template typename enable::type isgreater(T x, U y) { return functions::isgreater(x, y); } + inline bool isgreater(half x, half y) { return functions::isgreater(x, y); } + inline bool isgreater(half x, expr y) { return functions::isgreater(x, y); } + inline bool isgreater(expr x, half y) { return functions::isgreater(x, y); } + inline bool isgreater(expr x, expr y) { return functions::isgreater(x, y); } + + /// Comparison for greater equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else +// template typename enable::type isgreaterequal(T x, U y) { return functions::isgreaterequal(x, y); } + inline bool isgreaterequal(half x, half y) { return functions::isgreaterequal(x, y); } + inline bool isgreaterequal(half x, expr y) { return functions::isgreaterequal(x, y); } + inline bool isgreaterequal(expr x, half y) { return functions::isgreaterequal(x, y); } + inline bool isgreaterequal(expr x, expr y) { return functions::isgreaterequal(x, y); } + + /// Comparison for less than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else +// template typename enable::type isless(T x, U y) { return functions::isless(x, y); } + inline bool isless(half x, half y) { return functions::isless(x, y); } + inline bool isless(half x, expr y) { return functions::isless(x, y); } + inline bool isless(expr x, half y) { return functions::isless(x, y); } + inline bool isless(expr x, expr y) { return functions::isless(x, y); } + + /// Comparison for less equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else +// template typename enable::type islessequal(T x, U y) { return functions::islessequal(x, y); } + inline bool islessequal(half x, half y) { return functions::islessequal(x, y); } + inline bool islessequal(half x, expr y) { return functions::islessequal(x, y); } + inline bool islessequal(expr x, half y) { return functions::islessequal(x, y); } + inline bool islessequal(expr x, expr y) { return functions::islessequal(x, y); } + + /// Comarison for less or greater. + /// \param x first operand + /// \param y second operand + /// \retval true if either less or greater + /// \retval false else +// template typename enable::type islessgreater(T x, U y) { return functions::islessgreater(x, y); } + inline bool islessgreater(half x, half y) { return functions::islessgreater(x, y); } + inline bool islessgreater(half x, expr y) { return functions::islessgreater(x, y); } + inline bool islessgreater(expr x, half y) { return functions::islessgreater(x, y); } + inline bool islessgreater(expr x, expr y) { return functions::islessgreater(x, y); } + + /// Check if unordered. + /// \param x first operand + /// \param y second operand + /// \retval true if unordered (one or two NaN operands) + /// \retval false else +// template typename enable::type isunordered(T x, U y) { return functions::isunordered(x, y); } + inline bool isunordered(half x, half y) { return functions::isunordered(x, y); } + inline bool isunordered(half x, expr y) { return functions::isunordered(x, y); } + inline bool isunordered(expr x, half y) { return functions::isunordered(x, y); } + inline bool isunordered(expr x, expr y) { return functions::isunordered(x, y); } + + /// \name Casting + /// \{ + + /// Cast to or from half-precision floating point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the given rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// It uses the default rounding mode. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s is just a no-op. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + template T half_cast(U arg) { return half_caster::cast(arg); } + + /// Cast to or from half-precision floating point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the given rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s is just a no-op. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam R rounding mode to use. + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + template T half_cast(U arg) { return half_caster::cast(arg); } + /// \} + } + + using detail::operator==; + using detail::operator!=; + using detail::operator<; + using detail::operator>; + using detail::operator<=; + using detail::operator>=; + using detail::operator+; + using detail::operator-; + using detail::operator*; + using detail::operator/; + using detail::operator<<; + using detail::operator>>; + + using detail::abs; + using detail::fabs; + using detail::fmod; + using detail::remainder; + using detail::remquo; + using detail::fma; + using detail::fmax; + using detail::fmin; + using detail::fdim; + using detail::nanh; + using detail::exp; + using detail::expm1; + using detail::exp2; + using detail::log; + using detail::log10; + using detail::log1p; + using detail::log2; + using detail::sqrt; + using detail::cbrt; + using detail::hypot; + using detail::pow; + using detail::sin; + using detail::cos; + using detail::tan; + using detail::asin; + using detail::acos; + using detail::atan; + using detail::atan2; + using detail::sinh; + using detail::cosh; + using detail::tanh; + using detail::asinh; + using detail::acosh; + using detail::atanh; + using detail::erf; + using detail::erfc; + using detail::lgamma; + using detail::tgamma; + using detail::ceil; + using detail::floor; + using detail::trunc; + using detail::round; + using detail::lround; + using detail::nearbyint; + using detail::rint; + using detail::lrint; +#if HALF_ENABLE_CPP11_LONG_LONG + using detail::llround; + using detail::llrint; +#endif + using detail::frexp; + using detail::ldexp; + using detail::modf; + using detail::scalbn; + using detail::scalbln; + using detail::ilogb; + using detail::logb; + using detail::nextafter; + using detail::nexttoward; + using detail::copysign; + using detail::fpclassify; + using detail::isfinite; + using detail::isinf; + using detail::isnan; + using detail::isnormal; + using detail::signbit; + using detail::isgreater; + using detail::isgreaterequal; + using detail::isless; + using detail::islessequal; + using detail::islessgreater; + using detail::isunordered; + + using detail::half_cast; +} + + +/// Extensions to the C++ standard library. +namespace std +{ + /// Numeric limits for half-precision floats. + /// Because of the underlying single-precision implementation of many operations, it inherits some properties from + /// `std::numeric_limits`. + template<> class numeric_limits : public numeric_limits + { + public: + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = true; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Rounding mode. + /// Due to the mix of internal single-precision computations (using the rounding mode of the underlying + /// single-precision implementation) with the rounding mode of the single-to-half conversions, the actual rounding + /// mode might be `std::round_indeterminate` if the default half-precision rounding mode doesn't match the + /// single-precision rounding mode. + static HALF_CONSTEXPR_CONST float_round_style round_style = (std::numeric_limits::round_style== + half_float::half::round_style) ? half_float::half::round_style : round_indeterminate; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 11; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 5; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -4; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 16; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 4; + + /// Smallest positive normal value. + static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x0400); } + + /// Smallest finite value. + static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0xFBFF); } + + /// Largest finite value. + static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7BFF); } + + /// Difference between one and next representable value. + static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x1400); } + + /// Maximum rounding error. + static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW + { return half_float::half(half_float::detail::binary, (round_style==std::round_to_nearest) ? 0x3800 : 0x3C00); } + + /// Positive infinity. + static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7C00); } + + /// Quiet NaN. + static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7FFF); } + + /// Signalling NaN. + static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7DFF); } + + /// Smallest positive subnormal value. + static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x0001); } + }; + +#if HALF_ENABLE_CPP11_HASH + /// Hash function for half-precision floats. + /// This is only defined if C++11 `std::hash` is supported and enabled. + template<> struct hash //: unary_function + { + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const + { return hash()(static_cast(arg.data_)&-(arg.data_!=0x8000)); } + }; +#endif +} + + +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#ifdef HALF_POP_WARNINGS + #pragma warning(pop) + #undef HALF_POP_WARNINGS +#endif +#endif diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 44027e2851b..cab8b8fd8a3 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -29,6 +29,7 @@ #endif #include "caffe/util/device_alternate.hpp" +#include "caffe/util/fp16.hpp" // Convert macro to string #define STRINGIFY(m) #m @@ -50,12 +51,17 @@ private:\ classname& operator=(const classname&) // Instantiate a class with float and double specifications. +#ifdef HAS_HALF_SUPPORT #define INSTANTIATE_CLASS(classname) \ char gInstantiationGuard##classname; \ + template class classname; \ template class classname; \ - template class classname + template class classname; #define INSTANTIATE_LAYER_GPU_FORWARD(classname) \ + template void classname::Forward_gpu( \ + const std::vector*>& bottom, \ + const std::vector*>& top); \ template void classname::Forward_gpu( \ const std::vector*>& bottom, \ const std::vector*>& top); \ @@ -64,6 +70,10 @@ private:\ const std::vector*>& top); #define INSTANTIATE_LAYER_GPU_BACKWARD(classname) \ + template void classname::Backward_gpu( \ + const std::vector*>& top, \ + const std::vector& propagate_down, \ + const std::vector*>& bottom); \ template void classname::Backward_gpu( \ const std::vector*>& top, \ const std::vector& propagate_down, \ @@ -72,6 +82,30 @@ private:\ const std::vector*>& top, \ const std::vector& propagate_down, \ const std::vector*>& bottom) +#else +#define INSTANTIATE_CLASS(classname) \ + char gInstantiationGuard##classname; \ + template class classname; \ + template class classname; \ + +#define INSTANTIATE_LAYER_GPU_FORWARD(classname) \ + template void classname::Forward_gpu( \ + const std::vector*>& bottom, \ + const std::vector*>& top); \ + template void classname::Forward_gpu( \ + const std::vector*>& bottom, \ + const std::vector*>& top); + +#define INSTANTIATE_LAYER_GPU_BACKWARD(classname) \ + template void classname::Backward_gpu( \ + const std::vector*>& top, \ + const std::vector& propagate_down, \ + const std::vector*>& bottom); \ + template void classname::Backward_gpu( \ + const std::vector*>& top, \ + const std::vector& propagate_down, \ + const std::vector*>& bottom) +#endif #define INSTANTIATE_LAYER_GPU_FUNCS(classname) \ INSTANTIATE_LAYER_GPU_FORWARD(classname); \ @@ -226,19 +260,17 @@ class Caffe { #endif #endif // !CPU_ONLY shared_ptr random_generator_; - Brew mode_; + // The shared ptrs are being referenced on every thread, + // while the default device will be handled thread local + shared_ptr cpu_device_; + device* default_device_; + static vector > devices_; // Parallel training int solver_count_; int solver_rank_; bool multiprocess_; - - // The shared ptrs are being referenced on every thread, - // while the default device will be handled thread local - static vector > devices_; - shared_ptr cpu_device_; - device* default_device_; }; } // namespace caffe diff --git a/include/caffe/device.hpp b/include/caffe/device.hpp index fbb132c0d8c..f7ca666145b 100644 --- a/include/caffe/device.hpp +++ b/include/caffe/device.hpp @@ -65,6 +65,7 @@ class device { Backend backend_; uint_tp memory_usage_; uint_tp peak_memory_usage_; + std::vector > > buff_h_; std::vector > > buff_f_; std::vector > > buff_d_; bool host_unified_; diff --git a/include/caffe/greentea/cl_kernels.hpp b/include/caffe/greentea/cl_kernels.hpp index 0bb31369c0a..95b85156fe1 100755 --- a/include/caffe/greentea/cl_kernels.hpp +++ b/include/caffe/greentea/cl_kernels.hpp @@ -12,6 +12,7 @@ #include "viennacl/ocl/platform.hpp" namespace caffe { viennacl::ocl::program & RegisterKernels(viennacl::ocl::context *ctx); +template viennacl::ocl::program & submit_conv_spatial_program( viennacl::ocl::context *ctx, string name, string options); std::string getKernelBundleName(int index); diff --git a/include/caffe/greentea/greentea.hpp b/include/caffe/greentea/greentea.hpp index 92158edd9e4..13931c47018 100644 --- a/include/caffe/greentea/greentea.hpp +++ b/include/caffe/greentea/greentea.hpp @@ -94,8 +94,8 @@ struct is_same { // Macro to select the single (_float) or double (_double) precision kernel #define CL_KERNEL_SELECT(kernel) \ is_same::value ? \ - kernel "_float" : \ - kernel "_double" + kernel "_float" : (is_same::value ?\ + kernel "_double" : kernel "_half") #endif diff --git a/include/caffe/greentea/libdnn_tuner.hpp b/include/caffe/greentea/libdnn_tuner.hpp index dd6cea99645..733b36f7dc4 100644 --- a/include/caffe/greentea/libdnn_tuner.hpp +++ b/include/caffe/greentea/libdnn_tuner.hpp @@ -105,11 +105,11 @@ class LibDNNTunerParam { void add_constraint(std::shared_ptr constraint); protected: + std::vector> constraints_; LibDNNTuner* tuner_; std::string name_; int_tp curr_idx_; int_tp def_idx_; - std::vector> constraints_; }; class LibDNNTunerParamInt: public LibDNNTunerParam { diff --git a/include/caffe/layer_factory.hpp b/include/caffe/layer_factory.hpp index 5b306216e00..018d06d714a 100644 --- a/include/caffe/layer_factory.hpp +++ b/include/caffe/layer_factory.hpp @@ -83,10 +83,16 @@ class LayerRegisterer { LayerRegisterer(const string& type, shared_ptr > (*creator)(const LayerParameter&)); }; - +#ifdef HAS_HALF_SUPPORT +#define REGISTER_LAYER_CREATOR(type, creator) \ + static LayerRegisterer g_creator_h_##type(#type, creator); \ + static LayerRegisterer g_creator_f_##type(#type, creator); \ + static LayerRegisterer g_creator_d_##type(#type, creator) +#else #define REGISTER_LAYER_CREATOR(type, creator) \ static LayerRegisterer g_creator_f_##type(#type, creator); \ - static LayerRegisterer g_creator_d_##type(#type, creator) \ + static LayerRegisterer g_creator_d_##type(#type, creator) +#endif #define REGISTER_LAYER_CLASS(type) \ template \ diff --git a/include/caffe/test/test_caffe_main.hpp b/include/caffe/test/test_caffe_main.hpp index 77c5e627b73..19fb295803c 100644 --- a/include/caffe/test/test_caffe_main.hpp +++ b/include/caffe/test/test_caffe_main.hpp @@ -41,7 +41,11 @@ class MultiDeviceTest : public ::testing::Test { virtual ~MultiDeviceTest() { RemoveCaffeTempDir(); } }; +#ifdef HAS_HALF_SUPPORT +typedef ::testing::Types TestDtypes; +#else typedef ::testing::Types TestDtypes; +#endif template struct CPUDevice { @@ -73,16 +77,29 @@ template class GPUDeviceTest : public MultiDeviceTest > { }; +#ifdef HAS_HALF_SUPPORT +typedef ::testing::Types, CPUDevice, CPUDevice, + GPUDevice, GPUDevice, GPUDevice > + TestDtypesAndDevices; + +typedef ::testing::Types, + CPUDevice, + GPUDevice, + GPUDevice> + TestFloatAndDevices; +#else typedef ::testing::Types, CPUDevice, GPUDevice, GPUDevice > TestDtypesAndDevices; typedef ::testing::Types, - GPUDevice > + GPUDevice> TestFloatAndDevices; #endif +#endif + #if defined(USE_LEVELDB) && defined(USE_LMDB) struct TypeLevelDB { static DataParameter_DB backend; @@ -105,6 +122,19 @@ template <> bool isSupported >(void); template <> +bool isSupported >(void); +#ifdef HAS_HALF_SUPPORT +template <> +bool isSupported(void); + +template <> +bool isSupported >(void); + +template <> +bool isSupported >(void); +#endif + +template <> bool isSupported(void); template <> @@ -113,9 +143,6 @@ bool isSupported >(void); template <> bool isSupported >(void); -template <> -bool isSupported >(void); - #if defined(USE_LEVELDB) && defined(USE_LMDB) template <> bool isSupported(void); @@ -154,6 +181,7 @@ bool isSupported(void); template \ void GTEST_TEST_CLASS_NAME_(CaseName, TestName) \ ::TestBody_Impl() + #endif diff --git a/include/caffe/test/test_gradient_check_util.hpp b/include/caffe/test/test_gradient_check_util.hpp index 8ec9de992f5..4d85246b166 100644 --- a/include/caffe/test/test_gradient_check_util.hpp +++ b/include/caffe/test/test_gradient_check_util.hpp @@ -26,6 +26,11 @@ class GradientChecker { const Dtype kink_range = -1) : stepsize_(stepsize), threshold_(threshold), seed_(seed), kink_(kink), kink_range_(kink_range) { + if (std::is_same::value) { + //stepsize_ = 10 * stepsize; + threshold_ = 100 * threshold; + stepsize_ = stepsize; + } } // Checks the gradient of a layer, with provided bottom layers and top // layers. diff --git a/include/caffe/util/fp16.hpp b/include/caffe/util/fp16.hpp new file mode 100644 index 00000000000..f2738ede3fa --- /dev/null +++ b/include/caffe/util/fp16.hpp @@ -0,0 +1,56 @@ +#ifndef CAFFE_UTIL_FP16_H_ +#define CAFFE_UTIL_FP16_H_ + +#include "3rdparty/half/half.hpp" +using half_float::half; + +#define HALF_MAX 0x1.ffcp15f +#define HALF_MIN 0x1.0p-14f + +#include + +inline float fixup_arg_type(float v) { + return v; +} + +inline float fixup_arg_type(half_float::half v) { + return float(v); +} + +inline double fixup_arg_type(double v) { + return v; +} + +inline int fixup_arg_type(int v) { + return v; +} + +inline unsigned int fixup_arg_type(unsigned int v) { + return v; +} + +inline long long fixup_arg_type(long long v) { + return v; +} + +inline unsigned long long fixup_arg_type(unsigned long long v) { + return v; +} + +inline long fixup_arg_type(long v) { + return v; +} + +inline unsigned long fixup_arg_type(unsigned long v) { + return v; +} + +inline float fixup_arg_type(const half_float::detail::expr& expr) { + return float(expr); +} + +inline const void * fixup_arg_type(const boost::shared_ptr& share_ptr) { + return (const void*)share_ptr.get(); +} + +#endif // CAFFE_UTIL_HDF5_H_ diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index c54eb9f9810..9718ba72a61 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -84,6 +84,7 @@ void caffe_rng_uniform(const int_tp n, uint_tp* r); template void caffe_rng_uniform(const int_tp n, const Dtype a, const Dtype b, Dtype* r); + template void caffe_rng_gaussian(const int_tp n, const Dtype mu, const Dtype sigma, Dtype* r); diff --git a/include/caffe/util/mkl_alternate.hpp b/include/caffe/util/mkl_alternate.hpp index d7881a1b7c1..188a4199f3a 100644 --- a/include/caffe/util/mkl_alternate.hpp +++ b/include/caffe/util/mkl_alternate.hpp @@ -86,6 +86,20 @@ DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]) // In addition, MKL comes with an additional function axpby that is not present // in standard blas. We will simply use a two-step (inefficient, of course) way // to mimic that. +#ifdef HAS_HALF_SUPPORT +inline void cblas_haxpby(const int_tp N, const half alpha, const half* X, + const int_tp incX, const half beta, half* Y, + const int_tp incY) { + + for (int_tp n = 0; n < N; n++) + Y[n * incY] *= beta; + + for (int_tp n = 0; n < N; n++) { + Y[n * incY] += alpha * X[n * incX]; + } +} +#endif + inline void cblas_saxpby(const int_tp N, const float alpha, const float* X, const int_tp incX, const float beta, float* Y, const int_tp incY) { diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 4adbe6497d9..6fe876455e4 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -248,7 +248,7 @@ template<> int_tp Blob::asum_data() const { template Dtype Blob::asum_data() const { if (!data_) { - return 0; + return (Dtype)0; } switch (data_->head()) { case SyncedMemory::HEAD_AT_CPU: @@ -644,7 +644,27 @@ void Blob::FromProto(const BlobProto& proto, bool reshape) { } } } - +#ifdef HAS_HALF_SUPPORT +template <> +void Blob::ToProto(BlobProto* proto, bool write_diff) const { + proto->clear_shape(); + for (int_tp i = 0; i < shape_.size(); ++i) { + proto->mutable_shape()->add_dim(shape_[i]); + } + proto->clear_double_data(); + proto->clear_double_diff(); + const half* data_vec = cpu_data(); + for (int_tp i = 0; i < count_; ++i) { + proto->add_double_data(data_vec[i]); + } + if (write_diff) { + const half* diff_vec = cpu_diff(); + for (int_tp i = 0; i < count_; ++i) { + proto->add_double_diff(diff_vec[i]); + } + } +} +#endif template <> void Blob::ToProto(BlobProto* proto, bool write_diff) const { proto->clear_shape(); diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index 3743ab06c18..7896517cfa7 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -52,6 +52,9 @@ size_t dtsizeof(DataType data_type) { } } +template<> DataType dtypeof() { + return DFP16; +} template<> DataType dtypeof() { return DFP32; } diff --git a/src/caffe/device.cpp b/src/caffe/device.cpp index 2581d6dd426..12051cc5827 100644 --- a/src/caffe/device.cpp +++ b/src/caffe/device.cpp @@ -100,7 +100,16 @@ int device::num_queues() { } return 1; } - +#ifdef HAS_HALF_SUPPORT +template<> +shared_ptr > device::Buffer(int id) { + if (buff_h_.size() <= id) { + shared_ptr > blob_pointer(new Blob(this)); + buff_h_.push_back(blob_pointer); + } + return buff_h_[id]; +} +#endif template<> shared_ptr > device::Buffer(int id) { if (buff_f_.size() <= id) { diff --git a/src/caffe/greentea/cl_headers/header.cl b/src/caffe/greentea/cl_headers/header.cl index b052cf3af29..ec511b9fc30 100644 --- a/src/caffe/greentea/cl_headers/header.cl +++ b/src/caffe/greentea/cl_headers/header.cl @@ -31,6 +31,7 @@ #define TYPE_FLOAT 1 #define TYPE_DOUBLE 2 +#define TYPE_HALF 3 #if defined(cl_khr_fp64) #pragma OPENCL EXTENSION cl_khr_fp64 : enable @@ -44,6 +45,11 @@ #endif //DISABLE_DOUBLE_SUPPORT #endif +#if defined(cl_khr_fp16) +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define HALF_SUPPORT_AVAILABLE +#endif + #if defined(cl_khr_int64_base_atomics) #pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable #define ATOMICS_64_AVAILABLE @@ -58,4 +64,3 @@ #pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable #define ATOMICS_32_AVAILABLE #endif - diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index 455f03103b7..ab7cedcfab6 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -13,10 +13,10 @@ #endif // DISABLE_DOUBLE_SUPPORT namespace caffe { #ifdef USE_INDEX_64 -static std::string header = DOUBLE_SUPPORT "#ifndef __OPENCL_VERSION__\n#define __kernel\n#define __global\n#define __constant\n#define __local\n#define get_global_id(x) 0\n#define get_global_size(x) 0\n#define get_local_id(x) 0\n#define get_local_size(x) 0\n#define FLT_MAX 0\n#define FLT_MIN 0\n#define cl_khr_fp64\n#define cl_amd_fp64\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#define CLK_LOCAL_MEM_FENCE\n#define CLK_GLOBAL_MEM_FENCE\n#define Dtype float\n#define barrier(x)\n#define atomic_cmpxchg(x, y, z) x\n#define signbit(x) x\n#define int_tp long\n#define uint_tp unsigned long\n#define int_tpc long\n#define uint_tpc unsigned long\n#endif\n\n#define CONCAT(A,B) A##_##B\n#define TEMPLATE(name,type) CONCAT(name,type)\n\n#define TYPE_FLOAT 1\n#define TYPE_DOUBLE 2\n\n#if defined(cl_khr_fp64)\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#endif\n\n#if defined(cl_khr_int64_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable\n#define ATOMICS_64_AVAILABLE\n#endif\n\n#if defined(cl_khr_int32_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_int32_base_atomics : enable\n#define ATOMICS_32_AVAILABLE\n#endif\n\n#if defined(cl_khr_global_int32_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n#define ATOMICS_32_AVAILABLE\n#endif"; // NOLINT +static std::string header = DOUBLE_SUPPORT "#ifndef __OPENCL_VERSION__\n#define __kernel\n#define __global\n#define __constant\n#define __local\n#define get_global_id(x) 0\n#define get_global_size(x) 0\n#define get_local_id(x) 0\n#define get_local_size(x) 0\n#define FLT_MAX 0\n#define FLT_MIN 0\n#define cl_khr_fp64\n#define cl_amd_fp64\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#define CLK_LOCAL_MEM_FENCE\n#define CLK_GLOBAL_MEM_FENCE\n#define Dtype float\n#define barrier(x)\n#define atomic_cmpxchg(x, y, z) x\n#define signbit(x) x\n#define int_tp long\n#define uint_tp unsigned long\n#define int_tpc long\n#define uint_tpc unsigned long\n#endif\n\n#define CONCAT(A,B) A##_##B\n#define TEMPLATE(name,type) CONCAT(name,type)\n\n#define TYPE_FLOAT 1\n#define TYPE_DOUBLE 2\n#define TYPE_HALF 3\n\n#if defined(cl_khr_fp64)\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#endif\n\n#if defined(cl_khr_fp16)\n#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#define HALF_SUPPORT_AVAILABLE\n#endif\n\n#if defined(cl_khr_int64_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable\n#define ATOMICS_64_AVAILABLE\n#endif\n\n#if defined(cl_khr_int32_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_int32_base_atomics : enable\n#define ATOMICS_32_AVAILABLE\n#endif\n\n#if defined(cl_khr_global_int32_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n#define ATOMICS_32_AVAILABLE\n#endif"; // NOLINT static std::string definitions_64 = DOUBLE_SUPPORT "// Types used for parameters, offset computations and so on\n#define int_tp long\n#define uint_tp unsigned long\n\n// Definitions used to cast the types above as needed\n#define int_tpc long\n#define uint_tpc unsigned long"; // NOLINT #else -static std::string header = DOUBLE_SUPPORT "#ifndef __OPENCL_VERSION__\n#define __kernel\n#define __global\n#define __constant\n#define __local\n#define get_global_id(x) 0\n#define get_global_size(x) 0\n#define get_local_id(x) 0\n#define get_local_size(x) 0\n#define FLT_MAX 0\n#define FLT_MIN 0\n#define cl_khr_fp64\n#define cl_amd_fp64\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#define CLK_LOCAL_MEM_FENCE\n#define CLK_GLOBAL_MEM_FENCE\n#define Dtype float\n#define barrier(x)\n#define atomic_cmpxchg(x, y, z) x\n#define signbit(x) x\n#define int_tp long\n#define uint_tp unsigned long\n#define int_tpc long\n#define uint_tpc unsigned long\n#endif\n\n#define CONCAT(A,B) A##_##B\n#define TEMPLATE(name,type) CONCAT(name,type)\n\n#define TYPE_FLOAT 1\n#define TYPE_DOUBLE 2\n\n#if defined(cl_khr_fp64)\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#endif\n\n#if defined(cl_khr_int64_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable\n#define ATOMICS_64_AVAILABLE\n#endif\n\n#if defined(cl_khr_int32_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_int32_base_atomics : enable\n#define ATOMICS_32_AVAILABLE\n#endif\n\n#if defined(cl_khr_global_int32_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n#define ATOMICS_32_AVAILABLE\n#endif"; // NOLINT +static std::string header = DOUBLE_SUPPORT "#ifndef __OPENCL_VERSION__\n#define __kernel\n#define __global\n#define __constant\n#define __local\n#define get_global_id(x) 0\n#define get_global_size(x) 0\n#define get_local_id(x) 0\n#define get_local_size(x) 0\n#define FLT_MAX 0\n#define FLT_MIN 0\n#define cl_khr_fp64\n#define cl_amd_fp64\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#define CLK_LOCAL_MEM_FENCE\n#define CLK_GLOBAL_MEM_FENCE\n#define Dtype float\n#define barrier(x)\n#define atomic_cmpxchg(x, y, z) x\n#define signbit(x) x\n#define int_tp long\n#define uint_tp unsigned long\n#define int_tpc long\n#define uint_tpc unsigned long\n#endif\n\n#define CONCAT(A,B) A##_##B\n#define TEMPLATE(name,type) CONCAT(name,type)\n\n#define TYPE_FLOAT 1\n#define TYPE_DOUBLE 2\n#define TYPE_HALF 3\n\n#if defined(cl_khr_fp64)\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#ifndef DISABLE_DOUBLE_SUPPORT\n#define DOUBLE_SUPPORT_AVAILABLE\n#endif //DISABLE_DOUBLE_SUPPORT\n#endif\n\n#if defined(cl_khr_fp16)\n#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#define HALF_SUPPORT_AVAILABLE\n#endif\n\n#if defined(cl_khr_int64_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable\n#define ATOMICS_64_AVAILABLE\n#endif\n\n#if defined(cl_khr_int32_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_int32_base_atomics : enable\n#define ATOMICS_32_AVAILABLE\n#endif\n\n#if defined(cl_khr_global_int32_base_atomics)\n#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n#define ATOMICS_32_AVAILABLE\n#endif"; // NOLINT static std::string definitions_32 = DOUBLE_SUPPORT "// Types used for parameters, offset computations and so on\n#define int_tp int\n#define uint_tp unsigned int\n\n// Definitions used to cast the types above as needed\n#define int_tpc int\n#define uint_tpc unsigned int"; // NOLINT #endif static std::vector> cl_kernels{ @@ -27,7 +27,7 @@ static std::vector> cl_kernels{ "__kernel void TEMPLATE(relu_forward,Dtype)(const int_tp n,", // NOLINT "__global const Dtype* in,", // NOLINT "__global Dtype* out,", // NOLINT -"Dtype negative_slope) {", // NOLINT +"KERNEL_ARG_DTYPE negative_slope) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT "out[index] = in[index] > 0 ? in[index] : in[index] * negative_slope;", // NOLINT "}", // NOLINT @@ -37,10 +37,10 @@ static std::vector> cl_kernels{ "__global const Dtype* in_diff,", // NOLINT "__global const Dtype* in_data,", // NOLINT "__global Dtype* out_diff,", // NOLINT -"Dtype negative_slope) {", // NOLINT +"KERNEL_ARG_DTYPE negative_slope) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT "out_diff[index] = in_diff[index]", // NOLINT -"* ((in_data[index] > 0?1.0:0.0) + (in_data[index] <= 0?1.0:0.0) * negative_slope);", // NOLINT +"* ((Dtype)(in_data[index] > 0?1.0:0.0) + (Dtype)(in_data[index] <= 0?1.0:0.0) * negative_slope);", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT @@ -58,7 +58,7 @@ static std::vector> cl_kernels{ "__global Dtype* out_diff) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT "Dtype tanhx = out_data[index];", // NOLINT -"out_diff[index] = in_diff[index] * (1 - tanhx * tanhx);", // NOLINT +"out_diff[index] = in_diff[index] * ((Dtype)1 - tanhx * tanhx);", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT @@ -66,7 +66,7 @@ static std::vector> cl_kernels{ "__global const Dtype* in,", // NOLINT "__global Dtype* out) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT -"out[index] = 1.0 / (1.0 + exp(-in[index]));", // NOLINT +"out[index] = (Dtype)1.0 / ((Dtype)1.0 + exp(-in[index]));", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT @@ -76,11 +76,11 @@ static std::vector> cl_kernels{ "__global Dtype* out_diff) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT "const Dtype sigmoid_x = out_data[index];", // NOLINT -"out_diff[index] = in_diff[index] * sigmoid_x * (1 - sigmoid_x);", // NOLINT +"out_diff[index] = in_diff[index] * sigmoid_x * ((Dtype)1 - sigmoid_x);", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(threshold,Dtype)(const int_tp n, const Dtype threshold,", // NOLINT +"__kernel void TEMPLATE(threshold,Dtype)(const int_tp n, const KERNEL_ARG_DTYPE threshold,", // NOLINT "__global const Dtype* in,", // NOLINT "__global Dtype* out) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT @@ -120,11 +120,11 @@ static std::vector> cl_kernels{ "__global const Dtype* in_data,", // NOLINT "__global Dtype* out_diff) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT -"out_diff[index] = in_diff[index] * in_data[index] * (in_data[index] <= 0?1.0:0.0);", // NOLINT +"out_diff[index] = in_diff[index] * in_data[index] * (Dtype)(in_data[index] <= 0?1.0:0.0);", // NOLINT "for (int k = 1; k < rows; k++) {", // NOLINT "out_diff[index] += in_diff[index + k * rowPitch]", // NOLINT "* in_data[index + k * rowPitch]", // NOLINT -"* (in_data[index + k * rowPitch] <= 0?1.0:0.0);", // NOLINT +"* (Dtype)(in_data[index + k * rowPitch] <= 0?1.0:0.0);", // NOLINT "}", // NOLINT "}", // NOLINT "}", // NOLINT @@ -166,7 +166,7 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(gpu_set,Dtype)(const int_tp n, const Dtype alpha, __global Dtype* y) {", // NOLINT +"__kernel void TEMPLATE(gpu_set,Dtype)(const int_tp n, const KERNEL_ARG_DTYPE alpha, __global Dtype* y) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT "y[index] = alpha;", // NOLINT "}", // NOLINT @@ -177,11 +177,12 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(batch_norm_use_global_stats_in_place,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim,", // NOLINT +"Dtype TEMPLATE(bn_common,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim,", // NOLINT "const Dtype scale, const Dtype eps,", // NOLINT "__global const Dtype* mean,", // NOLINT "__global const Dtype* variance,", // NOLINT -"__global Dtype* top) {", // NOLINT +"__global const Dtype* data,", // NOLINT +"int_tp *out_off) {", // NOLINT "const int_tp idx_num = get_global_id(0);", // NOLINT "const int_tp idx_chans = get_global_id(1);", // NOLINT "const int_tp idx_spatial_dim = get_global_id(2);", // NOLINT @@ -190,30 +191,53 @@ static std::vector> cl_kernels{ "Dtype v = variance[idx_chans];", // NOLINT "", // NOLINT "m = -scale * m;", // NOLINT -"v = (Dtype)native_powr((float)mad(scale, v, eps), (float)-0.5);", // NOLINT +"v = (Dtype)native_powr((Dtype)mad(scale, v, eps), (Dtype)-0.5);", // NOLINT "", // NOLINT -"const int_tp out_off = (idx_num * channels + idx_chans) * spatial_dim + idx_spatial_dim;", // NOLINT -"top[out_off] = v * (top[out_off] + m);", // NOLINT +"*out_off = (idx_num * channels + idx_chans) * spatial_dim + idx_spatial_dim;", // NOLINT +"return (v * (data[*out_off] + m));", // NOLINT "}", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(batch_norm_use_global_stats,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim,", // NOLINT -"const Dtype scale, const Dtype eps,", // NOLINT +"", // NOLINT +"__kernel void TEMPLATE(bn_use_global_stats_in_place,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim,", // NOLINT +"const KERNEL_ARG_DTYPE scale, const KERNEL_ARG_DTYPE eps,", // NOLINT "__global const Dtype* mean,", // NOLINT "__global const Dtype* variance,", // NOLINT -"__global const Dtype* bottom,", // NOLINT "__global Dtype* top) {", // NOLINT -"const int_tp idx_num = get_global_id(0);", // NOLINT -"const int_tp idx_chans = get_global_id(1);", // NOLINT -"const int_tp idx_spatial_dim = get_global_id(2);", // NOLINT +"int_tp out_off;", // NOLINT +"Dtype val = TEMPLATE(bn_common,Dtype)(num, channels, spatial_dim, scale, eps, mean, variance, top, &out_off);", // NOLINT +"top[out_off] = val;", // NOLINT +"}", // NOLINT "", // NOLINT -"Dtype m = mean[idx_chans];", // NOLINT -"Dtype v = variance[idx_chans];", // NOLINT +"__kernel void TEMPLATE(bn_use_global_stats_in_place_fused_relu,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim,", // NOLINT +"const KERNEL_ARG_DTYPE scale, const KERNEL_ARG_DTYPE eps,", // NOLINT +"__global const Dtype* mean,", // NOLINT +"__global const Dtype* variance,", // NOLINT +"__global Dtype* top) {", // NOLINT +"int_tp out_off;", // NOLINT +"Dtype val = TEMPLATE(bn_common,Dtype)(num, channels, spatial_dim, scale, eps, mean, variance, top, &out_off);", // NOLINT +"top[out_off] = val > 0.0f ? val : 0.0f;", // NOLINT +"}", // NOLINT "", // NOLINT -"m = -scale * m;", // NOLINT -"v = (Dtype)native_powr((float)mad(scale, v, eps), (float)-0.5);", // NOLINT +"__kernel void TEMPLATE(bn_use_global_stats,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim,", // NOLINT +"const KERNEL_ARG_DTYPE scale, const KERNEL_ARG_DTYPE eps,", // NOLINT +"__global const Dtype* mean,", // NOLINT +"__global const Dtype* variance,", // NOLINT +"__global const Dtype* bottom,", // NOLINT +"__global Dtype* top) {", // NOLINT +"int_tp out_off;", // NOLINT +"Dtype val = TEMPLATE(bn_common,Dtype)(num, channels, spatial_dim, scale, eps, mean, variance, bottom, &out_off);", // NOLINT +"top[out_off] = val;", // NOLINT +"}", // NOLINT "", // NOLINT -"const int_tp out_off = (idx_num * channels + idx_chans) * spatial_dim + idx_spatial_dim;", // NOLINT -"top[out_off] = v * (bottom[out_off] + m);", // NOLINT +"__kernel void TEMPLATE(bn_use_global_stats_fused_relu,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim,", // NOLINT +"const KERNEL_ARG_DTYPE scale, const KERNEL_ARG_DTYPE eps,", // NOLINT +"__global const Dtype* mean,", // NOLINT +"__global const Dtype* variance,", // NOLINT +"__global const Dtype* bottom,", // NOLINT +"__global Dtype* top) {", // NOLINT +"int_tp out_off;", // NOLINT +"Dtype val = TEMPLATE(bn_common,Dtype)(num, channels, spatial_dim, scale, eps, mean, variance, bottom, &out_off);", // NOLINT +"top[out_off] = val > 0.0f ? val : 0.0f;", // NOLINT "}", // NOLINT ""}, // NOLINT {"#ifndef __OPENCL_VERSION__", // NOLINT @@ -255,7 +279,7 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(null_kernel,Dtype)(Dtype arg) {", // NOLINT +"__kernel void TEMPLATE(null_kernel,Dtype)(KERNEL_ARG_DTYPE arg) {", // NOLINT "Dtype out = arg;", // NOLINT "}", // NOLINT ""}, // NOLINT @@ -312,7 +336,7 @@ static std::vector> cl_kernels{ "__global Dtype* out) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT "if (in[index] > 0.0f) {", // NOLINT -"out[index] = in[index] + log((Dtype) (1.0 + exp(-in[index])));", // NOLINT +"out[index] = in[index] + log((Dtype) ((Dtype)1.0 + exp(-in[index])));", // NOLINT "} else {", // NOLINT "out[index] = log((Dtype) (1.0 + exp(in[index])));", // NOLINT "}", // NOLINT @@ -326,7 +350,7 @@ static std::vector> cl_kernels{ "Dtype kBNLL_THRESHOLD = 50.;", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT "Dtype expval = exp(min(in_data[index], kBNLL_THRESHOLD));", // NOLINT -"out_diff[index] = in_diff[index] * expval / (expval + 1.);", // NOLINT +"out_diff[index] = in_diff[index] * expval / (expval + (Dtype)1.);", // NOLINT "}", // NOLINT "}", // NOLINT ""}, // NOLINT @@ -342,7 +366,7 @@ static std::vector> cl_kernels{ "get_global_size(0)) {", // NOLINT "int_tp n = index / spatial_dim;", // NOLINT "int_tp s = index % spatial_dim;", // NOLINT -"float maxval = -FLT_MAX;", // NOLINT +"Dtype maxval = -DTYPE_MAX;", // NOLINT "for (int_tp c = 0; c < channels; ++c) {", // NOLINT "maxval = max((Dtype)(data[(n * channels + c) * spatial_dim + s]), (Dtype)maxval);", // NOLINT "}", // NOLINT @@ -449,7 +473,7 @@ static std::vector> cl_kernels{ "#endif", // NOLINT "", // NOLINT "__kernel void TEMPLATE(cll_backward,Dtype)(const int_tp count, const int_tp channels,", // NOLINT -"const Dtype margin, const Dtype alpha, __global const Dtype* y,", // NOLINT +"const KERNEL_ARG_DTYPE margin, const KERNEL_ARG_DTYPE alpha, __global const Dtype* y,", // NOLINT "__global const Dtype* diff, __global const Dtype* dist_sq,", // NOLINT "__global Dtype *bottom_diff) {", // NOLINT "for (int_tp i = get_global_id(0); i < count;", // NOLINT @@ -462,7 +486,7 @@ static std::vector> cl_kernels{ "Dtype beta = 0.;", // NOLINT "Dtype dist = sqrt(dist_sq[n]);", // NOLINT "mdist = (margin - dist);", // NOLINT -"beta = -alpha * mdist / (dist + 1e-4) * diff[i];", // NOLINT +"beta = -alpha * mdist / (dist + (Dtype)1e-4) * diff[i];", // NOLINT "if (mdist > 0.) {", // NOLINT "bottom_diff[i] = beta;", // NOLINT "} else {", // NOLINT @@ -473,7 +497,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(cll_backward_legacy,Dtype)(const int count, const int channels,", // NOLINT -"const Dtype margin, const Dtype alpha, __global Dtype* y,", // NOLINT +"const KERNEL_ARG_DTYPE margin, const KERNEL_ARG_DTYPE alpha, __global Dtype* y,", // NOLINT "__global Dtype* diff, __global Dtype* dist_sq,", // NOLINT "__global Dtype* bottom_diff) {", // NOLINT "for (int_tp i = get_global_id(0); i < count;", // NOLINT @@ -499,7 +523,7 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(conv_layer_spatial_phony,Dtype)(Dtype arg) {", // NOLINT +"__kernel void TEMPLATE(conv_layer_spatial_phony,Dtype)(KERNEL_ARG_DTYPE arg) {", // NOLINT "Dtype out = arg;", // NOLINT "}", // NOLINT "", // NOLINT @@ -619,8 +643,30 @@ static std::vector> cl_kernels{ "}", // NOLINT "}", // NOLINT "}", // NOLINT +"", // NOLINT "#endif", // NOLINT "", // NOLINT +"#if defined(convolve_simd) || defined(Conv_Interleaved)", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define INT_TYPE ushort", // NOLINT +"#define INT_TYPE2 ushort2", // NOLINT +"#define INT_TYPE4 ushort4", // NOLINT +"#define INT_TYPE8 ushort8", // NOLINT +"#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read_us2", // NOLINT +"#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read_us4", // NOLINT +"#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read_us8", // NOLINT +"#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read_us", // NOLINT +"#else", // NOLINT +"#define INT_TYPE uint", // NOLINT +"#define INT_TYPE2 uint2", // NOLINT +"#define INT_TYPE4 uint4", // NOLINT +"#define INT_TYPE8 uint8", // NOLINT +"#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read2", // NOLINT +"#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read4", // NOLINT +"#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read8", // NOLINT +"#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read", // NOLINT +"#endif", // NOLINT +"#endif", // NOLINT "", // NOLINT "//Begin IDLF kernels below here", // NOLINT "#ifdef IDLF", // NOLINT @@ -631,26 +677,28 @@ static std::vector> cl_kernels{ "// Each work-group (which will be mapped to 1 SIMD16/SIMD8 EU thread) will compute 16/8 different feature maps, but each feature map is for the same region of the imput image.", // NOLINT "// NDRange: (output_width+pad)/ OUT_BLOCK_WIDTH, (output_height+pad)/OUT_BLOCK_HEIGHT, NUM_FILTERS/OUT_BLOCK_DEPTH", // NOLINT "", // NOLINT -"// NOTE: for beignet this reqd_work_group_size does not guarantee that SIMD16/8 mode will be used, the compiler could choose to use two SIMD8 threads, and if that happens the code will break.", // NOLINT +"// NOTE: for beignet this reqd_work_group_size does not guarantee that SIMD16 mode will be used, the compiler could choose to use two SIMD8 threads, and if that happens the code will break.", // NOLINT +"#ifndef __BEIGNET__", // NOLINT "__attribute__((reqd_work_group_size(1, 1, SIMD_SIZE)))", // NOLINT +"#endif", // NOLINT "__kernel void", // NOLINT -"convolve_simd( // __global float *inputs, __global float* weights, __global float* outputs", // NOLINT +"convolve_simd(", // NOLINT "#ifdef FUSED_CONV_ELTWISE", // NOLINT "__global Dtype* eltwise_data,", // NOLINT "#endif", // NOLINT -"__global float* inputs_base,", // NOLINT -"filter_qualifier float* weights_base,", // NOLINT -"__global float* biases_base,", // NOLINT -"__global float* outputs_base,", // NOLINT +"__global Dtype* inputs_base,", // NOLINT +"filter_qualifier Dtype* weights_base,", // NOLINT +"__global Dtype* biases_base,", // NOLINT +"__global Dtype* outputs_base,", // NOLINT "const ushort input_width,", // NOLINT "const ushort input_height,", // NOLINT "const ushort output_width,", // NOLINT "const ushort output_height)", // NOLINT "{", // NOLINT -"__global float* outputs = outputs_base;", // NOLINT -"__global float* inputs = inputs_base;", // NOLINT -"filter_qualifier float* weights = weights_base;", // NOLINT -"__global float* biases = biases_base;", // NOLINT +"__global Dtype* outputs = outputs_base;", // NOLINT +"__global Dtype* inputs = inputs_base;", // NOLINT +"filter_qualifier Dtype* weights = weights_base;", // NOLINT +"__global Dtype* biases = biases_base;", // NOLINT "", // NOLINT "uint_tp oc = get_global_id(0) * OUT_BLOCK_WIDTH; // oc = Output Column", // NOLINT "uint_tp or = get_global_id(1) * OUT_BLOCK_HEIGHT;// or = Output Row", // NOLINT @@ -658,7 +706,7 @@ static std::vector> cl_kernels{ "uint_tp fmg = get_group_id(2);", // NOLINT "uint_tp lid = get_local_id(2);", // NOLINT "", // NOLINT -"float out[OUT_BLOCK_SIZE];", // NOLINT +"Dtype out[OUT_BLOCK_SIZE];", // NOLINT "", // NOLINT "int_tp in_addr;", // NOLINT "", // NOLINT @@ -684,8 +732,8 @@ static std::vector> cl_kernels{ "+ (curr_y - INPUT_PAD_H) * input_width // y tile offset", // NOLINT "+ curr_x - INPUT_PAD_W; // x tile offset", // NOLINT "union {", // NOLINT -"float4 in_vec[INVEC_SIZE];", // NOLINT -"float in_array[INVEC_SIZE * 4];", // NOLINT +"Dtype4 in_vec[INVEC_SIZE];", // NOLINT +"Dtype in_array[INVEC_SIZE * 4];", // NOLINT "} in_buf;", // NOLINT "", // NOLINT "for(int_tp kd = 0; kd < INPUT_DEPTH; kd++)", // NOLINT @@ -709,7 +757,7 @@ static std::vector> cl_kernels{ "in_buf.in_vec[reg].s2 = 0;", // NOLINT "in_buf.in_vec[reg].s3 = *(inputs + in_offset + 3);", // NOLINT "} else {", // NOLINT -"in_buf.in_vec[reg] = *(global float4*)(inputs + in_offset); // read SIMD_SIZE elements", // NOLINT +"in_buf.in_vec[reg] = vload4(0, (inputs + in_offset)); // read 16 elements", // NOLINT "if (curr_x + 1 >= input_width + INPUT_PAD_W)", // NOLINT "in_buf.in_vec[reg].s1 = 0;", // NOLINT "if (curr_x + 2 >= input_width + INPUT_PAD_W)", // NOLINT @@ -722,7 +770,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "curr_y += TILE_Y_STRIDE;", // NOLINT "#else", // NOLINT -"in_buf.in_vec[reg] = *(global float4*)(inputs + in_offset); // read SIMD_SIZE elements", // NOLINT +"in_buf.in_vec[reg] = vload4(0, (inputs + in_offset)); // read 16 elements", // NOLINT "#endif", // NOLINT "}", // NOLINT "in_offset += input_width * TILE_Y_STRIDE;", // NOLINT @@ -738,19 +786,19 @@ static std::vector> cl_kernels{ "#define WEIGHT_PREF 1", // NOLINT "#endif", // NOLINT "union {", // NOLINT -"float w[WEIGHT_PREF];", // NOLINT +"Dtype w[WEIGHT_PREF];", // NOLINT "#if KERNEL_WIDTH * KERNEL_HEIGHT != 1", // NOLINT -"uint8 ui8;", // NOLINT +"INT_TYPE8 ui8;", // NOLINT "#endif", // NOLINT "} weight_buf;", // NOLINT "int_tp w_idx=0;", // NOLINT "", // NOLINT "uint_tp orig_weight_addr = weight_addr;", // NOLINT "#if KERNEL_WIDTH * KERNEL_HEIGHT != 1", // NOLINT -"weight_buf.ui8 = intel_sub_group_block_read8((__global uint *)&weights[weight_addr]);", // NOLINT +"weight_buf.ui8 = SUB_GROUP_BLOCK_READ8((__global INT_TYPE *)&weights[weight_addr]);", // NOLINT "weight_addr += SIMD_SIZE * WEIGHT_PREF;", // NOLINT "#else", // NOLINT -"weight_buf.w[0] = as_float(intel_sub_group_block_read((__global uint *)&weights[weight_addr]));", // NOLINT +"weight_buf.w[0] = as_Dtype(SUB_GROUP_BLOCK_READ((__global INT_TYPE *)&weights[weight_addr]));", // NOLINT "weight_addr += SIMD_SIZE * 1;", // NOLINT "#endif", // NOLINT "", // NOLINT @@ -764,7 +812,7 @@ static std::vector> cl_kernels{ "{", // NOLINT "for(int_tp br=0; br < OUT_BLOCK_HEIGHT; br++) {", // NOLINT "for(int_tp bc=0; bc < OUT_BLOCK_WIDTH; bc++) {", // NOLINT -"float input = BLOCK_IN((br * STRIDEY + kr * DILATION_Y) * TILE_X + bc * STRIDEX + kc * DILATION_X);", // NOLINT +"Dtype input = BLOCK_IN((br * STRIDEY + kr * DILATION_Y) * TILE_X + bc * STRIDEX + kc * DILATION_X);", // NOLINT "out[br * OUT_BLOCK_WIDTH + bc] = mad(weight_buf.w[w_idx % WEIGHT_PREF], input, out[br * OUT_BLOCK_WIDTH + bc]);", // NOLINT "}", // NOLINT "}", // NOLINT @@ -775,7 +823,7 @@ static std::vector> cl_kernels{ "&& ((w_idx + 1) <= (KERNEL_WIDTH * KERNEL_HEIGHT - WEIGHT_PREF))", // NOLINT "#endif", // NOLINT ") {", // NOLINT -"weight_buf.ui8 = intel_sub_group_block_read8((__global uint *)&weights[weight_addr]);", // NOLINT +"weight_buf.ui8 = SUB_GROUP_BLOCK_READ8((__global INT_TYPE *)&weights[weight_addr]);", // NOLINT "weight_addr += SIMD_SIZE * WEIGHT_PREF; // weights must be stored in just the right SIMD swizzled format for this to work, see host code for details.", // NOLINT "}", // NOLINT "#if KERNEL_WIDTH*KERNEL_HEIGHT % 8 == 0", // NOLINT @@ -785,11 +833,11 @@ static std::vector> cl_kernels{ "#if KERNEL_WIDTH * KERNEL_HEIGHT % 8 == 1", // NOLINT "weight_buf.w[0] = weights[weight_addr];", // NOLINT "#elif KERNEL_WIDTH * KERNEL_HEIGHT % 8 == 2", // NOLINT -"weight_buf.ui8.s01 = intel_sub_group_block_read2((__global uint *)&weights[weight_addr]);", // NOLINT +"weight_buf.ui8.s01 = SUB_GROUP_BLOCK_READ2((__global INT_TYPE *)&weights[weight_addr]);", // NOLINT "#elif KERNEL_WIDTH * KERNEL_HEIGHT % 8 <= 4", // NOLINT -"weight_buf.ui8.s0123 = intel_sub_group_block_read4((__global uint *)&weights[weight_addr]);", // NOLINT +"weight_buf.ui8.s0123 = SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)&weights[weight_addr]);", // NOLINT "#else", // NOLINT -"weight_buf.ui8 = intel_sub_group_block_read8((__global uint *)&weights[weight_addr]);", // NOLINT +"weight_buf.ui8 = SUB_GROUP_BLOCK_READ8((__global INT_TYPE *)&weights[weight_addr]);", // NOLINT "#endif", // NOLINT "#endif", // NOLINT "#endif", // NOLINT @@ -808,7 +856,8 @@ static std::vector> cl_kernels{ "if ((ALIGNED_NUM_FILTERS == NUM_FILTERS || fm < NUM_FILTERS)) {", // NOLINT "uint_tp out_addr = OUT_BUFF_OFFSET + ( num_in_batch * TOTAL_OUTPUT_DEPTH + fm ) * output_width * output_height;", // NOLINT "out_addr += or * output_width + oc;", // NOLINT -"float bias = biases[(fm % ALIGNED_NUM_FILTERS)];", // NOLINT +"// we need this address calculation for biases because we support views and batching", // NOLINT +"Dtype bias = biases[fm];", // NOLINT "for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) {", // NOLINT "if (r + or >= output_height) break;", // NOLINT "for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) {", // NOLINT @@ -862,6 +911,25 @@ static std::vector> cl_kernels{ "float s6; float s7; float s8; float s9; float sa; float sb; float sc; float sd; float se; } float15;", // NOLINT "typedef struct float0 { float s0; } float0; //never used but makes compiler happy.", // NOLINT "", // NOLINT +"typedef struct half1 { half s0; } half1;", // NOLINT +"typedef struct half5 { half s0; half s1; half s2; half s3; half s4; } half5;", // NOLINT +"typedef struct half6 { half s0; half s1; half s2; half s3; half s4; half s5; } half6;", // NOLINT +"typedef struct half7 { half s0; half s1; half s2; half s3; half s4; half s5; half s6; } half7;", // NOLINT +"typedef struct half9 { half s0; half s1; half s2; half s3; half s4; half s5; half s6; half s7; half s8; } half9;", // NOLINT +"typedef struct half10 { half s0; half s1; half s2; half s3; half s4; half s5;", // NOLINT +"half s6; half s7; half s8; half s9; } half10;", // NOLINT +"typedef struct half11 { half s0; half s1; half s2; half s3; half s4; half s5;", // NOLINT +"half s6; half s7; half s8; half s9; half sa; } half11;", // NOLINT +"typedef struct half12 { half s0; half s1; half s2; half s3; half s4; half s5;", // NOLINT +"half s6; half s7; half s8; half s9; half sa; half sb; } half12;", // NOLINT +"typedef struct half13 { half s0; half s1; half s2; half s3; half s4; half s5;", // NOLINT +"half s6; half s7; half s8; half s9; half sa; half sb; half sc; } half13;", // NOLINT +"typedef struct half14 { half s0; half s1; half s2; half s3; half s4; half s5;", // NOLINT +"half s6; half s7; half s8; half s9; half sa; half sb; half sc; half sd; } half14;", // NOLINT +"typedef struct half15 { half s0; half s1; half s2; half s3; half s4; half s5;", // NOLINT +"half s6; half s7; half s8; half s9; half sa; half sb; half sc; half sd; half se; } half15;", // NOLINT +"typedef struct half0 { half s0; } half0; //never used but makes compiler happy.", // NOLINT +"", // NOLINT "#define OUT_PITCH_X output_width", // NOLINT "#define ROW_PITCH input_width", // NOLINT "", // NOLINT @@ -891,7 +959,7 @@ static std::vector> cl_kernels{ "#define TILE_K KERNEL_WIDTH", // NOLINT "#define TILE_N 32", // NOLINT "", // NOLINT -"#ifdef __BEIGNET__", // NOLINT +"#ifndef __BEIGNET__", // NOLINT "__attribute__((intel_reqd_sub_group_size(8)))", // NOLINT "#endif", // NOLINT "__kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS)", // NOLINT @@ -906,7 +974,7 @@ static std::vector> cl_kernels{ "int kernel_idx;", // NOLINT "", // NOLINT "#define DOT_PRODUCT_8( _result, _rowA, colB ) { _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); }", // NOLINT -"typedef CAT( float, KERNEL_WIDTH ) float_t;", // NOLINT +"typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t;", // NOLINT "", // NOLINT "// True for all threads if filter_width is multiple of TILE_N", // NOLINT "// else, true for all but right-most column of threads.", // NOLINT @@ -914,10 +982,10 @@ static std::vector> cl_kernels{ "{", // NOLINT "// Result ctile (*dst) is M rows x N columns", // NOLINT "// LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile.", // NOLINT -"float8 blockC00 = 0.f;", // NOLINT -"float8 blockC10 = 0.f;", // NOLINT -"float8 blockC20 = 0.f;", // NOLINT -"float8 blockC30 = 0.f;", // NOLINT +"Dtype8 blockC00 = 0.f;", // NOLINT +"Dtype8 blockC10 = 0.f;", // NOLINT +"Dtype8 blockC20 = 0.f;", // NOLINT +"Dtype8 blockC30 = 0.f;", // NOLINT "", // NOLINT "// Src0 (patch input) is directly used as atile.", // NOLINT "// Each work item points to the start of a different patch.", // NOLINT @@ -927,7 +995,7 @@ static std::vector> cl_kernels{ "#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1", // NOLINT "int saved_y = curr_y;", // NOLINT "#endif", // NOLINT -"const __global float *src0_read = src0", // NOLINT +"const __global Dtype *src0_read = src0", // NOLINT "+ aligned_input_size * global_z // batch offset", // NOLINT "+ (curr_y - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT "+ (curr_x - INPUT_PAD_W); // x offset", // NOLINT @@ -935,7 +1003,7 @@ static std::vector> cl_kernels{ "// Src1 (filter) is directly used as btile.", // NOLINT "// It starts at the top of src1 and walks down.", // NOLINT "// btile is K rows x N columns.", // NOLINT -"const __global float *src1_read = src1 + ( global_x * TILE_N * 2);", // NOLINT +"const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2);", // NOLINT "", // NOLINT "// Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1.", // NOLINT "// Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch", // NOLINT @@ -951,7 +1019,7 @@ static std::vector> cl_kernels{ "do", // NOLINT "{", // NOLINT "// Load atile and btile.", // NOLINT -"// Kernel data is partially interleaved. Every 2 rows are interleaved at float8 granularity.", // NOLINT +"// Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype8 granularity.", // NOLINT "// The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non", // NOLINT "// interleaved row is padded with zero to ensure same size as interleaved rows. This", // NOLINT "// interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the", // NOLINT @@ -963,11 +1031,11 @@ static std::vector> cl_kernels{ "const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1;", // NOLINT "", // NOLINT "#if INPUT_PAD_W == 0 && INPUT_PAD_H == 0 && DILATION_X == 1 && DILATION_Y == 1", // NOLINT -"float_t blockA00 = ( (const __global float_t*)src0_read )[ 0 ];", // NOLINT -"float* pblockA00 = (float*)(&blockA00);", // NOLINT +"Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[ 0 ];", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT "#else", // NOLINT -"float_t blockA00;", // NOLINT -"float* pblockA00 = (float*)(&blockA00);", // NOLINT +"Dtype_t blockA00;", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT "int pos = 0;", // NOLINT "LOOP(KERNEL_WIDTH, pos,", // NOLINT "{", // NOLINT @@ -980,20 +1048,20 @@ static std::vector> cl_kernels{ "#endif", // NOLINT "src0_read += (ROW_PITCH * DILATION_Y);", // NOLINT "", // NOLINT -"float blockB00[KERNEL_WIDTH*4];", // NOLINT -"float8* p8BlockB00 = (float8*)blockB00;", // NOLINT -"float4* p4BlockB00 = (float4*)blockB00;", // NOLINT -"float* pBlockB00 = (float* )blockB00;", // NOLINT +"Dtype blockB00[KERNEL_WIDTH*4];", // NOLINT +"Dtype8* p8BlockB00 = (Dtype8*)blockB00;", // NOLINT +"Dtype4* p4BlockB00 = (Dtype4*)blockB00;", // NOLINT +"Dtype* pBlockB00 = (Dtype* )blockB00;", // NOLINT "", // NOLINT "interleaved_y = 0;", // NOLINT "LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT "{", // NOLINT -"p8BlockB00[interleaved_y] = as_float8( intel_sub_group_block_read8( (const __global uint*)src1_read ) );", // NOLINT +"p8BlockB00[interleaved_y] = as_Dtype8( SUB_GROUP_BLOCK_READ8( (const __global INT_TYPE *)src1_read ) );", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "} )", // NOLINT "if ( kernel_width_is_odd )", // NOLINT "{", // NOLINT -"p4BlockB00[KERNEL_WIDTH - 1] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) );", // NOLINT +"p4BlockB00[KERNEL_WIDTH - 1] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE *)src1_read ) );", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "}", // NOLINT "", // NOLINT @@ -1037,11 +1105,11 @@ static std::vector> cl_kernels{ "+ ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT "", // NOLINT -"__global float *out = dst + out_offset;", // NOLINT -"float bias[4];", // NOLINT -"float4 *bias_vec;", // NOLINT -"bias_vec = (float4*)bias;", // NOLINT -"*bias_vec = as_float4(intel_sub_group_block_read4((__global uint *)biases + group_x * TILE_N));", // NOLINT +"__global Dtype *out = dst + out_offset;", // NOLINT +"Dtype bias[4];", // NOLINT +"Dtype4 *bias_vec;", // NOLINT +"bias_vec = (Dtype4*)bias;", // NOLINT +"*bias_vec = as_Dtype4(SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)biases + group_x * TILE_N));", // NOLINT "", // NOLINT "if (global_y * TILE_M < output_width * output_height )", // NOLINT "{", // NOLINT @@ -1061,7 +1129,7 @@ static std::vector> cl_kernels{ "// Result ctile (*dst) is M rows x N columns", // NOLINT "// LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile.", // NOLINT "int i = 0;", // NOLINT -"float8 blockC[TILE_N_LAST_DIV8];", // NOLINT +"Dtype8 blockC[TILE_N_LAST_DIV8];", // NOLINT "LOOP(TILE_N_LAST_DIV8, i,", // NOLINT "{", // NOLINT "blockC[i] = 0.f;", // NOLINT @@ -1075,7 +1143,7 @@ static std::vector> cl_kernels{ "#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1", // NOLINT "int saved_y = curr_y;", // NOLINT "#endif", // NOLINT -"const __global float *src0_read = src0", // NOLINT +"const __global Dtype *src0_read = src0", // NOLINT "+ aligned_input_size * global_z // batch offset", // NOLINT "+ (curr_y - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT "+ (curr_x - INPUT_PAD_W); // x offset", // NOLINT @@ -1083,7 +1151,7 @@ static std::vector> cl_kernels{ "// Src1 (filter) is directly used as btile.", // NOLINT "// It starts at the top of src1 and walks down.", // NOLINT "// btile is K rows x N columns.", // NOLINT -"const __global float *src1_read = src1 + ( global_x * TILE_N * 2);", // NOLINT +"const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2);", // NOLINT "", // NOLINT "// Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1.", // NOLINT "// Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch", // NOLINT @@ -1100,11 +1168,11 @@ static std::vector> cl_kernels{ "// Load atile and interleaved btile.", // NOLINT "const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1;", // NOLINT "#if INPUT_PAD_W == 0 && INPUT_PAD_H == 0 && DILATION_X == 1 && DILATION_Y == 1", // NOLINT -"float_t blockA00 = ( (const __global float_t*)src0_read )[ 0 ];", // NOLINT -"float* pblockA00 = (float*)(&blockA00);", // NOLINT +"Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[ 0 ];", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT "#else", // NOLINT -"float_t blockA00;", // NOLINT -"float* pblockA00 = (float*)(&blockA00);", // NOLINT +"Dtype_t blockA00;", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT "int pos = 0;", // NOLINT "LOOP(KERNEL_WIDTH, pos,", // NOLINT "{", // NOLINT @@ -1116,43 +1184,43 @@ static std::vector> cl_kernels{ "curr_y += DILATION_Y;", // NOLINT "#endif", // NOLINT "src0_read += (ROW_PITCH * DILATION_Y);", // NOLINT -"float blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8];", // NOLINT +"Dtype blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8];", // NOLINT "", // NOLINT "interleaved_y = 0;", // NOLINT "LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT "{", // NOLINT "#if TILE_N_LAST_DIV8 == 1", // NOLINT -"float2* p2BlockB = (float2* )blockB;", // NOLINT -"p2BlockB[interleaved_y] = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) );", // NOLINT +"Dtype2* p2BlockB = (Dtype2* )blockB;", // NOLINT +"p2BlockB[interleaved_y] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) );", // NOLINT "#elif TILE_N_LAST_DIV8 == 2", // NOLINT -"float4* p4BlockB = (float4* )blockB;", // NOLINT -"p4BlockB[interleaved_y] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) );", // NOLINT +"Dtype4* p4BlockB = (Dtype4* )blockB;", // NOLINT +"p4BlockB[interleaved_y] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) );", // NOLINT "#elif TILE_N_LAST_DIV8 == 3", // NOLINT "//TODO: broken. No block_read6", // NOLINT -"float6* p6BlockB = (float6* )blockB;", // NOLINT -"(*((float8*)(&p6BlockB[interleaved_y]))).s0123 = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) );", // NOLINT -"(*((float8*)(&p6BlockB[interleaved_y]))).s45 = as_float2( intel_sub_group_block_read2( (const __global uint*)(src1_read + 4 * 8) ) );", // NOLINT +"Dtype6* p6BlockB = (Dtype6* )blockB;", // NOLINT +"(*((Dtype8*)(&p6BlockB[interleaved_y]))).s0123 = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) );", // NOLINT +"(*((Dtype8*)(&p6BlockB[interleaved_y]))).s45 = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)(src1_read + 4 * 8) ) );", // NOLINT "#endif", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "} )", // NOLINT "if ( kernel_width_is_odd )", // NOLINT "{", // NOLINT "#if TILE_N_LAST_DIV8 == 1", // NOLINT -"float* pBlockB = (float* )blockB;", // NOLINT -"pBlockB[KERNEL_WIDTH - 1] = as_float( intel_sub_group_block_read( (const __global uint*)src1_read ) );", // NOLINT +"Dtype* pBlockB = (Dtype* )blockB;", // NOLINT +"pBlockB[KERNEL_WIDTH - 1] = as_Dtype( SUB_GROUP_BLOCK_READ( (const __global INT_TYPE*)src1_read ) );", // NOLINT "#elif TILE_N_LAST_DIV8 == 2", // NOLINT -"float2* p2BlockB = (float2* )blockB;", // NOLINT -"p2BlockB[KERNEL_WIDTH - 1] = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) );", // NOLINT +"Dtype2* p2BlockB = (Dtype2* )blockB;", // NOLINT +"p2BlockB[KERNEL_WIDTH - 1] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) );", // NOLINT "#elif TILE_N_LAST_DIV8 == 3", // NOLINT -"float3* p3BlockB = (float3* )blockB;", // NOLINT -"p3BlockB[KERNEL_WIDTH - 1].s01 = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) );", // NOLINT -"p3BlockB[KERNEL_WIDTH - 1].s2 = as_float( intel_sub_group_block_read( (const __global uint*) (src1_read + 2 * 8) ) );", // NOLINT +"Dtype3* p3BlockB = (Dtype3* )blockB;", // NOLINT +"p3BlockB[KERNEL_WIDTH - 1].s01 = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) );", // NOLINT +"p3BlockB[KERNEL_WIDTH - 1].s2 = as_Dtype( SUB_GROUP_BLOCK_READ( (const __global INT_TYPE*) (src1_read + 2 * 8) ) );", // NOLINT "#endif", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "}", // NOLINT "", // NOLINT "// Perform MADs", // NOLINT -"float* pBlockB = (float*)blockB;", // NOLINT +"Dtype* pBlockB = (Dtype*)blockB;", // NOLINT "kernel_idx = 0;", // NOLINT "interleaved_y = 0;", // NOLINT "LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT @@ -1196,12 +1264,11 @@ static std::vector> cl_kernels{ "+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT "+ ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT -"", // NOLINT -"__global float *out = dst + out_offset;", // NOLINT -"float bias[4];", // NOLINT -"float4 *bias_vec;", // NOLINT -"bias_vec = (float4*)bias;", // NOLINT -"*bias_vec = as_float4(intel_sub_group_block_read4((__global uint *)biases + group_x * TILE_N));", // NOLINT +"__global Dtype *out = dst + out_offset;", // NOLINT +"Dtype bias[4];", // NOLINT +"Dtype4 *bias_vec;", // NOLINT +"bias_vec = (Dtype4*)bias;", // NOLINT +"*bias_vec = as_Dtype4(SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)biases + group_x * TILE_N));", // NOLINT "", // NOLINT "if (global_y * TILE_M < output_width * output_height )", // NOLINT "{", // NOLINT @@ -1217,14 +1284,25 @@ static std::vector> cl_kernels{ "#endif", // NOLINT "}", // NOLINT "#endif", // NOLINT +"#ifdef GEMM_LIKE_CONV_32_2", // NOLINT "", // NOLINT -"#ifdef GEMM_LIKE_CONV_32_1_SIMD16", // NOLINT -"#define TILE_M 1", // NOLINT +"//////////////////////////////////////////////////////////////////////////////", // NOLINT +"// Conv_Interleaved_32_2_flex", // NOLINT +"//", // NOLINT +"// Convolution: each workitem computes 1 patch x 32 filters worth of output", // NOLINT +"// data. Kernel's inner loop works on a single tile consisting of one", // NOLINT +"// row from each patch and the filter data corresponding to that row. Filter", // NOLINT +"// matrix is interleaved to reduce GRF bank conflicts. Patches are walked", // NOLINT +"// by rows and then by slices. Relies on sub_group extension for block", // NOLINT +"// reads and SIMD broadcast. Allows flexible sizing of TILE width (TILE_N)", // NOLINT +"// by dynamically selecting one of two code paths: one uses TILE_N = 32 and", // NOLINT +"// the other uses TILE_N = 8, 16, or 24.", // NOLINT +"#define TILE_M 2", // NOLINT "#define TILE_K KERNEL_WIDTH", // NOLINT "#define TILE_N 32", // NOLINT "", // NOLINT "#ifndef __BEIGNET__", // NOLINT -"__attribute__((intel_reqd_sub_group_size(16)))", // NOLINT +"__attribute__((intel_reqd_sub_group_size(8)))", // NOLINT "#endif", // NOLINT "__kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS)", // NOLINT "{", // NOLINT @@ -1237,249 +1315,216 @@ static std::vector> cl_kernels{ "int kernel_y;", // NOLINT "int kernel_idx;", // NOLINT "", // NOLINT +"#define DOT_PRODUCT_8( _result, _rowA, colB ) { _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); }", // NOLINT +"typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t;", // NOLINT +"", // NOLINT +"// True for all threads if filter_width is multiple of TILE_N", // NOLINT +"// else, true for all but right-most column of threads.", // NOLINT +"if( TILE_N_LAST == 0 || global_x < WIDTH1 / TILE_N )", // NOLINT +"{", // NOLINT "// Result ctile (*dst) is M rows x N columns", // NOLINT -"// LWG size is 1x16. Thus each thread calculates 16*M rows x N cols of ctile.", // NOLINT -"Dtype16 blockC00 = 0.f;", // NOLINT -"Dtype16 blockC10 = 0.f;", // NOLINT +"// LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile.", // NOLINT +"Dtype8 blockC00 = 0.f;", // NOLINT +"Dtype8 blockC10 = 0.f;", // NOLINT +"Dtype8 blockC20 = 0.f;", // NOLINT +"Dtype8 blockC30 = 0.f;", // NOLINT +"Dtype8 blockC01 = 0.f;", // NOLINT +"Dtype8 blockC11 = 0.f;", // NOLINT +"Dtype8 blockC21 = 0.f;", // NOLINT +"Dtype8 blockC31 = 0.f;", // NOLINT "", // NOLINT "// Src0 (patch input) is directly used as atile.", // NOLINT "// Each work item points to the start of a different patch.", // NOLINT "// atile is M rows x K columns.", // NOLINT -"int curr_x = ( global_y % output_width ) * STRIDE_X;", // NOLINT -"int curr_y = ( global_y / output_width ) * STRIDE_Y;", // NOLINT +"int curr_x0 = ( ( global_y * TILE_M + 0 ) % output_width ) * STRIDE_X;", // NOLINT +"int curr_x1 = ( ( global_y * TILE_M + 1 ) % output_width ) * STRIDE_X;", // NOLINT +"int curr_y0 = ( ( global_y * TILE_M + 0 ) / output_width ) * STRIDE_Y;", // NOLINT +"int curr_y1 = ( ( global_y * TILE_M + 1 ) / output_width ) * STRIDE_Y;", // NOLINT "#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1", // NOLINT -"int saved_y = curr_y;", // NOLINT +"int saved_y0 = curr_y0;", // NOLINT +"int saved_y1 = curr_y1;", // NOLINT "#endif", // NOLINT -"", // NOLINT -"const __global Dtype *src0_read = src0", // NOLINT -"+ aligned_input_size * global_z // batch offset", // NOLINT -"+ (curr_y - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT -"+ curr_x - INPUT_PAD_W; // x offset", // NOLINT -"const __global Dtype *src0_read_orig = src0_read;", // NOLINT +"const __global Dtype *src0_read0 = src0", // NOLINT +"+ aligned_input_size * global_z // batch offset", // NOLINT +"+ (curr_y0 - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT +"+ curr_x0 - INPUT_PAD_W; // x offset", // NOLINT +"const __global Dtype *src0_read1 = src0", // NOLINT +"+ aligned_input_size * global_z // batch offset", // NOLINT +"+ (curr_y1 - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT +"+ curr_x1 - INPUT_PAD_W; // x offset", // NOLINT "", // NOLINT "// Src1 (filter) is directly used as btile.", // NOLINT "// It starts at the top of src1 and walks down.", // NOLINT "// btile is K rows x N columns.", // NOLINT -"const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2 );", // NOLINT +"const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2);", // NOLINT "", // NOLINT -"#define DOT_PRODUCT_16( _result, _rowA, colB ) { _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); _result.s8 = mad( _rowA, sub_group_broadcast( colB, 8 ), _result.s8 ); _result.s9 = mad( _rowA, sub_group_broadcast( colB, 9 ), _result.s9 ); _result.sa = mad( _rowA, sub_group_broadcast( colB, 10 ), _result.sa ); _result.sb = mad( _rowA, sub_group_broadcast( colB, 11 ), _result.sb ); _result.sc = mad( _rowA, sub_group_broadcast( colB, 12 ), _result.sc ); _result.sd = mad( _rowA, sub_group_broadcast( colB, 13 ), _result.sd ); _result.se = mad( _rowA, sub_group_broadcast( colB, 14 ), _result.se ); _result.sf = mad( _rowA, sub_group_broadcast( colB, 15 ), _result.sf ); }", // NOLINT -"typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t;", // NOLINT "// Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1.", // NOLINT "// Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch", // NOLINT "// and KERNEL_WIDTH/2 rows of interleaved filter.", // NOLINT "int patch_depth = 0;", // NOLINT -"#ifndef __BEIGNET__", // NOLINT -"__attribute__((opencl_unroll_hint(1)))", // NOLINT -"#endif", // NOLINT "do", // NOLINT "{", // NOLINT "int patch_row = 0;", // NOLINT -"#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0", // NOLINT -"curr_y = saved_y;", // NOLINT -"#endif", // NOLINT -"#ifndef __BEIGNET__", // NOLINT -"__attribute__((opencl_unroll_hint(1)))", // NOLINT -"#endif", // NOLINT "do", // NOLINT "{", // NOLINT "// Load atile and btile.", // NOLINT -"// Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype16 granularity.", // NOLINT +"// Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype8 granularity.", // NOLINT "// The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non", // NOLINT "// interleaved row is padded with zero to ensure same size as interleaved rows. This", // NOLINT "// interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the", // NOLINT "// kernel data would be arranged before/after interleaving for KERNEL_WIDTH=3.", // NOLINT -"// (0, 0) (16, 0) (32, 0) (48, 0) ... (0, 0) ( 0, 1) (16, 0) ( 0, 1) (32, 0) (0, 1) (48, 0) ...", // NOLINT -"// (0, 1) (16, 1) (32, 1) (48, 1) ... => (0, 2) (16, 2) (32, 2) (48, 2) ...", // NOLINT -"// (0, 2) (16, 2) (32, 2) (48, 2) ... ...", // NOLINT +"// (0, 0) (8, 0) (16, 0) (24, 0) ... (0, 0) (0, 1) (8, 0) (0, 1) (16, 0) (0, 1) (24, 0) ..", // NOLINT +"// (0, 1) (8, 1) (16, 1) (24, 1) ... => (0, 2) (8, 2) (16, 2) (24, 2) ...", // NOLINT +"// (0, 2) (8, 2) (16, 2) (24, 2) ... ...", // NOLINT "// ...", // NOLINT "const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1;", // NOLINT -"", // NOLINT -"#if INPUT_PAD_W == 0 && INPUT_PAD_H == 0 && DILATION_X == 1 && DILATION_Y == 1", // NOLINT -"Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[ 0 ];", // NOLINT +"#if INPUT_PAD_H == 0 && INPUT_PAD_W == 0 && DILATION_X == 1 && DILATION_Y == 1", // NOLINT +"Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH;", // NOLINT +"Dtype_t blockA01 = ( (const __global Dtype_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH;", // NOLINT "Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT +"Dtype* pblockA01 = (Dtype*)(&blockA01);", // NOLINT "#else", // NOLINT "Dtype_t blockA00;", // NOLINT "Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT "int pos = 0;", // NOLINT "LOOP(KERNEL_WIDTH, pos,", // NOLINT "{", // NOLINT -"if (curr_y >= INPUT_PAD_H && curr_y < input_height + INPUT_PAD_H && curr_x + pos * DILATION_X >= INPUT_PAD_W && curr_x + pos * DILATION_X < input_width + INPUT_PAD_W)", // NOLINT -"pblockA00[pos] = src0_read[pos * DILATION_X];", // NOLINT +"if (curr_y0 >= INPUT_PAD_H && curr_y0 < input_height + INPUT_PAD_H && curr_x0 + pos * DILATION_X >= INPUT_PAD_W && curr_x0 + pos * DILATION_X < input_width + INPUT_PAD_W)", // NOLINT +"pblockA00[pos] = src0_read0[pos * DILATION_X];", // NOLINT "else", // NOLINT "pblockA00[pos] = 0;", // NOLINT "})", // NOLINT -"curr_y += DILATION_Y;", // NOLINT +"curr_y0 += DILATION_Y;", // NOLINT +"Dtype_t blockA01;", // NOLINT +"Dtype* pblockA01 = (Dtype*)(&blockA01);", // NOLINT +"pos = 0;", // NOLINT +"LOOP(KERNEL_WIDTH, pos,", // NOLINT +"{", // NOLINT +"if (curr_y1 >= INPUT_PAD_H && curr_y1 < input_height + INPUT_PAD_H && curr_x1 + pos * DILATION_X >= INPUT_PAD_W && curr_x1 + pos * DILATION_X < input_width + INPUT_PAD_W)", // NOLINT +"pblockA01[pos] = src0_read1[pos * DILATION_X];", // NOLINT +"else", // NOLINT +"pblockA01[pos] = 0;", // NOLINT +"})", // NOLINT +"curr_y1 += DILATION_Y;", // NOLINT +"src0_read0 += (ROW_PITCH * DILATION_Y);", // NOLINT +"src0_read1 += (ROW_PITCH * DILATION_Y);", // NOLINT "#endif", // NOLINT -"src0_read += ROW_PITCH * DILATION_X;", // NOLINT -"uint blockB00[KERNEL_WIDTH * 2];", // NOLINT -"uint4* p4BlockB00 = (uint4*)blockB00;", // NOLINT -"uint2* p2BlockB00 = (uint2*)blockB00;", // NOLINT -"Dtype* pBlockB00 = (Dtype*)blockB00;", // NOLINT +"Dtype blockB00[KERNEL_WIDTH*4];", // NOLINT +"Dtype8* p8BlockB00 = (Dtype8*)blockB00;", // NOLINT +"Dtype4* p4BlockB00 = (Dtype4*)blockB00;", // NOLINT +"Dtype* pBlockB00 = (Dtype* )blockB00;", // NOLINT "", // NOLINT "interleaved_y = 0;", // NOLINT "LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT "{", // NOLINT -"p4BlockB00[interleaved_y] = intel_sub_group_block_read4( (const __global uint*)src1_read );", // NOLINT +"p8BlockB00[interleaved_y] = as_Dtype8( SUB_GROUP_BLOCK_READ8( (const __global INT_TYPE*)src1_read ) );", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "} )", // NOLINT "if ( kernel_width_is_odd )", // NOLINT "{", // NOLINT -"p2BlockB00[KERNEL_WIDTH - 1] = intel_sub_group_block_read2( (const __global uint*)src1_read );", // NOLINT +"p4BlockB00[KERNEL_WIDTH - 1] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) );", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "}", // NOLINT -"", // NOLINT "// Perform MADs", // NOLINT "kernel_idx = 0;", // NOLINT "interleaved_y = 0;", // NOLINT "LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT "{", // NOLINT "kernel_y = interleaved_y * 2;", // NOLINT -"DOT_PRODUCT_16( blockC00, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_16( blockC00, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_16( blockC10, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_16( blockC10, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC00, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC01, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC00, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC01, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC10, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC11, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC10, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC11, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC20, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC21, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC20, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC21, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC30, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC31, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC30, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC31, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT "} )", // NOLINT "if ( kernel_width_is_odd )", // NOLINT "{", // NOLINT "kernel_y = interleaved_y * 2;", // NOLINT -"DOT_PRODUCT_16( blockC00, pblockA00[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_16( blockC10, pblockA00[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC00, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC01, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC10, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC11, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC20, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC21, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC30, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC31, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT "//while( ++patch_row < 1 ); //debug", // NOLINT "while( ++patch_row < KERNEL_HEIGHT );", // NOLINT -"", // NOLINT -"src0_read += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y ); // reset to start of next slice of patch", // NOLINT +"#if INPUT_PAD_W != 0 || INPUT_PAD_H != 0 || DILATION_X != 1 || DILATION_Y != 1", // NOLINT +"curr_y0 = saved_y0;", // NOLINT +"curr_y1 = saved_y1;", // NOLINT +"#endif", // NOLINT +"src0_read0 += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y ); // reset to start of next slice of patch", // NOLINT +"src0_read1 += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y );", // NOLINT "}", // NOLINT "//while ( ++patch_depth < 1 ); //debug", // NOLINT "while ( ++patch_depth < INPUT_DEPTH );", // NOLINT "", // NOLINT "// Dst resembles a cube of width x height x (output channel * batches). Each tile writes:", // NOLINT "// (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used.", // NOLINT -"int_tp out_offset = global_z * out_pitch_z // batch offset", // NOLINT -"+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT -"+ ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset", // NOLINT -"+ ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT -"__global Dtype *out = dst + out_offset;", // NOLINT +"int_tp out0_offset = global_z * out_pitch_z // batch offset", // NOLINT +"+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT +"+ ( ( global_y * TILE_M + 0 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset", // NOLINT +"+ ( ( global_y * TILE_M + 0 ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT +"int_tp out1_offset = global_z * out_pitch_z // batch offset", // NOLINT +"+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT +"+ ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset", // NOLINT +"+ ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT "", // NOLINT -"Dtype bias[2];", // NOLINT -"Dtype2 *bias_vec;", // NOLINT -"bias_vec = (Dtype2*)bias;", // NOLINT -"*bias_vec = as_float2(intel_sub_group_block_read2((__global uint *)biases + group_x * TILE_N));", // NOLINT -"// Work around a potential compiler bug.", // NOLINT -"if (group_x > 0xFFFFFFFEul) {", // NOLINT -"out[0] = bias[0] + bias[1];", // NOLINT -"}", // NOLINT +"Dtype bias[4];", // NOLINT +"Dtype4 *bias_vec;", // NOLINT +"bias_vec = (Dtype4*)bias;", // NOLINT +"*bias_vec = as_Dtype4(SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)biases + group_x * TILE_N));", // NOLINT "", // NOLINT -"if (global_y * TILE_M < output_width * output_height )", // NOLINT -"{", // NOLINT -"#if ( ( OUT_DEPTH % TILE_N ) == 0 )", // NOLINT -"for (int i = 0; i < 16; i++)", // NOLINT -"{", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT -"}", // NOLINT -"#elif ( ( OUT_DEPTH % 16 ) == 0 )", // NOLINT -"if ( ( global_x + 1 ) < get_global_size(0) )", // NOLINT +"if( global_y * TILE_M < output_width * output_height )", // NOLINT "{", // NOLINT -"for ( int i = 0; i < 16; i++ )", // NOLINT +"for( int i = 0; i < 8; i++ )", // NOLINT "{", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out0_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out0_offset + ( 8+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out0_offset + (16+i) * out_pitch_y, blockC20[i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out0_offset + (24+i) * out_pitch_y, blockC30[i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT "}", // NOLINT "}", // NOLINT -"else", // NOLINT +"if( global_y * TILE_M + 1 < output_width * output_height )", // NOLINT "{", // NOLINT -"for (int i = 0; i < 16; i++)", // NOLINT +"for( int i = 0; i < 8; i++ )", // NOLINT "{", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"}", // NOLINT +"ACTIVATION_FUNCTION(dst, out1_offset + ( 0+i) * out_pitch_y, blockC01[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out1_offset + ( 8+i) * out_pitch_y, blockC11[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out1_offset + (16+i) * out_pitch_y, blockC21[i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"ACTIVATION_FUNCTION(dst, out1_offset + (24+i) * out_pitch_y, blockC31[i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT "}", // NOLINT -"#else", // NOLINT -"if ( ( global_x + 1 ) < get_global_size(0) )", // NOLINT -"{", // NOLINT -"for ( int i = 0; i < 16; i++ )", // NOLINT -"{", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT "}", // NOLINT "}", // NOLINT +"#if TILE_N_LAST > 0", // NOLINT "else", // NOLINT "{", // NOLINT -"#if ( (OUT_DEPTH % TILE_N) > 16 )", // NOLINT -"{", // NOLINT -"for (int i = 0; i < 16 ; i++)", // NOLINT -"{", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"}", // NOLINT -"for (int i = 0; i < OUT_DEPTH % 16 ; i++)", // NOLINT -"{", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT -"}", // NOLINT -"}", // NOLINT -"#else", // NOLINT -"{", // NOLINT -"for (int i = 0; i < OUT_DEPTH % 16 ; i++)", // NOLINT -"{", // NOLINT -"ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"}", // NOLINT -"}", // NOLINT -"#endif", // NOLINT -"}", // NOLINT -"#endif", // NOLINT -"}", // NOLINT -"}", // NOLINT -"#endif", // NOLINT -"", // NOLINT -"#ifdef GEMM_LIKE_CONV_32_2", // NOLINT "", // NOLINT -"//////////////////////////////////////////////////////////////////////////////", // NOLINT -"// Conv_Interleaved_32_2_flex", // NOLINT -"//", // NOLINT -"// Convolution: each workitem computes 1 patch x 32 filters worth of output", // NOLINT -"// data. Kernel's inner loop works on a single tile consisting of one", // NOLINT -"// row from each patch and the filter data corresponding to that row. Filter", // NOLINT -"// matrix is interleaved to reduce GRF bank conflicts. Patches are walked", // NOLINT -"// by rows and then by slices. Relies on sub_group extension for block", // NOLINT -"// reads and SIMD broadcast. Allows flexible sizing of TILE width (TILE_N)", // NOLINT -"// by dynamically selecting one of two code paths: one uses TILE_N = 32 and", // NOLINT -"// the other uses TILE_N = 8, 16, or 24.", // NOLINT -"#define TILE_M 2", // NOLINT -"#define TILE_K KERNEL_WIDTH", // NOLINT -"#define TILE_N 32", // NOLINT -"", // NOLINT -"#ifdef __BEIGNET__", // NOLINT -"__attribute__((intel_reqd_sub_group_size(8)))", // NOLINT -"#endif", // NOLINT -"__kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS)", // NOLINT -"{", // NOLINT -"const int group_x = get_group_id(0);", // NOLINT -"const int group_y = get_group_id(1);", // NOLINT -"const int global_x = get_global_id(0);", // NOLINT -"const int global_y = get_global_id(1);", // NOLINT -"const int global_z = get_global_id(2);", // NOLINT -"int interleaved_y;", // NOLINT -"int kernel_y;", // NOLINT -"int kernel_idx;", // NOLINT -"", // NOLINT -"#define DOT_PRODUCT_8( _result, _rowA, colB ) { _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); }", // NOLINT -"typedef CAT( float, KERNEL_WIDTH ) float_t;", // NOLINT -"", // NOLINT -"// True for all threads if filter_width is multiple of TILE_N", // NOLINT -"// else, true for all but right-most column of threads.", // NOLINT -"if( TILE_N_LAST == 0 || global_x < WIDTH1 / TILE_N )", // NOLINT -"{", // NOLINT "// Result ctile (*dst) is M rows x N columns", // NOLINT "// LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile.", // NOLINT -"float8 blockC00 = 0.f;", // NOLINT -"float8 blockC10 = 0.f;", // NOLINT -"float8 blockC20 = 0.f;", // NOLINT -"float8 blockC30 = 0.f;", // NOLINT -"float8 blockC01 = 0.f;", // NOLINT -"float8 blockC11 = 0.f;", // NOLINT -"float8 blockC21 = 0.f;", // NOLINT -"float8 blockC31 = 0.f;", // NOLINT +"int i = 0;", // NOLINT +"Dtype8 blockC0[TILE_N_LAST_DIV8];", // NOLINT +"Dtype8 blockC1[TILE_N_LAST_DIV8];", // NOLINT +"LOOP(TILE_N_LAST_DIV8, i,", // NOLINT +"{", // NOLINT +"blockC0[i] = 0.f;", // NOLINT +"blockC1[i] = 0.f;", // NOLINT +"} )", // NOLINT "", // NOLINT "// Src0 (patch input) is directly used as atile.", // NOLINT "// Each work item points to the start of a different patch.", // NOLINT @@ -1492,11 +1537,11 @@ static std::vector> cl_kernels{ "int saved_y0 = curr_y0;", // NOLINT "int saved_y1 = curr_y1;", // NOLINT "#endif", // NOLINT -"const __global float *src0_read0 = src0", // NOLINT +"const __global Dtype *src0_read0 = src0", // NOLINT "+ aligned_input_size * global_z // batch offset", // NOLINT "+ (curr_y0 - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT "+ curr_x0 - INPUT_PAD_W; // x offset", // NOLINT -"const __global float *src0_read1 = src0", // NOLINT +"const __global Dtype *src0_read1 = src0", // NOLINT "+ aligned_input_size * global_z // batch offset", // NOLINT "+ (curr_y1 - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT "+ curr_x1 - INPUT_PAD_W; // x offset", // NOLINT @@ -1504,7 +1549,7 @@ static std::vector> cl_kernels{ "// Src1 (filter) is directly used as btile.", // NOLINT "// It starts at the top of src1 and walks down.", // NOLINT "// btile is K rows x N columns.", // NOLINT -"const __global float *src1_read = src1 + ( global_x * TILE_N * 2);", // NOLINT +"const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2);", // NOLINT "", // NOLINT "// Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1.", // NOLINT "// Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch", // NOLINT @@ -1515,25 +1560,16 @@ static std::vector> cl_kernels{ "int patch_row = 0;", // NOLINT "do", // NOLINT "{", // NOLINT -"// Load atile and btile.", // NOLINT -"// Kernel data is partially interleaved. Every 2 rows are interleaved at float8 granularity.", // NOLINT -"// The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non", // NOLINT -"// interleaved row is padded with zero to ensure same size as interleaved rows. This", // NOLINT -"// interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the", // NOLINT -"// kernel data would be arranged before/after interleaving for KERNEL_WIDTH=3.", // NOLINT -"// (0, 0) (8, 0) (16, 0) (24, 0) ... (0, 0) (0, 1) (8, 0) (0, 1) (16, 0) (0, 1) (24, 0) ..", // NOLINT -"// (0, 1) (8, 1) (16, 1) (24, 1) ... => (0, 2) (8, 2) (16, 2) (24, 2) ...", // NOLINT -"// (0, 2) (8, 2) (16, 2) (24, 2) ... ...", // NOLINT -"// ...", // NOLINT +"// Load atile and interleaved btile.", // NOLINT "const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1;", // NOLINT "#if INPUT_PAD_H == 0 && INPUT_PAD_W == 0 && DILATION_X == 1 && DILATION_Y == 1", // NOLINT -"float_t blockA00 = ( (const __global float_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH;", // NOLINT -"float_t blockA01 = ( (const __global float_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH;", // NOLINT -"float* pblockA00 = (float*)(&blockA00);", // NOLINT -"float* pblockA01 = (float*)(&blockA01);", // NOLINT +"Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH;", // NOLINT +"Dtype_t blockA01 = ( (const __global Dtype_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH;", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT +"Dtype* pblockA01 = (Dtype*)(&blockA01);", // NOLINT "#else", // NOLINT -"float_t blockA00;", // NOLINT -"float* pblockA00 = (float*)(&blockA00);", // NOLINT +"Dtype_t blockA00;", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT "int pos = 0;", // NOLINT "LOOP(KERNEL_WIDTH, pos,", // NOLINT "{", // NOLINT @@ -1543,8 +1579,8 @@ static std::vector> cl_kernels{ "pblockA00[pos] = 0;", // NOLINT "})", // NOLINT "curr_y0 += DILATION_Y;", // NOLINT -"float_t blockA01;", // NOLINT -"float* pblockA01 = (float*)(&blockA01);", // NOLINT +"Dtype_t blockA01;", // NOLINT +"Dtype* pblockA01 = (Dtype*)(&blockA01);", // NOLINT "pos = 0;", // NOLINT "LOOP(KERNEL_WIDTH, pos,", // NOLINT "{", // NOLINT @@ -1554,59 +1590,81 @@ static std::vector> cl_kernels{ "pblockA01[pos] = 0;", // NOLINT "})", // NOLINT "curr_y1 += DILATION_Y;", // NOLINT -"src0_read0 += ROW_PITCH * DILATION_Y;", // NOLINT -"src0_read1 += ROW_PITCH * DILATION_Y;", // NOLINT +"src0_read0 += (ROW_PITCH * DILATION_Y);", // NOLINT +"src0_read1 += (ROW_PITCH * DILATION_Y);", // NOLINT "#endif", // NOLINT -"float blockB00[KERNEL_WIDTH*4];", // NOLINT -"float8* p8BlockB00 = (float8*)blockB00;", // NOLINT -"float4* p4BlockB00 = (float4*)blockB00;", // NOLINT -"float* pBlockB00 = (float* )blockB00;", // NOLINT +"Dtype blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8];", // NOLINT "", // NOLINT "interleaved_y = 0;", // NOLINT "LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT "{", // NOLINT -"p8BlockB00[interleaved_y] = as_float8( intel_sub_group_block_read8( (const __global uint*)src1_read ) );", // NOLINT +"#if TILE_N_LAST_DIV8 == 1", // NOLINT +"Dtype2* p2BlockB = (Dtype2* )blockB;", // NOLINT +"p2BlockB[interleaved_y] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) );", // NOLINT +"#elif TILE_N_LAST_DIV8 == 2", // NOLINT +"Dtype4* p4BlockB = (Dtype4* )blockB;", // NOLINT +"p4BlockB[interleaved_y] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) );", // NOLINT +"#elif TILE_N_LAST_DIV8 == 3", // NOLINT +"//TODO: broken. No block_read6", // NOLINT +"Dtype6* p6BlockB = (Dtype6* )blockB;", // NOLINT +"(*((Dtype8*)(&p6BlockB[interleaved_y]))).s0123 = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) );", // NOLINT +"(*((Dtype8*)(&p6BlockB[interleaved_y]))).s45 = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)(src1_read + 4 * 8) ) );", // NOLINT +"#endif", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "} )", // NOLINT "if ( kernel_width_is_odd )", // NOLINT "{", // NOLINT -"p4BlockB00[KERNEL_WIDTH - 1] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) );", // NOLINT +"#if TILE_N_LAST_DIV8 == 1", // NOLINT +"Dtype* pBlockB = (Dtype* )blockB;", // NOLINT +"pBlockB[KERNEL_WIDTH - 1] = as_Dtype( SUB_GROUP_BLOCK_READ( (const __global INT_TYPE*)src1_read ) );", // NOLINT +"#elif TILE_N_LAST_DIV8 == 2", // NOLINT +"Dtype2* p2BlockB = (Dtype2* )blockB;", // NOLINT +"p2BlockB[KERNEL_WIDTH - 1] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) );", // NOLINT +"#elif TILE_N_LAST_DIV8 == 3", // NOLINT +"Dtype3* p3BlockB = (Dtype3* )blockB;", // NOLINT +"p3BlockB[KERNEL_WIDTH - 1].s01 = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) );", // NOLINT +"p3BlockB[KERNEL_WIDTH - 1].s2 = as_Dtype( SUB_GROUP_BLOCK_READ( (const __global INT_TYPE*) (src1_read + 8) ) );", // NOLINT +"#endif", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "}", // NOLINT +"", // NOLINT "// Perform MADs", // NOLINT +"Dtype* pBlockB = (Dtype*)blockB;", // NOLINT "kernel_idx = 0;", // NOLINT "interleaved_y = 0;", // NOLINT "LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT "{", // NOLINT "kernel_y = interleaved_y * 2;", // NOLINT -"DOT_PRODUCT_8( blockC00, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC01, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC00, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC01, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC10, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC11, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC10, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC11, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC20, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC21, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC20, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC21, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC30, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC31, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC30, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC31, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC0[0], pblockA00[kernel_y ], pBlockB[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC1[0], pblockA01[kernel_y ], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC0[0], pblockA00[kernel_y + 1], pBlockB[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC1[0], pblockA01[kernel_y + 1], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT +"#if TILE_N_LAST_DIV8 >= 2", // NOLINT +"DOT_PRODUCT_8( blockC0[1], pblockA00[kernel_y ], pBlockB[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC1[1], pblockA01[kernel_y ], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC0[1], pblockA00[kernel_y + 1], pBlockB[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC1[1], pblockA01[kernel_y + 1], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT +"#if TILE_N_LAST_DIV8 >= 3", // NOLINT +"DOT_PRODUCT_8( blockC0[2], pblockA00[kernel_y ], pBlockB[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC1[2], pblockA01[kernel_y ], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC0[2], pblockA00[kernel_y + 1], pBlockB[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC1[2], pblockA01[kernel_y + 1], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT +"#endif", // NOLINT +"#endif", // NOLINT "} )", // NOLINT +"kernel_y = interleaved_y * 2;", // NOLINT "if ( kernel_width_is_odd )", // NOLINT "{", // NOLINT -"kernel_y = interleaved_y * 2;", // NOLINT -"DOT_PRODUCT_8( blockC00, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC01, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC10, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC11, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC20, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC21, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC30, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC31, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_8( blockC0[0], pblockA00[kernel_y], pBlockB[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC1[0], pblockA01[kernel_y], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT +"#if TILE_N_LAST_DIV8 >= 2", // NOLINT +"DOT_PRODUCT_8( blockC0[1], pblockA00[kernel_y], pBlockB[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC1[1], pblockA01[kernel_y], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT +"#if TILE_N_LAST_DIV8 >= 3", // NOLINT +"DOT_PRODUCT_8( blockC0[2], pblockA00[kernel_y], pBlockB[kernel_idx] );", // NOLINT +"DOT_PRODUCT_8( blockC1[2], pblockA01[kernel_y], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT +"#endif", // NOLINT +"#endif", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT @@ -1632,47 +1690,232 @@ static std::vector> cl_kernels{ "+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT "+ ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT +"__global Dtype *out1 = dst + out1_offset;", // NOLINT "", // NOLINT -"float bias[4];", // NOLINT -"float4 *bias_vec;", // NOLINT -"bias_vec = (float4*)bias;", // NOLINT -"*bias_vec = as_float4(intel_sub_group_block_read4((__global uint *)biases + group_x * TILE_N));", // NOLINT -"", // NOLINT +"Dtype bias[4];", // NOLINT +"Dtype4 *bias_vec;", // NOLINT +"bias_vec = (Dtype4*)bias;", // NOLINT +"*bias_vec = as_Dtype4(SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)biases + group_x * TILE_N));", // NOLINT "if( global_y * TILE_M < output_width * output_height )", // NOLINT "{", // NOLINT "for( int i = 0; i < 8; i++ )", // NOLINT "{", // NOLINT -"ACTIVATION_FUNCTION(dst, out0_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"ACTIVATION_FUNCTION(dst, out0_offset + ( 8+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT -"ACTIVATION_FUNCTION(dst, out0_offset + (16+i) * out_pitch_y, blockC20[i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT -"ACTIVATION_FUNCTION(dst, out0_offset + (24+i) * out_pitch_y, blockC30[i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out0_offset + ( 0+i) * out_pitch_y, blockC0[0][i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out0_offset + ( 8+i) * out_pitch_y, blockC0[1][i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out0_offset + (16+i) * out_pitch_y, blockC0[2][i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out0_offset + (24+i) * out_pitch_y, blockC0[3][i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT "}", // NOLINT "}", // NOLINT "if( global_y * TILE_M + 1 < output_width * output_height )", // NOLINT "{", // NOLINT "for( int i = 0; i < 8; i++ )", // NOLINT "{", // NOLINT -"ACTIVATION_FUNCTION(dst, out1_offset + ( 0+i) * out_pitch_y, blockC01[i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"ACTIVATION_FUNCTION(dst, out1_offset + ( 8+i) * out_pitch_y, blockC11[i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT -"ACTIVATION_FUNCTION(dst, out1_offset + (16+i) * out_pitch_y, blockC21[i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT -"ACTIVATION_FUNCTION(dst, out1_offset + (24+i) * out_pitch_y, blockC31[i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out1_offset + ( 0+i) * out_pitch_y, blockC1[0][i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out1_offset + ( 8+i) * out_pitch_y, blockC1[1][i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out1_offset + (16+i) * out_pitch_y, blockC1[2][i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT +"if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out1_offset + (24+i) * out_pitch_y, blockC1[3][i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT +"#endif", // NOLINT +"}", // NOLINT +"#endif", // NOLINT +"#if defined(GEMM_LIKE_CONV_32_2_SIMD16) || defined(GEMM_LIKE_CONV_32_1_SIMD16)", // NOLINT +"", // NOLINT +"#define INTERLEAVED_SIMD16_OUTPUT(_out_, _offset_, _m_) do { if (global_y * TILE_M < output_width * output_height ) { if ( ( OUT_DEPTH % TILE_N ) == 0 ) { for (int i = 0; i < 16; i++) { ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_ [i] + intel_sub_group_shuffle(bias[0], i)); ACTIVATION_FUNCTION(_out_, _offset_ + (16+i) * out_pitch_y, blockC1 ##_m_ [i] + intel_sub_group_shuffle(bias[1], i)); } } else if( ( OUT_DEPTH % 16 ) == 0 ) { if ( ( global_x + 1 ) < get_global_size(0) ) { for ( int i = 0; i < 16; i++ ) { ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_ [i] + intel_sub_group_shuffle(bias[0], i)); ACTIVATION_FUNCTION(_out_, _offset_ + (16+i) * out_pitch_y, blockC1 ##_m_ [i] + intel_sub_group_shuffle(bias[1], i)); } } else { for (int i = 0; i < 16; i++) { ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_ [i] + intel_sub_group_shuffle(bias[0], i)); } } } else { if ( ( global_x + 1 ) < get_global_size(0) ) { for ( int i = 0; i < 16; i++ ) { ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_[i] + intel_sub_group_shuffle(bias[0], i)); ACTIVATION_FUNCTION(_out_, _offset_ + (16+i) * out_pitch_y, blockC1 ##_m_[i] + intel_sub_group_shuffle(bias[1], i)); } } else { if ( (OUT_DEPTH % TILE_N) > 16 ) { for (int i = 0; i < 16 ; i++) { ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_[i] + intel_sub_group_shuffle(bias[0], i)); } for (int i = 0; i < OUT_DEPTH % 16 ; i++) { ACTIVATION_FUNCTION(_out_, _offset_ + (16+i) * out_pitch_y, blockC1 ##_m_[i] + intel_sub_group_shuffle(bias[1], i)); } } else { for (int i = 0; i < OUT_DEPTH % 16 ; i++) { ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_[i] + intel_sub_group_shuffle(bias[0], i)); } } } } } }while(0)", // NOLINT +"#endif", // NOLINT +"", // NOLINT +"#ifdef GEMM_LIKE_CONV_32_1_SIMD16", // NOLINT +"#define TILE_M 1", // NOLINT +"#define TILE_K KERNEL_WIDTH", // NOLINT +"#define TILE_N 32", // NOLINT +"", // NOLINT +"#ifndef __BEIGNET__", // NOLINT +"__attribute__((intel_reqd_sub_group_size(16)))", // NOLINT +"#endif", // NOLINT +"__kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS)", // NOLINT +"{", // NOLINT +"const int group_x = get_group_id(0);", // NOLINT +"const int group_y = get_group_id(1);", // NOLINT +"const int global_x = get_global_id(0);", // NOLINT +"const int global_y = get_global_id(1);", // NOLINT +"const int global_z = get_global_id(2);", // NOLINT +"int interleaved_y;", // NOLINT +"int kernel_y;", // NOLINT +"int kernel_idx;", // NOLINT +"", // NOLINT +"// Result ctile (*dst) is M rows x N columns", // NOLINT +"// LWG size is 1x16. Thus each thread calculates 16*M rows x N cols of ctile.", // NOLINT +"Dtype16 blockC00 = 0.f;", // NOLINT +"Dtype16 blockC10 = 0.f;", // NOLINT +"", // NOLINT +"// Src0 (patch input) is directly used as atile.", // NOLINT +"// Each work item points to the start of a different patch.", // NOLINT +"// atile is M rows x K columns.", // NOLINT +"int curr_x = ( global_y % output_width ) * STRIDE_X;", // NOLINT +"int curr_y = ( global_y / output_width ) * STRIDE_Y;", // NOLINT +"#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1", // NOLINT +"int saved_y = curr_y;", // NOLINT +"#endif", // NOLINT +"const __global Dtype *src0_read = src0", // NOLINT +"+ aligned_input_size * global_z // batch offset", // NOLINT +"+ (curr_y - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT +"+ curr_x - INPUT_PAD_W; // x offset", // NOLINT +"const __global Dtype *src0_read_orig = src0_read;", // NOLINT +"", // NOLINT +"// Src1 (filter) is directly used as btile.", // NOLINT +"// It starts at the top of src1 and walks down.", // NOLINT +"// btile is K rows x N columns.", // NOLINT +"const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2 );", // NOLINT +"", // NOLINT +"#define DOT_PRODUCT_16( _result, _rowA, colB ) { _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); _result.s8 = mad( _rowA, sub_group_broadcast( colB, 8 ), _result.s8 ); _result.s9 = mad( _rowA, sub_group_broadcast( colB, 9 ), _result.s9 ); _result.sa = mad( _rowA, sub_group_broadcast( colB, 10 ), _result.sa ); _result.sb = mad( _rowA, sub_group_broadcast( colB, 11 ), _result.sb ); _result.sc = mad( _rowA, sub_group_broadcast( colB, 12 ), _result.sc ); _result.sd = mad( _rowA, sub_group_broadcast( colB, 13 ), _result.sd ); _result.se = mad( _rowA, sub_group_broadcast( colB, 14 ), _result.se ); _result.sf = mad( _rowA, sub_group_broadcast( colB, 15 ), _result.sf ); }", // NOLINT +"typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t;", // NOLINT +"// Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1.", // NOLINT +"// Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch", // NOLINT +"// and KERNEL_WIDTH/2 rows of interleaved filter.", // NOLINT +"int patch_depth = 0;", // NOLINT +"#ifndef __BEIGNET__", // NOLINT +"__attribute__((opencl_unroll_hint(1)))", // NOLINT +"#endif", // NOLINT +"do", // NOLINT +"{", // NOLINT +"int patch_row = 0;", // NOLINT +"#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1", // NOLINT +"curr_y = saved_y;", // NOLINT +"#endif", // NOLINT +"#ifndef __BEIGNET__", // NOLINT +"__attribute__((opencl_unroll_hint(1)))", // NOLINT +"#endif", // NOLINT +"do", // NOLINT +"{", // NOLINT +"// Load atile and btile.", // NOLINT +"// Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype16 granularity.", // NOLINT +"// The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non", // NOLINT +"// interleaved row is padded with zero to ensure same size as interleaved rows. This", // NOLINT +"// interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the", // NOLINT +"// kernel data would be arranged before/after interleaving for KERNEL_WIDTH=3.", // NOLINT +"// (0, 0) (16, 0) (32, 0) (48, 0) ... (0, 0) ( 0, 1) (16, 0) ( 0, 1) (32, 0) (0, 1) (48, 0) ...", // NOLINT +"// (0, 1) (16, 1) (32, 1) (48, 1) ... => (0, 2) (16, 2) (32, 2) (48, 2) ...", // NOLINT +"// (0, 2) (16, 2) (32, 2) (48, 2) ... ...", // NOLINT +"// ...", // NOLINT +"const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1;", // NOLINT +"", // NOLINT +"#if INPUT_PAD_W == 0 && INPUT_PAD_H == 0 && DILATION_X == 1 && DILATION_Y == 1", // NOLINT +"Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[ 0 ];", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT +"#else", // NOLINT +"Dtype_t blockA00;", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT +"int pos = 0;", // NOLINT +"LOOP(KERNEL_WIDTH, pos,", // NOLINT +"{", // NOLINT +"if (curr_y >= INPUT_PAD_H && curr_y < input_height + INPUT_PAD_H && curr_x + pos * DILATION_X >= INPUT_PAD_W && curr_x + pos * DILATION_X < input_width + INPUT_PAD_W)", // NOLINT +"pblockA00[pos] = src0_read[pos * DILATION_X];", // NOLINT +"else", // NOLINT +"pblockA00[pos] = 0;", // NOLINT +"})", // NOLINT +"curr_y += DILATION_Y;", // NOLINT +"#endif", // NOLINT +"src0_read += ROW_PITCH * DILATION_Y;", // NOLINT +"INT_TYPE blockB00[KERNEL_WIDTH * 2];", // NOLINT +"INT_TYPE4* p4BlockB00 = (INT_TYPE4*)blockB00;", // NOLINT +"INT_TYPE2* p2BlockB00 = (INT_TYPE2*)blockB00;", // NOLINT +"Dtype* pBlockB00 = (Dtype*)blockB00;", // NOLINT +"interleaved_y = 0;", // NOLINT +"LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT +"{", // NOLINT +"p4BlockB00[interleaved_y] = SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read );", // NOLINT +"src1_read += WIDTH1 * 2;", // NOLINT +"} )", // NOLINT +"if ( kernel_width_is_odd )", // NOLINT +"{", // NOLINT +"p2BlockB00[KERNEL_WIDTH - 1] = SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read );", // NOLINT +"src1_read += WIDTH1 * 2;", // NOLINT +"}", // NOLINT +"", // NOLINT +"// Perform MADs", // NOLINT +"kernel_idx = 0;", // NOLINT +"interleaved_y = 0;", // NOLINT +"LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT +"{", // NOLINT +"kernel_y = interleaved_y * 2;", // NOLINT +"DOT_PRODUCT_16( blockC00, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_16( blockC00, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_16( blockC10, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_16( blockC10, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"} )", // NOLINT +"if ( kernel_width_is_odd )", // NOLINT +"{", // NOLINT +"kernel_y = interleaved_y * 2;", // NOLINT +"DOT_PRODUCT_16( blockC00, pblockA00[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_16( blockC10, pblockA00[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT "}", // NOLINT "}", // NOLINT +"", // NOLINT +"//while( ++patch_row < 1 ); //debug", // NOLINT +"while( ++patch_row < KERNEL_HEIGHT );", // NOLINT +"", // NOLINT +"src0_read += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y ); // reset to start of next slice of patch", // NOLINT "}", // NOLINT -"#if TILE_N_LAST > 0", // NOLINT -"else", // NOLINT +"//while ( ++patch_depth < 1 ); //debug", // NOLINT +"while ( ++patch_depth < INPUT_DEPTH );", // NOLINT +"", // NOLINT +"// Dst resembles a cube of width x height x (output channel * batches). Each tile writes:", // NOLINT +"// (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used.", // NOLINT +"int_tp out_offset = global_z * out_pitch_z // batch offset", // NOLINT +"+ ( group_x * TILE_N ) * out_pitch_y // channel offset", // NOLINT +"+ ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset", // NOLINT +"+ ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT +"__global Dtype *out = dst + out_offset;", // NOLINT +"", // NOLINT +"Dtype bias[2];", // NOLINT +"Dtype2 *bias_vec;", // NOLINT +"bias_vec = (Dtype2*)bias;", // NOLINT +"*bias_vec = as_Dtype2(SUB_GROUP_BLOCK_READ2((__global INT_TYPE *)biases + group_x * TILE_N));", // NOLINT +"// Work around a potential compiler bug.", // NOLINT +"if (group_x > 0xFFFFFFFEul) {", // NOLINT +"out[0] = bias[0] + bias[1];", // NOLINT +"}", // NOLINT +"INTERLEAVED_SIMD16_OUTPUT(dst, out_offset, 0);", // NOLINT +"}", // NOLINT +"#endif", // NOLINT +"", // NOLINT +"#ifdef GEMM_LIKE_CONV_32_2_SIMD16", // NOLINT +"", // NOLINT +"//////////////////////////////////////////////////////////////////////////////", // NOLINT +"// Conv_Interleaved_32_2_SIMD16", // NOLINT +"//", // NOLINT +"// Convolution: each workitem computes 1 patch x 32 filters worth of output", // NOLINT +"// data.", // NOLINT +"#define TILE_M 2", // NOLINT +"#define TILE_K KERNEL_WIDTH", // NOLINT +"#define TILE_N 32", // NOLINT +"", // NOLINT +"#ifndef __BEIGNET__", // NOLINT +"__attribute__((intel_reqd_sub_group_size(16)))", // NOLINT +"#endif", // NOLINT +"__kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS)", // NOLINT "{", // NOLINT +"const int group_x = get_group_id(0);", // NOLINT +"const int group_y = get_group_id(1);", // NOLINT +"const int global_x = get_global_id(0);", // NOLINT +"const int global_y = get_global_id(1);", // NOLINT +"const int global_z = get_global_id(2);", // NOLINT +"int interleaved_y;", // NOLINT +"int kernel_y;", // NOLINT +"int kernel_idx;", // NOLINT +"#define DOT_PRODUCT_16( _result, _rowA, colB ) { _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); _result.s8 = mad( _rowA, sub_group_broadcast( colB, 8 ), _result.s8 ); _result.s9 = mad( _rowA, sub_group_broadcast( colB, 9 ), _result.s9 ); _result.sa = mad( _rowA, sub_group_broadcast( colB, 10 ), _result.sa ); _result.sb = mad( _rowA, sub_group_broadcast( colB, 11 ), _result.sb ); _result.sc = mad( _rowA, sub_group_broadcast( colB, 12 ), _result.sc ); _result.sd = mad( _rowA, sub_group_broadcast( colB, 13 ), _result.sd ); _result.se = mad( _rowA, sub_group_broadcast( colB, 14 ), _result.se ); _result.sf = mad( _rowA, sub_group_broadcast( colB, 15 ), _result.sf ); }", // NOLINT +"typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t;", // NOLINT "", // NOLINT +"// True for all threads if filter_width is multiple of TILE_N", // NOLINT +"// else, true for all but right-most column of threads.", // NOLINT +"{", // NOLINT "// Result ctile (*dst) is M rows x N columns", // NOLINT "// LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile.", // NOLINT -"int i = 0;", // NOLINT -"float8 blockC0[TILE_N_LAST_DIV8];", // NOLINT -"float8 blockC1[TILE_N_LAST_DIV8];", // NOLINT -"LOOP(TILE_N_LAST_DIV8, i,", // NOLINT -"{", // NOLINT -"blockC0[i] = 0.f;", // NOLINT -"blockC1[i] = 0.f;", // NOLINT -"} )", // NOLINT +"Dtype16 blockC00 = 0.f;", // NOLINT +"Dtype16 blockC10 = 0.f;", // NOLINT +"Dtype16 blockC01 = 0.f;", // NOLINT +"Dtype16 blockC11 = 0.f;", // NOLINT "", // NOLINT "// Src0 (patch input) is directly used as atile.", // NOLINT "// Each work item points to the start of a different patch.", // NOLINT @@ -1685,11 +1928,11 @@ static std::vector> cl_kernels{ "int saved_y0 = curr_y0;", // NOLINT "int saved_y1 = curr_y1;", // NOLINT "#endif", // NOLINT -"const __global float *src0_read0 = src0", // NOLINT +"const __global Dtype *src0_read0 = src0", // NOLINT "+ aligned_input_size * global_z // batch offset", // NOLINT "+ (curr_y0 - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT "+ curr_x0 - INPUT_PAD_W; // x offset", // NOLINT -"const __global float *src0_read1 = src0", // NOLINT +"const __global Dtype *src0_read1 = src0", // NOLINT "+ aligned_input_size * global_z // batch offset", // NOLINT "+ (curr_y1 - INPUT_PAD_H) * ROW_PITCH // y offset", // NOLINT "+ curr_x1 - INPUT_PAD_W; // x offset", // NOLINT @@ -1697,7 +1940,7 @@ static std::vector> cl_kernels{ "// Src1 (filter) is directly used as btile.", // NOLINT "// It starts at the top of src1 and walks down.", // NOLINT "// btile is K rows x N columns.", // NOLINT -"const __global float *src1_read = src1 + ( global_x * TILE_N * 2);", // NOLINT +"const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2);", // NOLINT "", // NOLINT "// Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1.", // NOLINT "// Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch", // NOLINT @@ -1708,16 +1951,25 @@ static std::vector> cl_kernels{ "int patch_row = 0;", // NOLINT "do", // NOLINT "{", // NOLINT -"// Load atile and interleaved btile.", // NOLINT +"// Load atile and btile.", // NOLINT +"// Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype8 granularity.", // NOLINT +"// The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non", // NOLINT +"// interleaved row is padded with zero to ensure same size as interleaved rows. This", // NOLINT +"// interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the", // NOLINT +"// kernel data would be arranged before/after interleaving for KERNEL_WIDTH=3.", // NOLINT +"// (0, 0) (8, 0) (16, 0) (24, 0) ... (0, 0) (0, 1) (8, 0) (0, 1) (16, 0) (0, 1) (24, 0) ..", // NOLINT +"// (0, 1) (8, 1) (16, 1) (24, 1) ... => (0, 2) (8, 2) (16, 2) (24, 2) ...", // NOLINT +"// (0, 2) (8, 2) (16, 2) (24, 2) ... ...", // NOLINT +"// ...", // NOLINT "const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1;", // NOLINT "#if INPUT_PAD_H == 0 && INPUT_PAD_W == 0 && DILATION_X == 1 && DILATION_Y == 1", // NOLINT -"float_t blockA00 = ( (const __global float_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH;", // NOLINT -"float_t blockA01 = ( (const __global float_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH;", // NOLINT -"float* pblockA00 = (float*)(&blockA00);", // NOLINT -"float* pblockA01 = (float*)(&blockA01);", // NOLINT +"Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH;", // NOLINT +"Dtype_t blockA01 = ( (const __global Dtype_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH;", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT +"Dtype* pblockA01 = (Dtype*)(&blockA01);", // NOLINT "#else", // NOLINT -"float_t blockA00;", // NOLINT -"float* pblockA00 = (float*)(&blockA00);", // NOLINT +"Dtype_t blockA00;", // NOLINT +"Dtype* pblockA00 = (Dtype*)(&blockA00);", // NOLINT "int pos = 0;", // NOLINT "LOOP(KERNEL_WIDTH, pos,", // NOLINT "{", // NOLINT @@ -1727,8 +1979,8 @@ static std::vector> cl_kernels{ "pblockA00[pos] = 0;", // NOLINT "})", // NOLINT "curr_y0 += DILATION_Y;", // NOLINT -"float_t blockA01;", // NOLINT -"float* pblockA01 = (float*)(&blockA01);", // NOLINT +"Dtype_t blockA01;", // NOLINT +"Dtype* pblockA01 = (Dtype*)(&blockA01);", // NOLINT "pos = 0;", // NOLINT "LOOP(KERNEL_WIDTH, pos,", // NOLINT "{", // NOLINT @@ -1741,78 +1993,44 @@ static std::vector> cl_kernels{ "src0_read0 += (ROW_PITCH * DILATION_Y);", // NOLINT "src0_read1 += (ROW_PITCH * DILATION_Y);", // NOLINT "#endif", // NOLINT -"float blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8];", // NOLINT +"Dtype blockB00[KERNEL_WIDTH*2];", // NOLINT +"Dtype4* p4BlockB00 = (Dtype4*)blockB00;", // NOLINT +"Dtype2* p2BlockB00 = (Dtype2*)blockB00;", // NOLINT +"Dtype* pBlockB00 = (Dtype* )blockB00;", // NOLINT "", // NOLINT "interleaved_y = 0;", // NOLINT "LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT "{", // NOLINT -"#if TILE_N_LAST_DIV8 == 1", // NOLINT -"float2* p2BlockB = (float2* )blockB;", // NOLINT -"p2BlockB[interleaved_y] = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) );", // NOLINT -"#elif TILE_N_LAST_DIV8 == 2", // NOLINT -"float4* p4BlockB = (float4* )blockB;", // NOLINT -"p4BlockB[interleaved_y] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) );", // NOLINT -"#elif TILE_N_LAST_DIV8 == 3", // NOLINT -"//TODO: broken. No block_read6", // NOLINT -"float6* p6BlockB = (float6* )blockB;", // NOLINT -"(*((float8*)(&p6BlockB[interleaved_y]))).s0123 = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) );", // NOLINT -"(*((float8*)(&p6BlockB[interleaved_y]))).s45 = as_float2( intel_sub_group_block_read2( (const __global uint*)(src1_read + 4 * 8) ) );", // NOLINT -"#endif", // NOLINT +"p4BlockB00[interleaved_y] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) );", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "} )", // NOLINT "if ( kernel_width_is_odd )", // NOLINT "{", // NOLINT -"#if TILE_N_LAST_DIV8 == 1", // NOLINT -"float* pBlockB = (float* )blockB;", // NOLINT -"pBlockB[KERNEL_WIDTH - 1] = as_float( intel_sub_group_block_read( (const __global uint*)src1_read ) );", // NOLINT -"#elif TILE_N_LAST_DIV8 == 2", // NOLINT -"float2* p2BlockB = (float2* )blockB;", // NOLINT -"p2BlockB[KERNEL_WIDTH - 1] = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) );", // NOLINT -"#elif TILE_N_LAST_DIV8 == 3", // NOLINT -"float3* p3BlockB = (float3* )blockB;", // NOLINT -"p3BlockB[KERNEL_WIDTH - 1].s01 = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) );", // NOLINT -"p3BlockB[KERNEL_WIDTH - 1].s2 = as_float( intel_sub_group_block_read( (const __global uint*) (src1_read + 8) ) );", // NOLINT -"#endif", // NOLINT +"p2BlockB00[KERNEL_WIDTH - 1] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) );", // NOLINT "src1_read += WIDTH1 * 2;", // NOLINT "}", // NOLINT -"", // NOLINT "// Perform MADs", // NOLINT -"float* pBlockB = (float*)blockB;", // NOLINT "kernel_idx = 0;", // NOLINT "interleaved_y = 0;", // NOLINT "LOOP(KERNEL_WIDTH_DIV2, interleaved_y,", // NOLINT "{", // NOLINT "kernel_y = interleaved_y * 2;", // NOLINT -"DOT_PRODUCT_8( blockC0[0], pblockA00[kernel_y ], pBlockB[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC1[0], pblockA01[kernel_y ], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC0[0], pblockA00[kernel_y + 1], pBlockB[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC1[0], pblockA01[kernel_y + 1], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT -"#if TILE_N_LAST_DIV8 >= 2", // NOLINT -"DOT_PRODUCT_8( blockC0[1], pblockA00[kernel_y ], pBlockB[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC1[1], pblockA01[kernel_y ], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC0[1], pblockA00[kernel_y + 1], pBlockB[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC1[1], pblockA01[kernel_y + 1], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT -"#if TILE_N_LAST_DIV8 >= 3", // NOLINT -"DOT_PRODUCT_8( blockC0[2], pblockA00[kernel_y ], pBlockB[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC1[2], pblockA01[kernel_y ], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT -"DOT_PRODUCT_8( blockC0[2], pblockA00[kernel_y + 1], pBlockB[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC1[2], pblockA01[kernel_y + 1], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT -"#endif", // NOLINT -"#endif", // NOLINT +"DOT_PRODUCT_16( blockC00, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_16( blockC01, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_16( blockC00, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_16( blockC01, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_16( blockC10, pblockA00[kernel_y ], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_16( blockC11, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_16( blockC10, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_16( blockC11, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT "} )", // NOLINT -"kernel_y = interleaved_y * 2;", // NOLINT "if ( kernel_width_is_odd )", // NOLINT "{", // NOLINT -"DOT_PRODUCT_8( blockC0[0], pblockA00[kernel_y], pBlockB[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC1[0], pblockA01[kernel_y], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT -"#if TILE_N_LAST_DIV8 >= 2", // NOLINT -"DOT_PRODUCT_8( blockC0[1], pblockA00[kernel_y], pBlockB[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC1[1], pblockA01[kernel_y], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT -"#if TILE_N_LAST_DIV8 >= 3", // NOLINT -"DOT_PRODUCT_8( blockC0[2], pblockA00[kernel_y], pBlockB[kernel_idx] );", // NOLINT -"DOT_PRODUCT_8( blockC1[2], pblockA01[kernel_y], pBlockB[kernel_idx] ); kernel_idx++;", // NOLINT -"#endif", // NOLINT -"#endif", // NOLINT +"kernel_y = interleaved_y * 2;", // NOLINT +"DOT_PRODUCT_16( blockC00, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_16( blockC01, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT +"DOT_PRODUCT_16( blockC10, pblockA00[kernel_y], pBlockB00[kernel_idx] );", // NOLINT +"DOT_PRODUCT_16( blockC11, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT @@ -1822,8 +2040,8 @@ static std::vector> cl_kernels{ "curr_y0 = saved_y0;", // NOLINT "curr_y1 = saved_y1;", // NOLINT "#endif", // NOLINT -"src0_read0 += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y ); // reset to start of next slice of patch", // NOLINT -"src0_read1 += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y );", // NOLINT +"src0_read0 += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y); // reset to start of next slice of patch", // NOLINT +"src0_read1 += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y);", // NOLINT "}", // NOLINT "//while ( ++patch_depth < 1 ); //debug", // NOLINT "while ( ++patch_depth < INPUT_DEPTH );", // NOLINT @@ -1839,32 +2057,14 @@ static std::vector> cl_kernels{ "+ ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset", // NOLINT "+ ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset", // NOLINT "", // NOLINT -"float bias[4];", // NOLINT -"float4 *bias_vec;", // NOLINT -"bias_vec = (float4*)bias;", // NOLINT -"*bias_vec = as_float4(intel_sub_group_block_read4((__global uint *)biases + group_x * TILE_N));", // NOLINT -"if( global_y * TILE_M < output_width * output_height )", // NOLINT -"{", // NOLINT -"for( int i = 0; i < 8; i++ )", // NOLINT -"{", // NOLINT -"if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out0_offset + ( 0+i) * out_pitch_y, blockC0[0][i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out0_offset + ( 8+i) * out_pitch_y, blockC0[1][i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT -"if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out0_offset + (16+i) * out_pitch_y, blockC0[2][i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT -"if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out0_offset + (24+i) * out_pitch_y, blockC0[3][i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT -"}", // NOLINT -"}", // NOLINT -"if( global_y * TILE_M + 1 < output_width * output_height )", // NOLINT -"{", // NOLINT -"for( int i = 0; i < 8; i++ )", // NOLINT -"{", // NOLINT -"if ( TILE_N_LAST_DIV8 > 0 ) ACTIVATION_FUNCTION(dst, out1_offset + ( 0+i) * out_pitch_y, blockC1[0][i] + intel_sub_group_shuffle(bias[0], i));", // NOLINT -"if ( TILE_N_LAST_DIV8 > 1 ) ACTIVATION_FUNCTION(dst, out1_offset + ( 8+i) * out_pitch_y, blockC1[1][i] + intel_sub_group_shuffle(bias[1], i));", // NOLINT -"if ( TILE_N_LAST_DIV8 > 2 ) ACTIVATION_FUNCTION(dst, out1_offset + (16+i) * out_pitch_y, blockC1[2][i] + intel_sub_group_shuffle(bias[2], i));", // NOLINT -"if ( TILE_N_LAST_DIV8 > 3 ) ACTIVATION_FUNCTION(dst, out1_offset + (24+i) * out_pitch_y, blockC1[3][i] + intel_sub_group_shuffle(bias[3], i));", // NOLINT -"}", // NOLINT -"}", // NOLINT +"Dtype bias[2];", // NOLINT +"Dtype2 *bias_vec;", // NOLINT +"bias_vec = (Dtype2*)bias;", // NOLINT +"*bias_vec = as_Dtype2(SUB_GROUP_BLOCK_READ2((__global INT_TYPE *)biases + group_x * TILE_N));", // NOLINT +"", // NOLINT +"INTERLEAVED_SIMD16_OUTPUT(dst, out0_offset, 0);", // NOLINT +"INTERLEAVED_SIMD16_OUTPUT(dst, out1_offset, 1);", // NOLINT "}", // NOLINT -"#endif", // NOLINT "}", // NOLINT "#endif", // NOLINT ""}, // NOLINT @@ -1996,20 +2196,20 @@ static std::vector> cl_kernels{ "__global const Dtype* in,", // NOLINT "__global const uint_tp* mask,", // NOLINT "const uint_tp threshold,", // NOLINT -"const Dtype scale,", // NOLINT +"const KERNEL_ARG_DTYPE scale,", // NOLINT "__global Dtype* out) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT -"out[index] = in[index] * ((mask[index] > threshold)?1.0:0.0) * scale;", // NOLINT +"out[index] = in[index] * (Dtype)((mask[index] > threshold)?1.0:0.0) * scale;", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(dropout_backward,Dtype)(", // NOLINT "const int_tp n, __global const Dtype* in_diff,", // NOLINT "__global const uint_tp* mask, const uint_tp threshold,", // NOLINT -"const Dtype scale,", // NOLINT +"const KERNEL_ARG_DTYPE scale,", // NOLINT "__global Dtype* out_diff) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT -"out_diff[index] = in_diff[index] * ((mask[index] > threshold)?1.0:0.0) * scale;", // NOLINT +"out_diff[index] = in_diff[index] * (Dtype)((mask[index] > threshold)?1.0:0.0) * scale;", // NOLINT "}", // NOLINT "}", // NOLINT ""}, // NOLINT @@ -2024,7 +2224,7 @@ static std::vector> cl_kernels{ "__global int_tp* mask) {", // NOLINT "for (int_tp index = get_global_id(0); index < nthreads;", // NOLINT "index += get_global_size(0)) {", // NOLINT -"Dtype maxval = -FLT_MAX;", // NOLINT +"Dtype maxval = -DTYPE_MAX;", // NOLINT "int_tp maxidx = -1;", // NOLINT "if (bottom_data_a[index] > bottom_data_b[index]) {", // NOLINT "// only update for very first bottom_data blob (blob_idx == 0)", // NOLINT @@ -2065,9 +2265,9 @@ static std::vector> cl_kernels{ "", // NOLINT "__kernel void TEMPLATE(elu_forward,Dtype)(const int n, __global const Dtype* in,", // NOLINT "__global Dtype* out,", // NOLINT -"Dtype alpha) {", // NOLINT +"KERNEL_ARG_DTYPE alpha) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT -"out[index] = in[index] > 0 ? in[index] : alpha * (exp(in[index]) - 1.0);", // NOLINT +"out[index] = in[index] > 0 ? in[index] : alpha * (exp(in[index]) - (Dtype)1.0);", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT @@ -2075,7 +2275,7 @@ static std::vector> cl_kernels{ "__global const Dtype* out_data,", // NOLINT "__global const Dtype* in_data,", // NOLINT "__global Dtype* out_diff,", // NOLINT -"Dtype alpha) {", // NOLINT +"KERNEL_ARG_DTYPE alpha) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT "out_diff[index] =", // NOLINT "in_data[index] > 0 ?", // NOLINT @@ -2104,6 +2304,45 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "// atomic_add from: http://suhorukov.blogspot.com/2011/12/opencl-11-atomic-operations-on-floating.html", // NOLINT +"", // NOLINT +"// atomic_add fddrom: http://suhorukov.blogspot.com/2011/12/opencl-11-atomic-operations-on-floating.html", // NOLINT +"#if (TYPE == TYPE_HALF)", // NOLINT +"", // NOLINT +"// FIXME, has bug which may hang GPU.", // NOLINT +"inline void TEMPLATE(atomic_add,Dtype)(volatile __global Dtype *source, const Dtype operand) {", // NOLINT +"union {", // NOLINT +"uint_tp intVal;", // NOLINT +"Dtype floatVal[2];", // NOLINT +"} newVal;", // NOLINT +"union {", // NOLINT +"uint_tp intVal;", // NOLINT +"Dtype floatVal[2];", // NOLINT +"} prevVal;", // NOLINT +"do {", // NOLINT +"// FIXME, need to consider buffer overflow.", // NOLINT +"prevVal.floatVal[0] = *source;", // NOLINT +"prevVal.floatVal[1] = *(source+1);", // NOLINT +"newVal.floatVal[0] = prevVal.floatVal[0] + operand;", // NOLINT +"newVal.floatVal[1] = prevVal.floatVal[1];", // NOLINT +"} while (atomic_cmpxchg((volatile __global unsigned int *)source, prevVal.intVal, newVal.intVal) != prevVal.intVal);", // NOLINT +"}", // NOLINT +"", // NOLINT +"__kernel void TEMPLATE(embed_backward,Dtype)(const int_tp nthreads, __global const Dtype* bottom_data,", // NOLINT +"__global const Dtype* top_diff, const int_tp M, const int_tp N, const int_tp K,", // NOLINT +"__global Dtype* weight_diff) {", // NOLINT +"for (int_tp top_index = get_global_id(0); top_index < nthreads;", // NOLINT +"top_index += get_global_size(0)) {", // NOLINT +"const int_tp n = top_index / N;", // NOLINT +"const int_tp d = top_index % N;", // NOLINT +"const int_tp index = (int_tp)(bottom_data[n]);", // NOLINT +"const int_tp weight_index = index * N + d;", // NOLINT +"", // NOLINT +"TEMPLATE(atomic_add,Dtype)((weight_diff + weight_index), *(top_diff + top_index));", // NOLINT +"}", // NOLINT +"}", // NOLINT +"#endif", // NOLINT +"", // NOLINT +"", // NOLINT "#if (TYPE == TYPE_FLOAT)", // NOLINT "#ifdef ATOMICS_32_AVAILABLE", // NOLINT "inline void TEMPLATE(atomic_add,Dtype)(volatile __global Dtype *source, const Dtype operand) {", // NOLINT @@ -2174,7 +2413,7 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(fft_phony,Dtype)(Dtype arg) {", // NOLINT +"__kernel void TEMPLATE(fft_phony,Dtype)(KERNEL_ARG_DTYPE arg) {", // NOLINT "Dtype out = arg;", // NOLINT "}", // NOLINT "", // NOLINT @@ -2985,8 +3224,8 @@ static std::vector> cl_kernels{ "cdotc4.xz += mad( s1.xz, s2.xz, s1.yw * s2.yw);", // NOLINT "cdotc4.yw += mad(-s1.xz, s2.yw, s1.yw * s2.xz);", // NOLINT "}", // NOLINT -"cdotc.x += dot(cdotc4.xz, (float2)(1));", // NOLINT -"cdotc.y += dot(cdotc4.yw, (float2)(1));", // NOLINT +"cdotc.x += dot(cdotc4.xz, (Dtype2)(1));", // NOLINT +"cdotc.y += dot(cdotc4.yw, (Dtype2)(1));", // NOLINT "if (r == 1) {", // NOLINT "const __global Dtype* src1_ptr2 =", // NOLINT "(const __global Dtype*)(((const __global Dtype4*)(src1_ptr)) + n);", // NOLINT @@ -3013,7 +3252,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "}", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(fill,Dtype)(const int_tp n, const Dtype alpha, __global Dtype* x,", // NOLINT +"__kernel void TEMPLATE(fill,Dtype)(const int_tp n, const KERNEL_ARG_DTYPE alpha, __global Dtype* x,", // NOLINT "const int_tp offx) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT "x[index + offx] = alpha;", // NOLINT @@ -3052,38 +3291,38 @@ static std::vector> cl_kernels{ "//#define USE_IMAGE_C", // NOLINT "#ifdef USE_IMAGE_C", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read_us8( _C, _coordC ) )", // NOLINT +"#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read_us8( _C, _coordC ) )", // NOLINT "#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write_us8( _C, _coordC, as_ushort8( _val ) )", // NOLINT "#else", // NOLINT -"#define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read8( _C, _coordC ) )", // NOLINT +"#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read8( _C, _coordC ) )", // NOLINT "#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) )", // NOLINT "#endif", // NOLINT "#define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst", // NOLINT "#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint))", // NOLINT "#else", // NOLINT -"#define BLOCKC_READ8( _C, _coordC ) (float8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * ldc + _coordC.x + get_local_id(0) ] : 0)", // NOLINT +"#define BLOCKC_READ8( _C, _coordC ) (Dtype8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * ldc + _coordC.x + get_local_id(0) ] : 0, (_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * ldc + _coordC.x + get_local_id(0) ] : 0)", // NOLINT "", // NOLINT "#define BLOCKC_WRITE8( _C, _coordC, _val) do { if (_coordC.x + get_local_id(0) < N) { if (_coordC.y < M) _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] = _val.s0; if (_coordC.y + 1 < M) _C[ ( _coordC.y + 1 )* ldc + _coordC.x + get_local_id(0) ] = _val.s1; if (_coordC.y + 2 < M) _C[ ( _coordC.y + 2 )* ldc + _coordC.x + get_local_id(0) ] = _val.s2; if (_coordC.y + 3 < M) _C[ ( _coordC.y + 3 )* ldc + _coordC.x + get_local_id(0) ] = _val.s3; if (_coordC.y + 4 < M) _C[ ( _coordC.y + 4 )* ldc + _coordC.x + get_local_id(0) ] = _val.s4; if (_coordC.y + 5 < M) _C[ ( _coordC.y + 5 )* ldc + _coordC.x + get_local_id(0) ] = _val.s5; if (_coordC.y + 6 < M) _C[ ( _coordC.y + 6 )* ldc + _coordC.x + get_local_id(0) ] = _val.s6; if (_coordC.y + 7 < M) _C[ ( _coordC.y + 7 )* ldc + _coordC.x + get_local_id(0) ] = _val.s7; }} while(0)", // NOLINT -"#define MATC_PARAMETER __global float * C, const int offC, const int M, const int N, const int ldc", // NOLINT +"#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N, const int ldc", // NOLINT "#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1)", // NOLINT "#endif", // NOLINT "", // NOLINT -"#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); int2 coordC = coordDst; float8 blockC00; float8 blockC01; float8 blockC02; float8 blockC03; if (BETA_NOT0) { blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); if (!ALPHA1) { blockC00 = mad(blockAxB00, (float8)alpha, blockC00); blockC01 = mad(blockAxB01, (float8)alpha, blockC01); blockC02 = mad(blockAxB02, (float8)alpha, blockC02); blockC03 = mad(blockAxB03, (float8)alpha, blockC03); } else { blockC00 += blockAxB00; blockC01 += blockAxB01; blockC02 += blockAxB02; blockC03 += blockAxB03; } } else { blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); if (!ALPHA1) { blockC00 = mad(blockAxB00, (float8)alpha, blockC00); blockC01 = mad(blockAxB01, (float8)alpha, blockC01); blockC02 = mad(blockAxB02, (float8)alpha, blockC02); blockC03 = mad(blockAxB03, (float8)alpha, blockC03); } else { blockC00 += blockAxB00; blockC01 += blockAxB01; blockC02 += blockAxB02; blockC03 += blockAxB03; } } BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC03 );", // NOLINT +"#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); int2 coordC = coordDst; Dtype8 blockC00; Dtype8 blockC01; Dtype8 blockC02; Dtype8 blockC03; if (BETA_NOT0) { blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); if (!ALPHA1) { blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); } else { blockC00 += blockAxB00; blockC01 += blockAxB01; blockC02 += blockAxB02; blockC03 += blockAxB03; } } else { blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); if (!ALPHA1) { blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); } else { blockC00 += blockAxB00; blockC01 += blockAxB01; blockC02 += blockAxB02; blockC03 += blockAxB03; } } BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; BLOCKC_WRITE8( _dst, coordDst, blockC03 );", // NOLINT "", // NOLINT "// Get the specified column of the block of the block", // NOLINT -"#define TRANSPOSE_BLOCK_8( _block, _col ) (float8)( intel_sub_group_shuffle( _block.s0, _col ), intel_sub_group_shuffle( _block.s1, _col ), intel_sub_group_shuffle( _block.s2, _col ), intel_sub_group_shuffle( _block.s3, _col ), intel_sub_group_shuffle( _block.s4, _col ), intel_sub_group_shuffle( _block.s5, _col ), intel_sub_group_shuffle( _block.s6, _col ), intel_sub_group_shuffle( _block.s7, _col ) );", // NOLINT +"#define TRANSPOSE_BLOCK_8( _block, _col ) (Dtype8)( intel_sub_group_shuffle( _block.s0, _col ), intel_sub_group_shuffle( _block.s1, _col ), intel_sub_group_shuffle( _block.s2, _col ), intel_sub_group_shuffle( _block.s3, _col ), intel_sub_group_shuffle( _block.s4, _col ), intel_sub_group_shuffle( _block.s5, _col ), intel_sub_group_shuffle( _block.s6, _col ), intel_sub_group_shuffle( _block.s7, _col ) );", // NOLINT "", // NOLINT "// A's column block multiply B 's row block.", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB00, _blockB01 ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); const float8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); const float8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); const float8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); const float8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); const float8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); const float8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); const float8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); const float8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); _result = mad( (float8)(_blockB00.s0), acol0, _result ); _result = mad( (float8)(_blockB00.s1), acol1, _result ); _result = mad( (float8)(_blockB00.s2), acol2, _result ); _result = mad( (float8)(_blockB00.s3), acol3, _result ); _result = mad( (float8)(_blockB00.s4), acol4, _result ); _result = mad( (float8)(_blockB00.s5), acol5, _result ); _result = mad( (float8)(_blockB00.s6), acol6, _result ); _result = mad( (float8)(_blockB00.s7), acol7, _result ); _result = mad( (float8)(_blockB01.s0), acol8, _result ); _result = mad( (float8)(_blockB01.s1), acol9, _result ); _result = mad( (float8)(_blockB01.s2), acola, _result ); _result = mad( (float8)(_blockB01.s3), acolb, _result ); _result = mad( (float8)(_blockB01.s4), acolc, _result ); _result = mad( (float8)(_blockB01.s5), acold, _result ); _result = mad( (float8)(_blockB01.s6), acole, _result ); _result = mad( (float8)(_blockB01.s7), acolf, _result ); }", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB00, _blockB01 ) { const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); _result = mad( (Dtype8)(_blockB00.s0), acol0, _result ); _result = mad( (Dtype8)(_blockB00.s1), acol1, _result ); _result = mad( (Dtype8)(_blockB00.s2), acol2, _result ); _result = mad( (Dtype8)(_blockB00.s3), acol3, _result ); _result = mad( (Dtype8)(_blockB00.s4), acol4, _result ); _result = mad( (Dtype8)(_blockB00.s5), acol5, _result ); _result = mad( (Dtype8)(_blockB00.s6), acol6, _result ); _result = mad( (Dtype8)(_blockB00.s7), acol7, _result ); _result = mad( (Dtype8)(_blockB01.s0), acol8, _result ); _result = mad( (Dtype8)(_blockB01.s1), acol9, _result ); _result = mad( (Dtype8)(_blockB01.s2), acola, _result ); _result = mad( (Dtype8)(_blockB01.s3), acolb, _result ); _result = mad( (Dtype8)(_blockB01.s4), acolc, _result ); _result = mad( (Dtype8)(_blockB01.s5), acold, _result ); _result = mad( (Dtype8)(_blockB01.s6), acole, _result ); _result = mad( (Dtype8)(_blockB01.s7), acolf, _result ); }", // NOLINT "#else", // NOLINT -"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); _result = mad( (float8)(_blockB.s0), acol0, _result ); _result = mad( (float8)(_blockB.s1), acol1, _result ); _result = mad( (float8)(_blockB.s2), acol2, _result ); _result = mad( (float8)(_blockB.s3), acol3, _result ); _result = mad( (float8)(_blockB.s4), acol4, _result ); _result = mad( (float8)(_blockB.s5), acol5, _result ); _result = mad( (float8)(_blockB.s6), acol6, _result ); _result = mad( (float8)(_blockB.s7), acol7, _result ); }", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); _result = mad( (Dtype8)(_blockB.s0), acol0, _result ); _result = mad( (Dtype8)(_blockB.s1), acol1, _result ); _result = mad( (Dtype8)(_blockB.s2), acol2, _result ); _result = mad( (Dtype8)(_blockB.s3), acol3, _result ); _result = mad( (Dtype8)(_blockB.s4), acol4, _result ); _result = mad( (Dtype8)(_blockB.s5), acol5, _result ); _result = mad( (Dtype8)(_blockB.s6), acol6, _result ); _result = mad( (Dtype8)(_blockB.s7), acol7, _result ); }", // NOLINT "#endif", // NOLINT "", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define GEMM_NN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha_in, float beta_in, int width0, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0; float8 blockAxB01 = 0; float8 blockAxB02 = 0; float8 blockAxB03 = 0; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; float8 blockB01 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, blockB01 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, blockB01 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, blockB01 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, blockB01 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#define GEMM_NN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, KERNEL_ARG_DTYPE alpha_in, KERNEL_ARG_DTYPE beta_in, int width0, int isFirstColBlock) { const Dtype alpha = (Dtype)alpha_in; const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); Dtype8 blockAxB00 = 0; Dtype8 blockAxB01 = 0; Dtype8 blockAxB02 = 0; Dtype8 blockAxB03 = 0; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; Dtype8 blockB01 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, blockB01 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, blockB01 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, blockB01 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, blockB01 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT "#else", // NOLINT -"#define GEMM_NN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha_in, float beta_in, int width0, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#define GEMM_NN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, KERNEL_ARG_DTYPE alpha_in, KERNEL_ARG_DTYPE beta_in, int width0, int isFirstColBlock) { const Dtype alpha = (Dtype)alpha_in; const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); Dtype8 blockAxB00 = 0.0f; Dtype8 blockAxB01 = 0.0f; Dtype8 blockAxB02 = 0.0f; Dtype8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT "#endif", // NOLINT "", // NOLINT "GEMM_NN(1, 0) // ALPHA == 1, BETA == 0", // NOLINT @@ -3093,16 +3332,17 @@ static std::vector> cl_kernels{ "", // NOLINT "#undef TRANSPOSE_BLOCK_8", // NOLINT "#undef MULTIPLY_BLOCKS_8x8", // NOLINT +"#undef GEMM_NN", // NOLINT "", // NOLINT "// replicate the first row to column block.", // NOLINT -"#define TRANSPOSE_BLOCK_8(_vec, _col) (float8)( intel_sub_group_shuffle(_vec, _col + 0), intel_sub_group_shuffle(_vec, _col + 1), intel_sub_group_shuffle(_vec, _col + 2), intel_sub_group_shuffle(_vec, _col + 3), intel_sub_group_shuffle(_vec, _col + 4), intel_sub_group_shuffle(_vec, _col + 5), intel_sub_group_shuffle(_vec, _col + 6), intel_sub_group_shuffle(_vec, _col + 7) )", // NOLINT +"#define TRANSPOSE_BLOCK_8(_vec, _col) (Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), intel_sub_group_shuffle(_vec, _col + 1), intel_sub_group_shuffle(_vec, _col + 2), intel_sub_group_shuffle(_vec, _col + 3), intel_sub_group_shuffle(_vec, _col + 4), intel_sub_group_shuffle(_vec, _col + 5), intel_sub_group_shuffle(_vec, _col + 6), intel_sub_group_shuffle(_vec, _col + 7) )", // NOLINT "", // NOLINT -"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) { _result = mad( (float8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); _result = mad( (float8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); _result = mad( (float8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); _result = mad( (float8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); _result = mad( (float8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); _result = mad( (float8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); _result = mad( (float8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); _result = mad( (float8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); }", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) { _result = mad( (Dtype8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); _result = mad( (Dtype8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); _result = mad( (Dtype8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); _result = mad( (Dtype8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); _result = mad( (Dtype8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); _result = mad( (Dtype8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); _result = mad( (Dtype8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); _result = mad( (Dtype8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); }", // NOLINT "", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define GEMM_TN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha_in, float beta_in, int width0, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0; float8 blockAxB01 = 0; float8 blockAxB02 = 0; float8 blockAxB03 = 0; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_half8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT; float8 blockA01 = as_half8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#define GEMM_TN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, KERNEL_ARG_DTYPE alpha_in, KERNEL_ARG_DTYPE beta_in, int width0, int isFirstColBlock) { const Dtype alpha = (Dtype)alpha_in; const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); Dtype8 blockAxB00 = 0; Dtype8 blockAxB01 = 0; Dtype8 blockAxB02 = 0; Dtype8 blockAxB03 = 0; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT; Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT "#else", // NOLINT -"#define GEMM_TN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, float alpha_in, float beta_in, int width0, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, 0 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#define GEMM_TN(ALPHA1, BETA_NOT0) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, __read_only image2d_t B, MATC_PARAMETER, KERNEL_ARG_DTYPE alpha_in, KERNEL_ARG_DTYPE beta_in, int width0, int isFirstColBlock) { const Dtype alpha = (Dtype)alpha_in; const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); Dtype8 blockAxB00 = 0.0f; Dtype8 blockAxB01 = 0.0f; Dtype8 blockAxB02 = 0.0f; Dtype8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); do { int2 coordBTemp = coordB; Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; int2 coordATemp = coordA; Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, 0 ); } while( coordB.y < width0 ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT "#endif", // NOLINT "", // NOLINT "GEMM_TN(1, 0) // ALPHA == 1, BETA == 0", // NOLINT @@ -3112,20 +3352,21 @@ static std::vector> cl_kernels{ "", // NOLINT "#undef MULTIPLY_BLOCKS_8x8", // NOLINT "#undef TRANSPOSE_BLOCK_8", // NOLINT +"#undef GEMM_TN", // NOLINT "", // NOLINT "// The same as GEMM_NN", // NOLINT -"#define TRANSPOSE_BLOCK_8( _block, _col ) (float8)( intel_sub_group_shuffle( _block.s0, _col), intel_sub_group_shuffle( _block.s1, _col), intel_sub_group_shuffle( _block.s2, _col), intel_sub_group_shuffle( _block.s3, _col), intel_sub_group_shuffle( _block.s4, _col), intel_sub_group_shuffle( _block.s5, _col), intel_sub_group_shuffle( _block.s6, _col), intel_sub_group_shuffle( _block.s7, _col) )", // NOLINT +"#define TRANSPOSE_BLOCK_8( _block, _col ) (Dtype8)( intel_sub_group_shuffle( _block.s0, _col), intel_sub_group_shuffle( _block.s1, _col), intel_sub_group_shuffle( _block.s2, _col), intel_sub_group_shuffle( _block.s3, _col), intel_sub_group_shuffle( _block.s4, _col), intel_sub_group_shuffle( _block.s5, _col), intel_sub_group_shuffle( _block.s6, _col), intel_sub_group_shuffle( _block.s7, _col) )", // NOLINT "", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); const float8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); const float8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); const float8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); const float8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); const float8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); const float8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); const float8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); const float8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); _result = mad( (float8)_blockB.s0, acol0, _result ); _result = mad( (float8)_blockB.s1, acol1, _result ); _result = mad( (float8)_blockB.s2, acol2, _result ); _result = mad( (float8)_blockB.s3, acol3, _result ); _result = mad( (float8)_blockB.s4, acol4, _result ); _result = mad( (float8)_blockB.s5, acol5, _result ); _result = mad( (float8)_blockB.s6, acol6, _result ); _result = mad( (float8)_blockB.s7, acol7, _result ); _result = mad( (float8)_blockB.s8, acol8, _result ); _result = mad( (float8)_blockB.s9, acol9, _result ); _result = mad( (float8)_blockB.sa, acola, _result ); _result = mad( (float8)_blockB.sb, acolb, _result ); _result = mad( (float8)_blockB.sc, acolc, _result ); _result = mad( (float8)_blockB.sd, acold, _result ); _result = mad( (float8)_blockB.se, acole, _result ); _result = mad( (float8)_blockB.sf, acolf, _result ); }", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); _result = mad( (Dtype8)_blockB.s0, acol0, _result ); _result = mad( (Dtype8)_blockB.s1, acol1, _result ); _result = mad( (Dtype8)_blockB.s2, acol2, _result ); _result = mad( (Dtype8)_blockB.s3, acol3, _result ); _result = mad( (Dtype8)_blockB.s4, acol4, _result ); _result = mad( (Dtype8)_blockB.s5, acol5, _result ); _result = mad( (Dtype8)_blockB.s6, acol6, _result ); _result = mad( (Dtype8)_blockB.s7, acol7, _result ); _result = mad( (Dtype8)_blockB.s8, acol8, _result ); _result = mad( (Dtype8)_blockB.s9, acol9, _result ); _result = mad( (Dtype8)_blockB.sa, acola, _result ); _result = mad( (Dtype8)_blockB.sb, acolb, _result ); _result = mad( (Dtype8)_blockB.sc, acolc, _result ); _result = mad( (Dtype8)_blockB.sd, acold, _result ); _result = mad( (Dtype8)_blockB.se, acole, _result ); _result = mad( (Dtype8)_blockB.sf, acolf, _result ); }", // NOLINT "#else", // NOLINT -"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); _result = mad( (float8)_blockB.s0, acol0, _result ); _result = mad( (float8)_blockB.s1, acol1, _result ); _result = mad( (float8)_blockB.s2, acol2, _result ); _result = mad( (float8)_blockB.s3, acol3, _result ); _result = mad( (float8)_blockB.s4, acol4, _result ); _result = mad( (float8)_blockB.s5, acol5, _result ); _result = mad( (float8)_blockB.s6, acol6, _result ); _result = mad( (float8)_blockB.s7, acol7, _result ); }", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) { const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); _result = mad( (Dtype8)_blockB.s0, acol0, _result ); _result = mad( (Dtype8)_blockB.s1, acol1, _result ); _result = mad( (Dtype8)_blockB.s2, acol2, _result ); _result = mad( (Dtype8)_blockB.s3, acol3, _result ); _result = mad( (Dtype8)_blockB.s4, acol4, _result ); _result = mad( (Dtype8)_blockB.s5, acol5, _result ); _result = mad( (Dtype8)_blockB.s6, acol6, _result ); _result = mad( (Dtype8)_blockB.s7, acol7, _result ); }", // NOLINT "#endif", // NOLINT "", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha_in, float beta_in, int padded_k, int k, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0; float8 blockAxB01 = 0; float8 blockAxB02 = 0; float8 blockAxB03 = 0; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float16 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, KERNEL_ARG_DTYPE alpha_in, KERNEL_ARG_DTYPE beta_in, int padded_k, int k, int isFirstColBlock) { const Dtype alpha = (Dtype)alpha_in; const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); Dtype8 blockAxB00 = 0; Dtype8 blockAxB01 = 0; Dtype8 blockAxB02 = 0; Dtype8 blockAxB03 = 0; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { Dtype16 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT "#else", // NOLINT -"#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha_in, float beta_in, int padded_k, int k, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT +"#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, KERNEL_ARG_DTYPE alpha_in, KERNEL_ARG_DTYPE beta_in, int padded_k, int k, int isFirstColBlock) { const Dtype alpha = (Dtype)alpha_in; const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); Dtype8 blockAxB00 = 0.0f; Dtype8 blockAxB01 = 0.0f; Dtype8 blockAxB02 = 0.0f; Dtype8 blockAxB03 = 0.0f; int2 coordA = (int2)( 0, group_y * TILE_M ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { Dtype8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0); }", // NOLINT "#endif", // NOLINT "", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT @@ -3144,12 +3385,12 @@ static std::vector> cl_kernels{ "#undef MATB_PARAMETER", // NOLINT "", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); _blockb = as_half16(as_ushort16(vload8(0, B_read))); _coordB.x += TILE_K * 2;", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); _blockb = as_Dtype16(as_ushort16(vload8(0, B_read))); _coordB.x += TILE_K * 2;", // NOLINT "#else", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); _blockb = vload8(0, B_read); _coordB.x += TILE_K;", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); _blockb = vload8(0, B_read); _coordB.x += TILE_K;", // NOLINT "#endif", // NOLINT "", // NOLINT -"#define MATB_PARAMETER __global float *B, int offB, int ldb", // NOLINT +"#define MATB_PARAMETER __global Dtype *B, int offB, int ldb", // NOLINT "", // NOLINT "GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0", // NOLINT "GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0", // NOLINT @@ -3159,9 +3400,9 @@ static std::vector> cl_kernels{ "#undef MATB_PARAMETER", // NOLINT "", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s8 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s9 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sa = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sb = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sc = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sd = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.se = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sf = temp.s0; _coordB.x += 16;", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); Dtype4 temp; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s8 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s9 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sa = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sb = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sc = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sd = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.se = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.sf = temp.s0; _coordB.x += 16;", // NOLINT "#else", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); Dtype4 temp; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT "#endif", // NOLINT "", // NOLINT "#define MATB_PARAMETER __read_only image2d_t B", // NOLINT @@ -3175,16 +3416,17 @@ static std::vector> cl_kernels{ "", // NOLINT "#undef MULTIPLY_BLOCKS_8x8", // NOLINT "#undef TRANSPOSE_BLOCK_8", // NOLINT +"#undef GEMM_NT", // NOLINT "", // NOLINT "//The same as GEMM_TN.", // NOLINT -"#define TRANSPOSE_BLOCK_8(_vec, _col) (float8)( intel_sub_group_shuffle(_vec, _col + 0), intel_sub_group_shuffle(_vec, _col + 1), intel_sub_group_shuffle(_vec, _col + 2), intel_sub_group_shuffle(_vec, _col + 3), intel_sub_group_shuffle(_vec, _col + 4), intel_sub_group_shuffle(_vec, _col + 5), intel_sub_group_shuffle(_vec, _col + 6), intel_sub_group_shuffle(_vec, _col + 7) );", // NOLINT +"#define TRANSPOSE_BLOCK_8(_vec, _col) (Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), intel_sub_group_shuffle(_vec, _col + 1), intel_sub_group_shuffle(_vec, _col + 2), intel_sub_group_shuffle(_vec, _col + 3), intel_sub_group_shuffle(_vec, _col + 4), intel_sub_group_shuffle(_vec, _col + 5), intel_sub_group_shuffle(_vec, _col + 6), intel_sub_group_shuffle(_vec, _col + 7) );", // NOLINT "", // NOLINT -"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) { const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); _result = mad( (float8)_blockB.s0, acol0, _result ); _result = mad( (float8)_blockB.s1, acol1, _result ); _result = mad( (float8)_blockB.s2, acol2, _result ); _result = mad( (float8)_blockB.s3, acol3, _result ); _result = mad( (float8)_blockB.s4, acol4, _result ); _result = mad( (float8)_blockB.s5, acol5, _result ); _result = mad( (float8)_blockB.s6, acol6, _result ); _result = mad( (float8)_blockB.s7, acol7, _result ); }", // NOLINT +"#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) { const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); _result = mad( (Dtype8)_blockB.s0, acol0, _result ); _result = mad( (Dtype8)_blockB.s1, acol1, _result ); _result = mad( (Dtype8)_blockB.s2, acol2, _result ); _result = mad( (Dtype8)_blockB.s3, acol3, _result ); _result = mad( (Dtype8)_blockB.s4, acol4, _result ); _result = mad( (Dtype8)_blockB.s5, acol5, _result ); _result = mad( (Dtype8)_blockB.s6, acol6, _result ); _result = mad( (Dtype8)_blockB.s7, acol7, _result ); }", // NOLINT "", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha_in, float beta_in, int padded_k, int k, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0; float8 blockAxB01 = 0; float8 blockAxB02 = 0; float8 blockAxB03 = 0; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0);}", // NOLINT +"#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, KERNEL_ARG_DTYPE alpha_in, KERNEL_ARG_DTYPE beta_in, int padded_k, int k, int isFirstColBlock) { const Dtype alpha = (Dtype)alpha_in; const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); Dtype8 blockAxB00 = 0; Dtype8 blockAxB01 = 0; Dtype8 blockAxB02 = 0; Dtype8 blockAxB03 = 0; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { Dtype8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT; Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0);}", // NOLINT "#else", // NOLINT -"#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, float alpha_in, float beta_in, int padded_k, int k, int isFirstColBlock) { const float alpha = (float)alpha_in; const float beta = (float)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); float8 blockAxB00 = 0.0f; float8 blockAxB01 = 0.0f; float8 blockAxB02 = 0.0f; float8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { float8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00, 0 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0);}", // NOLINT +"#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( __read_only image2d_t A, MATB_PARAMETER, MATC_PARAMETER, KERNEL_ARG_DTYPE alpha_in, KERNEL_ARG_DTYPE beta_in, int padded_k, int k, int isFirstColBlock) { const Dtype alpha = (Dtype)alpha_in; const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); Dtype8 blockAxB00 = 0.0f; Dtype8 blockAxB01 = 0.0f; Dtype8 blockAxB02 = 0.0f; Dtype8 blockAxB03 = 0.0f; int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); int2 coordB = (int2)( 0, ( group_x * TILE_N )); const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; do { Dtype8 blockB00; BLOCKB_READ8(blockB00, B, coordB); int2 coordATemp = coordA; Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00, 0 ); MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00, 0 ); } while( coordB.x < padded_k / VECSIZE ); GEMM_OUTPUT(ALPHA1, BETA_NOT0);}", // NOLINT "#endif", // NOLINT "", // NOLINT "#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); _blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;", // NOLINT @@ -3199,12 +3441,12 @@ static std::vector> cl_kernels{ "#undef MATB_PARAMETER", // NOLINT "", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); _blockb = as_half8(as_ushort8(vload4(0, B_read))); _coordB.x += TILE_K;", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); _blockb = as_Dtype8(as_ushort8(vload4(0, B_read))); _coordB.x += TILE_K;", // NOLINT "#else", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); _blockb = vload8(0, B_read); _coordB.x += TILE_K;", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); _blockb = vload8(0, B_read); _coordB.x += TILE_K;", // NOLINT "#endif", // NOLINT "", // NOLINT -"#define MATB_PARAMETER __global float *B, int offB, int ldb", // NOLINT +"#define MATB_PARAMETER __global Dtype *B, int offB, int ldb", // NOLINT "", // NOLINT "GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0", // NOLINT "GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0", // NOLINT @@ -3213,7 +3455,7 @@ static std::vector> cl_kernels{ "#undef BLOCKB_READ8", // NOLINT "#undef MATB_PARAMETER", // NOLINT "", // NOLINT -"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); float4 temp; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT +"#define BLOCKB_READ8(_blockb, _B, _coordB) int2 _coordBTemp = _coordB; _coordBTemp.y += get_local_id(0); Dtype4 temp; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s0 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s1 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s2 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s3 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s4 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s5 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s6 = temp.s0; temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; _blockb.s7 = temp.s0; _coordB.x += 8;", // NOLINT "", // NOLINT "#define MATB_PARAMETER __read_only image2d_t B", // NOLINT "", // NOLINT @@ -3226,13 +3468,17 @@ static std::vector> cl_kernels{ "", // NOLINT "#undef MULTIPLY_BLOCKS_8x8", // NOLINT "#undef TRANSPOSE_BLOCK_8", // NOLINT +"#undef GEMM_TT", // NOLINT "", // NOLINT "#undef TILE_M", // NOLINT "#undef TILE_K", // NOLINT "#undef TILE_N", // NOLINT +"#undef SUBGROUP_BLOCK_READ8", // NOLINT +"#undef READ_IMAGE", // NOLINT +"#undef SIZE_OF_ELEMENT", // NOLINT "", // NOLINT "__kernel void TEMPLATE(gemm_buffer_copy_image_transpose, Dtype)(", // NOLINT -"__global float* A,", // NOLINT +"__global Dtype* A,", // NOLINT "__write_only image2d_t ImA,", // NOLINT "int offA,", // NOLINT "int width,", // NOLINT @@ -3242,17 +3488,17 @@ static std::vector> cl_kernels{ "const int gidx = get_global_id(0);", // NOLINT "const int gidy = get_global_id(1);", // NOLINT "int2 coord_dst = (int2)(gidx, gidy);", // NOLINT -"__global float* A_off = A + offA;", // NOLINT -"float srcA = A_off[gidy * ldA + gidx];", // NOLINT +"__global Dtype* A_off = A + offA;", // NOLINT +"Dtype srcA = A_off[gidy * ldA + gidx];", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"write_imageh(ImA, coord_dst, (float4)srcA);", // NOLINT +"write_imageh(ImA, coord_dst, (Dtype4)srcA);", // NOLINT "#else", // NOLINT -"write_imagef(ImA, coord_dst, (float4)srcA);", // NOLINT +"write_imagef(ImA, coord_dst, (Dtype4)srcA);", // NOLINT "#endif", // NOLINT "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)(", // NOLINT -"__global float* A,", // NOLINT +"__global Dtype* A,", // NOLINT "__write_only image2d_t ImA,", // NOLINT "int offA,", // NOLINT "int width,", // NOLINT @@ -3267,14 +3513,14 @@ static std::vector> cl_kernels{ "write_imageh(ImA, coord_dst, 0);", // NOLINT "return;", // NOLINT "}", // NOLINT -"__global float* A_off = A + offA;", // NOLINT +"__global Dtype* A_off = A + offA;", // NOLINT "write_imageh(ImA, coord_dst, A_off[gidy * ldA + gidx]);", // NOLINT "#else", // NOLINT "if (gidx >= width || gidy >= height) {", // NOLINT "write_imageui(ImA, coord_dst, (uint4)0);", // NOLINT "return;", // NOLINT "}", // NOLINT -"__global float* A_off = A + offA;", // NOLINT +"__global Dtype* A_off = A + offA;", // NOLINT "uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * ldA + gidx]));", // NOLINT "write_imageui(ImA, coord_dst, srcA);", // NOLINT "#endif", // NOLINT @@ -3295,18 +3541,18 @@ static std::vector> cl_kernels{ "__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1)))", // NOLINT "__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM)))", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NN, Dtype)(", // NOLINT -"const __global float *src0, int off0,", // NOLINT -"const __global float *src1, int off1,", // NOLINT -"__global float *dst, int offd,", // NOLINT +"const __global Dtype *src0, int off0,", // NOLINT +"const __global Dtype *src1, int off1,", // NOLINT +"__global Dtype *dst, int offd,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha_in,", // NOLINT -"float beta_in,", // NOLINT +"KERNEL_ARG_DTYPE alpha_in,", // NOLINT +"KERNEL_ARG_DTYPE beta_in,", // NOLINT "int start_index)", // NOLINT "{", // NOLINT -"const float alpha = (float)alpha_in;", // NOLINT -"const float beta = (float)beta_in;", // NOLINT +"const Dtype alpha = (Dtype)alpha_in;", // NOLINT +"const Dtype beta = (Dtype)beta_in;", // NOLINT "const int group_x = get_group_id(0);", // NOLINT "const int group_y = get_group_id(1);", // NOLINT "const int local_x = get_local_id(0);", // NOLINT @@ -3314,14 +3560,14 @@ static std::vector> cl_kernels{ "const int global_x = get_global_id(0);", // NOLINT "const int global_y = get_global_id(1);", // NOLINT "", // NOLINT -"float4 brow;", // NOLINT -"float2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7;", // NOLINT +"Dtype4 brow;", // NOLINT +"Dtype2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7;", // NOLINT "", // NOLINT -"__global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT "", // NOLINT -"const __global float *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0;", // NOLINT +"const __global Dtype *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0;", // NOLINT "", // NOLINT -"const __global float *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;", // NOLINT +"const __global Dtype *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;", // NOLINT "", // NOLINT "int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M);", // NOLINT "", // NOLINT @@ -3334,14 +3580,14 @@ static std::vector> cl_kernels{ "int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border;", // NOLINT "int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border;", // NOLINT "", // NOLINT -"float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);", // NOLINT -"float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N);", // NOLINT -"float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);", // NOLINT -"float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);", // NOLINT -"float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);", // NOLINT -"float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);", // NOLINT -"float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);", // NOLINT -"float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);", // NOLINT +"Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);", // NOLINT +"Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N);", // NOLINT +"Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);", // NOLINT +"Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);", // NOLINT +"Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);", // NOLINT +"Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);", // NOLINT +"Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);", // NOLINT +"Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);", // NOLINT "", // NOLINT "int end_index = min(start_index + 256, K);", // NOLINT "int w = start_index;", // NOLINT @@ -3355,7 +3601,7 @@ static std::vector> cl_kernels{ "arow6 = alpha * vload2(0, src0_read + row6 * K);", // NOLINT "arow7 = alpha * vload2(0, src0_read + row7 * K);", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( index, suffix ) brow = vload4(0, src1_read0); src1_read0 += N; dot00 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); dot01 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); dot02 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); dot03 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); dot04 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); dot05 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); dot06 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); dot07 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );", // NOLINT +"#define MM_DOT_PRODUCT( index, suffix ) brow = vload4(0, src1_read0); src1_read0 += N; dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );", // NOLINT "MM_DOT_PRODUCT(0, 0);", // NOLINT "MM_DOT_PRODUCT(0, 1);", // NOLINT "MM_DOT_PRODUCT(1, 0);", // NOLINT @@ -3414,7 +3660,7 @@ static std::vector> cl_kernels{ "arow7.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row7 * K)[0] : 0.0f;", // NOLINT "arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f;", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( index, suffix ) brow = (w < K) ? vload4(0, src1_read0) : (float4)0.0f; src1_read0 += N; w++; dot00 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); dot01 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); dot02 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); dot03 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); dot04 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); dot05 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); dot06 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); dot07 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );", // NOLINT +"#define MM_DOT_PRODUCT( index, suffix ) brow = (w < K) ? vload4(0, src1_read0) : (Dtype4)0.0f; src1_read0 += N; w++; dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );", // NOLINT "MM_DOT_PRODUCT(0, 0);", // NOLINT "MM_DOT_PRODUCT(0, 1);", // NOLINT "MM_DOT_PRODUCT(1, 0);", // NOLINT @@ -3563,17 +3809,17 @@ static std::vector> cl_kernels{ "__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT "__attribute__((intel_reqd_sub_group_size(8)))", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(", // NOLINT -"const __global float *src0, int off0,", // NOLINT -"const __global float *src1, int off1,", // NOLINT -"__global float *dst, int offd,", // NOLINT +"const __global Dtype *src0, int off0,", // NOLINT +"const __global Dtype *src1, int off1,", // NOLINT +"__global Dtype *dst, int offd,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha_in,", // NOLINT -"float beta_in)", // NOLINT +"KERNEL_ARG_DTYPE alpha_in,", // NOLINT +"KERNEL_ARG_DTYPE beta_in)", // NOLINT "{", // NOLINT -"const float alpha = (float)alpha_in;", // NOLINT -"const float beta = (float)beta_in;", // NOLINT +"const Dtype alpha = (Dtype)alpha_in;", // NOLINT +"const Dtype beta = (Dtype)beta_in;", // NOLINT "const int group_x = get_group_id(0);", // NOLINT "const int group_y = get_group_id(1);", // NOLINT "const int local_x = get_local_id(0);", // NOLINT @@ -3581,32 +3827,32 @@ static std::vector> cl_kernels{ "const int global_x = get_global_id(0);", // NOLINT "const int global_y = get_global_id(1);", // NOLINT "", // NOLINT -"float8 dot00 = 0.f;", // NOLINT -"float8 dot01 = 0.f;", // NOLINT -"float8 dot02 = 0.f;", // NOLINT -"float8 dot03 = 0.f;", // NOLINT -"float8 dot04 = 0.f;", // NOLINT -"float8 dot05 = 0.f;", // NOLINT -"float8 dot06 = 0.f;", // NOLINT -"float8 dot07 = 0.f;", // NOLINT +"Dtype8 dot00 = 0.f;", // NOLINT +"Dtype8 dot01 = 0.f;", // NOLINT +"Dtype8 dot02 = 0.f;", // NOLINT +"Dtype8 dot03 = 0.f;", // NOLINT +"Dtype8 dot04 = 0.f;", // NOLINT +"Dtype8 dot05 = 0.f;", // NOLINT +"Dtype8 dot06 = 0.f;", // NOLINT +"Dtype8 dot07 = 0.f;", // NOLINT "", // NOLINT -"float4 brow0;", // NOLINT -"float4 brow1;", // NOLINT -"float4 brow2;", // NOLINT -"float4 brow3;", // NOLINT -"float4 brow4;", // NOLINT -"float4 brow5;", // NOLINT -"float4 brow6;", // NOLINT -"float4 brow7;", // NOLINT +"Dtype4 brow0;", // NOLINT +"Dtype4 brow1;", // NOLINT +"Dtype4 brow2;", // NOLINT +"Dtype4 brow3;", // NOLINT +"Dtype4 brow4;", // NOLINT +"Dtype4 brow5;", // NOLINT +"Dtype4 brow6;", // NOLINT +"Dtype4 brow7;", // NOLINT "", // NOLINT -"__global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT "", // NOLINT -"const __global float *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;", // NOLINT +"const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;", // NOLINT "", // NOLINT -"const __global float *src1_read0 = src1 + (group_x * TILE_N) * K + off1;", // NOLINT +"const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;", // NOLINT "", // NOLINT -"__local float slm_brow[8 * SLM_BLOCK];", // NOLINT -"__local float* slm_brow0;", // NOLINT +"__local Dtype slm_brow[8 * SLM_BLOCK];", // NOLINT +"__local Dtype* slm_brow0;", // NOLINT "", // NOLINT "int local_index = mad24(local_y, 8, local_x) * 4;", // NOLINT "int w;", // NOLINT @@ -3626,7 +3872,7 @@ static std::vector> cl_kernels{ "w = b_tile;", // NOLINT "int end_w = min(b_tile + SLM_BLOCK, K);", // NOLINT "while( w + TILE_K <= end_w ) {", // NOLINT -"float4 arow;", // NOLINT +"Dtype4 arow;", // NOLINT "", // NOLINT "brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK);", // NOLINT "brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK);", // NOLINT @@ -3637,7 +3883,7 @@ static std::vector> cl_kernels{ "brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK);", // NOLINT "brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK);", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( _row, _dot ) arow = vload4(0, src0_read + _row * K); _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT +"#define MM_DOT_PRODUCT( _row, _dot ) arow = vload4(0, src0_read + _row * K); _dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT "MM_DOT_PRODUCT( 0, dot00 );", // NOLINT "MM_DOT_PRODUCT( 1, dot01 );", // NOLINT "MM_DOT_PRODUCT( 2, dot02 );", // NOLINT @@ -3656,7 +3902,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "if(w < K) {", // NOLINT -"float4 arow;", // NOLINT +"Dtype4 arow;", // NOLINT "", // NOLINT "#define READ_BROW(_brow, _row) _brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); _brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; _brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; _brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; _brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;", // NOLINT "READ_BROW(brow0, 0);", // NOLINT @@ -3668,7 +3914,7 @@ static std::vector> cl_kernels{ "READ_BROW(brow6, 6);", // NOLINT "READ_BROW(brow7, 7);", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( _row, _dot ) arow = vload4(0, src0_read + _row * K); arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT +"#define MM_DOT_PRODUCT( _row, _dot ) arow = vload4(0, src0_read + _row * K); arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; _dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT "MM_DOT_PRODUCT( 0, dot00 );", // NOLINT "MM_DOT_PRODUCT( 1, dot01 );", // NOLINT "MM_DOT_PRODUCT( 2, dot02 );", // NOLINT @@ -3680,7 +3926,7 @@ static std::vector> cl_kernels{ "#undef MM_DOT_PRODUCT", // NOLINT "}", // NOLINT "", // NOLINT -"#define REDUCE(_dot) _dot = as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));", // NOLINT +"#define REDUCE(_dot) _dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));", // NOLINT "REDUCE(dot00);", // NOLINT "REDUCE(dot01);", // NOLINT "REDUCE(dot02);", // NOLINT @@ -3691,7 +3937,7 @@ static std::vector> cl_kernels{ "REDUCE(dot07);", // NOLINT "#undef REDUCE", // NOLINT "", // NOLINT -"float output = 0.0f;", // NOLINT +"Dtype output = 0.0f;", // NOLINT "#define OUTPUT( _dot) output = (local_x == 0) ? _dot.s0 : output; output = (local_x == 1) ? _dot.s1 : output; output = (local_x == 2) ? _dot.s2 : output; output = (local_x == 3) ? _dot.s3 : output; output = (local_x == 4) ? _dot.s4 : output; output = (local_x == 5) ? _dot.s5 : output; output = (local_x == 6) ? _dot.s6 : output; output = (local_x == 7) ? _dot.s7 : output; dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); dst_write0 += N;", // NOLINT "", // NOLINT "if(global_x < N && global_y * 8 < M) {", // NOLINT @@ -3716,32 +3962,32 @@ static std::vector> cl_kernels{ "", // NOLINT "#define SLM_SIZE 64", // NOLINT "void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(", // NOLINT -"const __global float* srca_read0,", // NOLINT -"const __global float* srca_read1,", // NOLINT -"const __global float* srcb_read,", // NOLINT -"__local float4* work0,", // NOLINT -"__local float4* work1,", // NOLINT +"const __global Dtype* srca_read0,", // NOLINT +"const __global Dtype* srca_read1,", // NOLINT +"const __global Dtype* srcb_read,", // NOLINT +"__local Dtype4* work0,", // NOLINT +"__local Dtype4* work1,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT "int x_gid,", // NOLINT "int lid,", // NOLINT -"float alpha,", // NOLINT -"float beta,", // NOLINT -"__global float* dstc0,", // NOLINT -"__global float* dstc1)", // NOLINT +"Dtype alpha,", // NOLINT +"Dtype beta,", // NOLINT +"__global Dtype* dstc0,", // NOLINT +"__global Dtype* dstc1)", // NOLINT "{", // NOLINT -"__local float* work_each0 = (__local float*)work0;", // NOLINT -"__local float* work_each1 = (__local float*)work1;", // NOLINT +"__local Dtype* work_each0 = (__local Dtype*)work0;", // NOLINT +"__local Dtype* work_each1 = (__local Dtype*)work1;", // NOLINT "", // NOLINT "int rows = N - x_gid * 4;", // NOLINT "", // NOLINT -"float4 dot0[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT -"float4 dot1[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT "", // NOLINT "int i = lid;", // NOLINT "while( i < K / 4) {", // NOLINT -"const float4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT -"const float4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT +"const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT +"const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < rows; ++j) {", // NOLINT "dot0[j] += b0 * vload4(i, srcb_read + j * K);", // NOLINT @@ -3760,13 +4006,13 @@ static std::vector> cl_kernels{ "short tail_items = K % 4;", // NOLINT "", // NOLINT "if(tail_items != 0) {", // NOLINT -"const __global float *srcb_tail = srcb_read + i * 4;", // NOLINT -"const __global float *srca_tail0 = srca_read0 + i * 4;", // NOLINT -"const __global float *srca_tail1 = srca_read1 + i * 4;", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + i * 4;", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + i * 4;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + i * 4;", // NOLINT "#pragma unroll", // NOLINT "for(short i = 0; i < tail_items; ++i) {", // NOLINT -"const float at0 = srca_tail0[i];", // NOLINT -"const float at1 = srca_tail1[i];", // NOLINT +"const Dtype at0 = srca_tail0[i];", // NOLINT +"const Dtype at1 = srca_tail1[i];", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < rows; ++j) {", // NOLINT "work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];", // NOLINT @@ -3794,48 +4040,48 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(", // NOLINT -"__global const float * A,", // NOLINT +"__global const Dtype * A,", // NOLINT "int offA,", // NOLINT -"__global const float * B,", // NOLINT +"__global const Dtype * B,", // NOLINT "int offB,", // NOLINT -"__global float * C,", // NOLINT +"__global Dtype * C,", // NOLINT "int offC,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha_f,", // NOLINT -"float beta_f)", // NOLINT +"KERNEL_ARG_DTYPE alpha_f,", // NOLINT +"KERNEL_ARG_DTYPE beta_f)", // NOLINT "{", // NOLINT -"float alpha = (float)alpha_f;", // NOLINT -"float beta = (float)beta_f;", // NOLINT +"Dtype alpha = (Dtype)alpha_f;", // NOLINT +"Dtype beta = (Dtype)beta_f;", // NOLINT "int x_gid = get_group_id(0);", // NOLINT "int lid = get_local_id(0);", // NOLINT "", // NOLINT -"const __global float *srca_read0 = A + offA;", // NOLINT -"const __global float *srca_read1 = srca_read0 + K;", // NOLINT +"const __global Dtype *srca_read0 = A + offA;", // NOLINT +"const __global Dtype *srca_read1 = srca_read0 + K;", // NOLINT "", // NOLINT -"const __global float *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT +"const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT "", // NOLINT -"__global float4 *dstc0 = (__global float4*)(C + offC);", // NOLINT -"__global float4 *dstc1 = (__global float4*)((__global float*)(dstc0) + N);", // NOLINT +"__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);", // NOLINT +"__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);", // NOLINT "", // NOLINT -"__local float4 work0[SLM_SIZE];", // NOLINT -"__local float4 work1[SLM_SIZE];", // NOLINT -"__local float* work_each0 = (__local float*)work0;", // NOLINT -"__local float* work_each1 = (__local float*)work1;", // NOLINT +"__local Dtype4 work0[SLM_SIZE];", // NOLINT +"__local Dtype4 work1[SLM_SIZE];", // NOLINT +"__local Dtype* work_each0 = (__local Dtype*)work0;", // NOLINT +"__local Dtype* work_each1 = (__local Dtype*)work1;", // NOLINT "", // NOLINT "if(x_gid == N / 4) {", // NOLINT -"TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global float*)dstc0, (__global float*)dstc1);", // NOLINT +"TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1);", // NOLINT "} else {", // NOLINT -"float4 dot0[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT -"float4 dot1[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT "int i = lid;", // NOLINT "while( i < K / 4) {", // NOLINT -"const float4 b0 = vload4(i, srca_read0);", // NOLINT -"const float4 b1 = vload4(i, srca_read1);", // NOLINT +"const Dtype4 b0 = vload4(i, srca_read0);", // NOLINT +"const Dtype4 b1 = vload4(i, srca_read1);", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < 4; ++j) {", // NOLINT -"float4 a = vload4(i, srcb_read + j * K);", // NOLINT +"Dtype4 a = vload4(i, srcb_read + j * K);", // NOLINT "dot0[j] += b0 * a;", // NOLINT "dot1[j] += b1 * a;", // NOLINT "}", // NOLINT @@ -3851,14 +4097,14 @@ static std::vector> cl_kernels{ "if(i == K / 4) {", // NOLINT "short tail_items = K % 4;", // NOLINT "if(tail_items != 0) {", // NOLINT -"const __global float *srcb_tail = srcb_read + i * 4;", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + i * 4;", // NOLINT "", // NOLINT -"const __global float *srca_tail0 = srca_read0 + i * 4;", // NOLINT -"const __global float *srca_tail1 = srca_read1 + i * 4;", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + i * 4;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + i * 4;", // NOLINT "#pragma unroll", // NOLINT "for(short i = 0; i < tail_items; ++i) {", // NOLINT -"const float at0 = srca_tail0[i];", // NOLINT -"const float at1 = srca_tail1[i];", // NOLINT +"const Dtype at0 = srca_tail0[i];", // NOLINT +"const Dtype at1 = srca_tail1[i];", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < 4; ++j) {", // NOLINT "work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];", // NOLINT @@ -3886,44 +4132,44 @@ static std::vector> cl_kernels{ "", // NOLINT "#define SLM_SIZE 32", // NOLINT "void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(", // NOLINT -"const __global float* srca_read0,", // NOLINT -"const __global float* srca_read1,", // NOLINT -"const __global float* srca_read2,", // NOLINT -"const __global float* srca_read3,", // NOLINT -"const __global float* srcb_read,", // NOLINT -"__local float4* work0,", // NOLINT -"__local float4* work1,", // NOLINT -"__local float4* work2,", // NOLINT -"__local float4* work3,", // NOLINT +"const __global Dtype* srca_read0,", // NOLINT +"const __global Dtype* srca_read1,", // NOLINT +"const __global Dtype* srca_read2,", // NOLINT +"const __global Dtype* srca_read3,", // NOLINT +"const __global Dtype* srcb_read,", // NOLINT +"__local Dtype4* work0,", // NOLINT +"__local Dtype4* work1,", // NOLINT +"__local Dtype4* work2,", // NOLINT +"__local Dtype4* work3,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT "int x_gid,", // NOLINT "int lid,", // NOLINT -"float alpha,", // NOLINT -"float beta,", // NOLINT -"__global float* dstc0,", // NOLINT -"__global float* dstc1,", // NOLINT -"__global float* dstc2,", // NOLINT -"__global float* dstc3)", // NOLINT +"Dtype alpha,", // NOLINT +"Dtype beta,", // NOLINT +"__global Dtype* dstc0,", // NOLINT +"__global Dtype* dstc1,", // NOLINT +"__global Dtype* dstc2,", // NOLINT +"__global Dtype* dstc3)", // NOLINT "{", // NOLINT -"__local float* work_each0 = (__local float*)(work0 + lid);", // NOLINT -"__local float* work_each1 = (__local float*)(work1 + lid);", // NOLINT -"__local float* work_each2 = (__local float*)(work2 + lid);", // NOLINT -"__local float* work_each3 = (__local float*)(work3 + lid);", // NOLINT +"__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);", // NOLINT +"__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);", // NOLINT +"__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);", // NOLINT +"__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);", // NOLINT "", // NOLINT "int rows = N - x_gid * 4;", // NOLINT "", // NOLINT -"float4 dot0[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT -"float4 dot1[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT -"float4 dot2[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT -"float4 dot3[3] = {(float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT "", // NOLINT "int i = lid;", // NOLINT "while( i < K / 4) {", // NOLINT -"const float4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT -"const float4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT -"const float4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]};", // NOLINT -"const float4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]};", // NOLINT +"const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};", // NOLINT +"const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};", // NOLINT +"const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]};", // NOLINT +"const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]};", // NOLINT "#pragma unrol", // NOLINT "for(int j = 0; j < rows; ++j) {", // NOLINT "dot0[j] += a0 * vload4(i, srcb_read + j * K);", // NOLINT @@ -3946,18 +4192,18 @@ static std::vector> cl_kernels{ "short tail_items = K % 4;", // NOLINT "", // NOLINT "if(tail_items != 0) {", // NOLINT -"const __global float *srcb_tail = srcb_read + i * 4;", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + i * 4;", // NOLINT "", // NOLINT -"const __global float *srca_tail0 = srca_read0 + i * 4;", // NOLINT -"const __global float *srca_tail1 = srca_read1 + i * 4;", // NOLINT -"const __global float *srca_tail2 = srca_read2 + i * 4;", // NOLINT -"const __global float *srca_tail3 = srca_read3 + i * 4;", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + i * 4;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + i * 4;", // NOLINT +"const __global Dtype *srca_tail2 = srca_read2 + i * 4;", // NOLINT +"const __global Dtype *srca_tail3 = srca_read3 + i * 4;", // NOLINT "#pragma unroll", // NOLINT "for(short i = 0; i < tail_items; ++i) {", // NOLINT -"const float at0 = srca_tail0[i];", // NOLINT -"const float at1 = srca_tail1[i];", // NOLINT -"const float at2 = srca_tail2[i];", // NOLINT -"const float at3 = srca_tail3[i];", // NOLINT +"const Dtype at0 = srca_tail0[i];", // NOLINT +"const Dtype at1 = srca_tail1[i];", // NOLINT +"const Dtype at2 = srca_tail2[i];", // NOLINT +"const Dtype at3 = srca_tail3[i];", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < rows; ++j) {", // NOLINT "work_each0[j] += at0 * srcb_tail[i + j * K];", // NOLINT @@ -3991,62 +4237,62 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(", // NOLINT -"__global const float * A,", // NOLINT +"__global const Dtype * A,", // NOLINT "int offA,", // NOLINT -"__global const float * B,", // NOLINT +"__global const Dtype * B,", // NOLINT "int offB,", // NOLINT -"__global float * C,", // NOLINT +"__global Dtype * C,", // NOLINT "int offC,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha_f,", // NOLINT -"float beta_f)", // NOLINT +"KERNEL_ARG_DTYPE alpha_f,", // NOLINT +"KERNEL_ARG_DTYPE beta_f)", // NOLINT "{", // NOLINT -"float alpha = (float)alpha_f;", // NOLINT -"float beta = (float)beta_f;", // NOLINT +"Dtype alpha = (Dtype)alpha_f;", // NOLINT +"Dtype beta = (Dtype)beta_f;", // NOLINT "int x_gid = get_group_id(0);", // NOLINT "int lid = get_local_id(0);", // NOLINT "int lsize = get_local_size(0);", // NOLINT "", // NOLINT -"const __global float *srca_read0 = A + offA;", // NOLINT -"const __global float *srca_read1 = srca_read0 + K;", // NOLINT -"const __global float *srca_read2 = srca_read1 + K;", // NOLINT -"const __global float *srca_read3 = srca_read2 + K;", // NOLINT +"const __global Dtype *srca_read0 = A + offA;", // NOLINT +"const __global Dtype *srca_read1 = srca_read0 + K;", // NOLINT +"const __global Dtype *srca_read2 = srca_read1 + K;", // NOLINT +"const __global Dtype *srca_read3 = srca_read2 + K;", // NOLINT "", // NOLINT -"const __global float *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT +"const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;", // NOLINT "", // NOLINT -"__global float4 *dstc0 = (__global float4*)(C + offC);", // NOLINT -"__global float4 *dstc1 = (__global float4*)((__global float*)(dstc0) + N);", // NOLINT -"__global float4 *dstc2 = (__global float4*)((__global float*)(dstc1) + N);", // NOLINT -"__global float4 *dstc3 = (__global float4*)((__global float*)(dstc2) + N);", // NOLINT +"__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);", // NOLINT +"__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);", // NOLINT +"__global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N);", // NOLINT +"__global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N);", // NOLINT "", // NOLINT -"__local float4 work0[SLM_SIZE];", // NOLINT -"__local float4 work1[SLM_SIZE];", // NOLINT -"__local float4 work2[SLM_SIZE];", // NOLINT -"__local float4 work3[SLM_SIZE];", // NOLINT -"__local float* work_each0 = (__local float*)(work0 + lid);", // NOLINT -"__local float* work_each1 = (__local float*)(work1 + lid);", // NOLINT -"__local float* work_each2 = (__local float*)(work2 + lid);", // NOLINT -"__local float* work_each3 = (__local float*)(work3 + lid);", // NOLINT +"__local Dtype4 work0[SLM_SIZE];", // NOLINT +"__local Dtype4 work1[SLM_SIZE];", // NOLINT +"__local Dtype4 work2[SLM_SIZE];", // NOLINT +"__local Dtype4 work3[SLM_SIZE];", // NOLINT +"__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);", // NOLINT +"__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);", // NOLINT +"__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);", // NOLINT +"__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);", // NOLINT "", // NOLINT "if(x_gid == N / 4) {", // NOLINT -"TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) (srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, (__global float*)dstc0, (__global float*)dstc1, (__global float*)dstc2, (__global float*)dstc3);", // NOLINT +"TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) (srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3);", // NOLINT "} else {", // NOLINT -"float4 dot0[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT -"float4 dot1[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT -"float4 dot2[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT -"float4 dot3[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)};", // NOLINT +"Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT +"Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT "", // NOLINT "int kid = lid;", // NOLINT "while( kid < K / 4) {", // NOLINT -"const float4 b0 = vload4(kid, srca_read0);", // NOLINT -"const float4 b1 = vload4(kid, srca_read1);", // NOLINT -"const float4 b2 = vload4(kid, srca_read2);", // NOLINT -"const float4 b3 = vload4(kid, srca_read3);", // NOLINT +"const Dtype4 b0 = vload4(kid, srca_read0);", // NOLINT +"const Dtype4 b1 = vload4(kid, srca_read1);", // NOLINT +"const Dtype4 b2 = vload4(kid, srca_read2);", // NOLINT +"const Dtype4 b3 = vload4(kid, srca_read3);", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < 4; ++j) {", // NOLINT -"float4 a = vload4(kid, srcb_read + j * K);", // NOLINT +"Dtype4 a = vload4(kid, srcb_read + j * K);", // NOLINT "dot0[j] += b0 * a;", // NOLINT "dot1[j] += b1 * a;", // NOLINT "dot2[j] += b2 * a;", // NOLINT @@ -4066,18 +4312,18 @@ static std::vector> cl_kernels{ "short tail_items = K % 4;", // NOLINT "if(tail_items != 0) {", // NOLINT "int offset = kid << 2;", // NOLINT -"const __global float *srcb_tail = srcb_read + offset;", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + offset;", // NOLINT "", // NOLINT -"const __global float *srca_tail0 = srca_read0 + offset;", // NOLINT -"const __global float *srca_tail1 = srca_read1 + offset;", // NOLINT -"const __global float *srca_tail2 = srca_read2 + offset;", // NOLINT -"const __global float *srca_tail3 = srca_read3 + offset;", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + offset;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + offset;", // NOLINT +"const __global Dtype *srca_tail2 = srca_read2 + offset;", // NOLINT +"const __global Dtype *srca_tail3 = srca_read3 + offset;", // NOLINT "#pragma unroll", // NOLINT "for(short i = 0; i < tail_items; ++i) {", // NOLINT -"const float at0 = srca_tail0[i];", // NOLINT -"const float at1 = srca_tail1[i];", // NOLINT -"const float at2 = srca_tail2[i];", // NOLINT -"const float at3 = srca_tail3[i];", // NOLINT +"const Dtype at0 = srca_tail0[i];", // NOLINT +"const Dtype at1 = srca_tail1[i];", // NOLINT +"const Dtype at2 = srca_tail2[i];", // NOLINT +"const Dtype at3 = srca_tail3[i];", // NOLINT "#pragma unroll", // NOLINT "for(int j = 0; j < 4; ++j) {", // NOLINT "work_each0[j] += at0 * srcb_tail[i + j * K];", // NOLINT @@ -4111,73 +4357,73 @@ static std::vector> cl_kernels{ "", // NOLINT "#define SLM_SIZE 16", // NOLINT "__kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(", // NOLINT -"__global const float * A,", // NOLINT +"__global const Dtype * A,", // NOLINT "int offA,", // NOLINT -"__global const float * B,", // NOLINT +"__global const Dtype * B,", // NOLINT "int offB,", // NOLINT -"__global float * C,", // NOLINT +"__global Dtype * C,", // NOLINT "int offC,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha_f,", // NOLINT -"float beta_f)", // NOLINT +"KERNEL_ARG_DTYPE alpha_f,", // NOLINT +"KERNEL_ARG_DTYPE beta_f)", // NOLINT "{", // NOLINT -"float alpha = (float)alpha_f;", // NOLINT -"float beta = (float)beta_f;", // NOLINT +"Dtype alpha = (Dtype)alpha_f;", // NOLINT +"Dtype beta = (Dtype)beta_f;", // NOLINT "int x_gid = get_group_id(0);", // NOLINT "int lid = get_local_id(0);", // NOLINT "int lsize = get_local_size(0);", // NOLINT "", // NOLINT -"const __global float *srca_read0 = A + offA;", // NOLINT -"const __global float *srca_read1 = srca_read0 + K;", // NOLINT -"const __global float *srca_read2 = srca_read1 + K;", // NOLINT -"const __global float *srca_read3 = srca_read2 + K;", // NOLINT -"const __global float *srca_read4 = srca_read3 + K;", // NOLINT -"const __global float *srca_read5 = srca_read4 + K;", // NOLINT -"const __global float *srca_read6 = srca_read5 + K;", // NOLINT -"const __global float *srca_read7 = srca_read6 + K;", // NOLINT -"", // NOLINT -"const __global float *srcb_read = B + x_gid * K + offB;", // NOLINT -"", // NOLINT -"__global float *dstc0 = C + offC;", // NOLINT -"__global float *dstc1 = dstc0 + N;", // NOLINT -"__global float *dstc2 = dstc1 + N;", // NOLINT -"__global float *dstc3 = dstc2 + N;", // NOLINT -"__global float *dstc4 = dstc3 + N;", // NOLINT -"__global float *dstc5 = dstc4 + N;", // NOLINT -"__global float *dstc6 = dstc5 + N;", // NOLINT -"__global float *dstc7 = dstc6 + N;", // NOLINT -"", // NOLINT -"__local float work0[SLM_SIZE];", // NOLINT -"__local float work1[SLM_SIZE];", // NOLINT -"__local float work2[SLM_SIZE];", // NOLINT -"__local float work3[SLM_SIZE];", // NOLINT -"__local float work4[SLM_SIZE];", // NOLINT -"__local float work5[SLM_SIZE];", // NOLINT -"__local float work6[SLM_SIZE];", // NOLINT -"__local float work7[SLM_SIZE];", // NOLINT -"", // NOLINT -"float4 dot0 = (float4)(0.);", // NOLINT -"float4 dot1 = (float4)(0.);", // NOLINT -"float4 dot2 = (float4)(0.);", // NOLINT -"float4 dot3 = (float4)(0.);", // NOLINT -"float4 dot4 = (float4)(0.);", // NOLINT -"float4 dot5 = (float4)(0.);", // NOLINT -"float4 dot6 = (float4)(0.);", // NOLINT -"float4 dot7 = (float4)(0.);", // NOLINT +"const __global Dtype *srca_read0 = A + offA;", // NOLINT +"const __global Dtype *srca_read1 = srca_read0 + K;", // NOLINT +"const __global Dtype *srca_read2 = srca_read1 + K;", // NOLINT +"const __global Dtype *srca_read3 = srca_read2 + K;", // NOLINT +"const __global Dtype *srca_read4 = srca_read3 + K;", // NOLINT +"const __global Dtype *srca_read5 = srca_read4 + K;", // NOLINT +"const __global Dtype *srca_read6 = srca_read5 + K;", // NOLINT +"const __global Dtype *srca_read7 = srca_read6 + K;", // NOLINT +"", // NOLINT +"const __global Dtype *srcb_read = B + x_gid * K + offB;", // NOLINT +"", // NOLINT +"__global Dtype *dstc0 = C + offC;", // NOLINT +"__global Dtype *dstc1 = dstc0 + N;", // NOLINT +"__global Dtype *dstc2 = dstc1 + N;", // NOLINT +"__global Dtype *dstc3 = dstc2 + N;", // NOLINT +"__global Dtype *dstc4 = dstc3 + N;", // NOLINT +"__global Dtype *dstc5 = dstc4 + N;", // NOLINT +"__global Dtype *dstc6 = dstc5 + N;", // NOLINT +"__global Dtype *dstc7 = dstc6 + N;", // NOLINT +"", // NOLINT +"__local Dtype work0[SLM_SIZE];", // NOLINT +"__local Dtype work1[SLM_SIZE];", // NOLINT +"__local Dtype work2[SLM_SIZE];", // NOLINT +"__local Dtype work3[SLM_SIZE];", // NOLINT +"__local Dtype work4[SLM_SIZE];", // NOLINT +"__local Dtype work5[SLM_SIZE];", // NOLINT +"__local Dtype work6[SLM_SIZE];", // NOLINT +"__local Dtype work7[SLM_SIZE];", // NOLINT +"", // NOLINT +"Dtype4 dot0 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot1 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot2 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot3 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot4 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot5 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot6 = (Dtype4)(0.);", // NOLINT +"Dtype4 dot7 = (Dtype4)(0.);", // NOLINT "", // NOLINT "int kid = lid;", // NOLINT "while( kid < K / 4) {", // NOLINT -"const float4 a0 = vload4(kid, srca_read0);", // NOLINT -"const float4 a1 = vload4(kid, srca_read1);", // NOLINT -"const float4 a2 = vload4(kid, srca_read2);", // NOLINT -"const float4 a3 = vload4(kid, srca_read3);", // NOLINT -"const float4 a4 = vload4(kid, srca_read4);", // NOLINT -"const float4 a5 = vload4(kid, srca_read5);", // NOLINT -"const float4 a6 = vload4(kid, srca_read6);", // NOLINT -"const float4 a7 = vload4(kid, srca_read7);", // NOLINT -"float4 b = vload4(kid, srcb_read);", // NOLINT +"const Dtype4 a0 = vload4(kid, srca_read0);", // NOLINT +"const Dtype4 a1 = vload4(kid, srca_read1);", // NOLINT +"const Dtype4 a2 = vload4(kid, srca_read2);", // NOLINT +"const Dtype4 a3 = vload4(kid, srca_read3);", // NOLINT +"const Dtype4 a4 = vload4(kid, srca_read4);", // NOLINT +"const Dtype4 a5 = vload4(kid, srca_read5);", // NOLINT +"const Dtype4 a6 = vload4(kid, srca_read6);", // NOLINT +"const Dtype4 a7 = vload4(kid, srca_read7);", // NOLINT +"Dtype4 b = vload4(kid, srcb_read);", // NOLINT "dot0 += a0 * b;", // NOLINT "dot1 += a1 * b;", // NOLINT "dot2 += a2 * b;", // NOLINT @@ -4202,16 +4448,16 @@ static std::vector> cl_kernels{ "short tail_items = K % 4;", // NOLINT "if(tail_items != 0) {", // NOLINT "int offset = kid << 2;", // NOLINT -"const __global float *srcb_tail = srcb_read + offset;", // NOLINT -"", // NOLINT -"const __global float *srca_tail0 = srca_read0 + offset;", // NOLINT -"const __global float *srca_tail1 = srca_read1 + offset;", // NOLINT -"const __global float *srca_tail2 = srca_read2 + offset;", // NOLINT -"const __global float *srca_tail3 = srca_read3 + offset;", // NOLINT -"const __global float *srca_tail4 = srca_read4 + offset;", // NOLINT -"const __global float *srca_tail5 = srca_read5 + offset;", // NOLINT -"const __global float *srca_tail6 = srca_read6 + offset;", // NOLINT -"const __global float *srca_tail7 = srca_read7 + offset;", // NOLINT +"const __global Dtype *srcb_tail = srcb_read + offset;", // NOLINT +"", // NOLINT +"const __global Dtype *srca_tail0 = srca_read0 + offset;", // NOLINT +"const __global Dtype *srca_tail1 = srca_read1 + offset;", // NOLINT +"const __global Dtype *srca_tail2 = srca_read2 + offset;", // NOLINT +"const __global Dtype *srca_tail3 = srca_read3 + offset;", // NOLINT +"const __global Dtype *srca_tail4 = srca_read4 + offset;", // NOLINT +"const __global Dtype *srca_tail5 = srca_read5 + offset;", // NOLINT +"const __global Dtype *srca_tail6 = srca_read6 + offset;", // NOLINT +"const __global Dtype *srca_tail7 = srca_read7 + offset;", // NOLINT "#pragma unroll", // NOLINT "for(short item = 0; item < tail_items; ++item) {", // NOLINT "work0[lid] += srca_tail0[item] * srcb_tail[item];", // NOLINT @@ -4267,19 +4513,19 @@ static std::vector> cl_kernels{ "__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1)))", // NOLINT "__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM)))", // NOLINT "__kernel void TEMPLATE(gemm_buffer_TN, Dtype)(", // NOLINT -"const __global float *src0, int off0,", // NOLINT -"const __global float *src1, int off1,", // NOLINT -"__global float *dst, int offd,", // NOLINT +"const __global Dtype *src0, int off0,", // NOLINT +"const __global Dtype *src1, int off1,", // NOLINT +"__global Dtype *dst, int offd,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha_in,", // NOLINT -"float beta_in,", // NOLINT +"KERNEL_ARG_DTYPE alpha_in,", // NOLINT +"KERNEL_ARG_DTYPE beta_in,", // NOLINT "int start_index)", // NOLINT "", // NOLINT "{", // NOLINT -"const float alpha = (float)alpha_in;", // NOLINT -"const float beta = (float)beta_in;", // NOLINT +"const Dtype alpha = (Dtype)alpha_in;", // NOLINT +"const Dtype beta = (Dtype)beta_in;", // NOLINT "const int group_x = get_group_id(0);", // NOLINT "const int group_y = get_group_id(1);", // NOLINT "const int local_x = get_local_id(0);", // NOLINT @@ -4287,62 +4533,62 @@ static std::vector> cl_kernels{ "const int global_x = get_global_id(0);", // NOLINT "const int global_y = get_global_id(1);", // NOLINT "", // NOLINT -"float4 brow;", // NOLINT +"Dtype4 brow;", // NOLINT "", // NOLINT -"__global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT "", // NOLINT -"const __global float *src0_read = src0 + (local_x * (TILE_K / SIMD_SIZE_GEMM) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT +"const __global Dtype *src0_read = src0 + (local_x * (TILE_K / SIMD_SIZE_GEMM) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT "", // NOLINT -"const __global float *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;", // NOLINT +"const __global Dtype *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;", // NOLINT "", // NOLINT -"float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);", // NOLINT -"float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N);", // NOLINT -"float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);", // NOLINT -"float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);", // NOLINT -"float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);", // NOLINT -"float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);", // NOLINT -"float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);", // NOLINT -"float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);", // NOLINT +"Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);", // NOLINT +"Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N);", // NOLINT +"Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);", // NOLINT +"Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);", // NOLINT +"Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);", // NOLINT +"Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);", // NOLINT +"Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);", // NOLINT +"Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);", // NOLINT "", // NOLINT "int end_index = min(start_index + 256, K);", // NOLINT "while( start_index + TILE_K <= end_index ) {", // NOLINT -"float8 arow0 = alpha * vload8(0, src0_read);", // NOLINT -"float8 arow1 = alpha * vload8(0, src0_read + M);", // NOLINT -"", // NOLINT -"#define MM_DOT_PRODUCT( _arow ) brow = vload4(0, src1_read0); src1_read0 += N; dot00 = mad( (float4)(_arow.s0), brow, dot00 ); dot01 = mad( (float4)(_arow.s1), brow, dot01 ); dot02 = mad( (float4)(_arow.s2), brow, dot02 ); dot03 = mad( (float4)(_arow.s3), brow, dot03 ); dot04 = mad( (float4)(_arow.s4), brow, dot04 ); dot05 = mad( (float4)(_arow.s5), brow, dot05 ); dot06 = mad( (float4)(_arow.s6), brow, dot06 ); dot07 = mad( (float4)(_arow.s7), brow, dot07 );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) );", // NOLINT +"Dtype8 arow0 = alpha * vload8(0, src0_read);", // NOLINT +"Dtype8 arow1 = alpha * vload8(0, src0_read + M);", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _arow ) brow = vload4(0, src1_read0); src1_read0 += N; dot00 = mad( (Dtype4)(_arow.s0), brow, dot00 ); dot01 = mad( (Dtype4)(_arow.s1), brow, dot01 ); dot02 = mad( (Dtype4)(_arow.s2), brow, dot02 ); dot03 = mad( (Dtype4)(_arow.s3), brow, dot03 ); dot04 = mad( (Dtype4)(_arow.s4), brow, dot04 ); dot05 = mad( (Dtype4)(_arow.s5), brow, dot05 ); dot06 = mad( (Dtype4)(_arow.s6), brow, dot06 ); dot07 = mad( (Dtype4)(_arow.s7), brow, dot07 );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) );", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) );", // NOLINT "#endif", // NOLINT "#undef MM_DOT_PRODUCT", // NOLINT "", // NOLINT @@ -4351,43 +4597,43 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "if(start_index < end_index) {", // NOLINT -"float8 arow0 = ((start_index + local_x * 2) < K) ? alpha * vload8(0, src0_read) : (float8)0.0f;", // NOLINT -"float8 arow1 = ((start_index + local_x * 2 + 1) < K) ? alpha * vload8(0, src0_read + M) : (float8)0.0f;", // NOLINT -"", // NOLINT -"#define MM_DOT_PRODUCT( _arow ) brow = (start_index < K) ? vload4(0, src1_read0) : (float4)0.0f; src1_read0 += N; start_index++; dot00 = mad( (float4)(_arow.s0), brow, dot00 ); dot01 = mad( (float4)(_arow.s1), brow, dot01 ); dot02 = mad( (float4)(_arow.s2), brow, dot02 ); dot03 = mad( (float4)(_arow.s3), brow, dot03 ); dot04 = mad( (float4)(_arow.s4), brow, dot04 ); dot05 = mad( (float4)(_arow.s5), brow, dot05 ); dot06 = mad( (float4)(_arow.s6), brow, dot06 ); dot07 = mad( (float4)(_arow.s7), brow, dot07 );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) );", // NOLINT +"Dtype8 arow0 = ((start_index + local_x * 2) < K) ? alpha * vload8(0, src0_read) : (Dtype8)0.0f;", // NOLINT +"Dtype8 arow1 = ((start_index + local_x * 2 + 1) < K) ? alpha * vload8(0, src0_read + M) : (Dtype8)0.0f;", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _arow ) brow = (start_index < K) ? vload4(0, src1_read0) : (Dtype4)0.0f; src1_read0 += N; start_index++; dot00 = mad( (Dtype4)(_arow.s0), brow, dot00 ); dot01 = mad( (Dtype4)(_arow.s1), brow, dot01 ); dot02 = mad( (Dtype4)(_arow.s2), brow, dot02 ); dot03 = mad( (Dtype4)(_arow.s3), brow, dot03 ); dot04 = mad( (Dtype4)(_arow.s4), brow, dot04 ); dot05 = mad( (Dtype4)(_arow.s5), brow, dot05 ); dot06 = mad( (Dtype4)(_arow.s6), brow, dot06 ); dot07 = mad( (Dtype4)(_arow.s7), brow, dot07 );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) );", // NOLINT "#if TYPE == TYPE_HALF", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) );", // NOLINT -"MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) );", // NOLINT +"MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) );", // NOLINT "#endif", // NOLINT "#undef MM_DOT_PRODUCT", // NOLINT "}", // NOLINT @@ -4493,19 +4739,19 @@ static std::vector> cl_kernels{ "__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT "__attribute__((intel_reqd_sub_group_size(8)))", // NOLINT "__kernel void TEMPLATE(gemm_buffer_TT, Dtype)(", // NOLINT -"const __global float *src0, int off0,", // NOLINT -"const __global float *src1, int off1,", // NOLINT -"__global float *dst, int offd,", // NOLINT +"const __global Dtype *src0, int off0,", // NOLINT +"const __global Dtype *src1, int off1,", // NOLINT +"__global Dtype *dst, int offd,", // NOLINT "int M,", // NOLINT "int N,", // NOLINT "int K,", // NOLINT -"float alpha_in,", // NOLINT -"float beta_in,", // NOLINT +"KERNEL_ARG_DTYPE alpha_in,", // NOLINT +"KERNEL_ARG_DTYPE beta_in,", // NOLINT "int start_index)", // NOLINT "", // NOLINT "{", // NOLINT -"const float alpha = (float)alpha_in;", // NOLINT -"const float beta = (float)beta_in;", // NOLINT +"const Dtype alpha = (Dtype)alpha_in;", // NOLINT +"const Dtype beta = (Dtype)beta_in;", // NOLINT "const int group_x = get_group_id(0);", // NOLINT "const int group_y = get_group_id(1);", // NOLINT "const int local_x = get_local_id(0);", // NOLINT @@ -4513,30 +4759,30 @@ static std::vector> cl_kernels{ "const int global_x = get_global_id(0);", // NOLINT "const int global_y = get_global_id(1);", // NOLINT "", // NOLINT -"float8 dot0 = 0.f;", // NOLINT -"float8 dot1 = 0.f;", // NOLINT -"float8 dot2 = 0.f;", // NOLINT -"float8 dot3 = 0.f;", // NOLINT +"Dtype8 dot0 = 0.f;", // NOLINT +"Dtype8 dot1 = 0.f;", // NOLINT +"Dtype8 dot2 = 0.f;", // NOLINT +"Dtype8 dot3 = 0.f;", // NOLINT "", // NOLINT -"float16 brow0;", // NOLINT -"float16 brow1;", // NOLINT -"float16 brow2;", // NOLINT -"float16 brow3;", // NOLINT +"Dtype16 brow0;", // NOLINT +"Dtype16 brow1;", // NOLINT +"Dtype16 brow2;", // NOLINT +"Dtype16 brow3;", // NOLINT "", // NOLINT -"__global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT "", // NOLINT -"const __global float *src0_read = src0 + (local_x * (TILE_K / 8) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT +"const __global Dtype *src0_read = src0 + (local_x * (TILE_K / 8) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0;", // NOLINT "", // NOLINT -"const __global float *src1_read0 = src1 + (local_x * VEC_SIZE + (group_x * TILE_N)) * K + start_index + off1;", // NOLINT +"const __global Dtype *src1_read0 = src1 + (local_x * VEC_SIZE + (group_x * TILE_N)) * K + start_index + off1;", // NOLINT "", // NOLINT -"float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);", // NOLINT -"float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N);", // NOLINT -"float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);", // NOLINT -"float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);", // NOLINT -"float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);", // NOLINT -"float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);", // NOLINT -"float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);", // NOLINT -"float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);", // NOLINT +"Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);", // NOLINT +"Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N);", // NOLINT +"Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);", // NOLINT +"Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);", // NOLINT +"Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);", // NOLINT +"Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);", // NOLINT +"Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);", // NOLINT +"Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);", // NOLINT "", // NOLINT "int end_index = min(start_index + 256, K);", // NOLINT "while( start_index + TILE_K <= end_index ) {", // NOLINT @@ -4545,10 +4791,10 @@ static std::vector> cl_kernels{ "brow2 = vload16(0, src1_read0 + 2 * K);", // NOLINT "brow3 = vload16(0, src1_read0 + 3 * K);", // NOLINT "", // NOLINT -"float8 arow0 = alpha * vload8(0, src0_read);", // NOLINT -"float8 arow1 = alpha * vload8(0, src0_read + M);", // NOLINT +"Dtype8 arow0 = alpha * vload8(0, src0_read);", // NOLINT +"Dtype8 arow1 = alpha * vload8(0, src0_read + M);", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( _brow, _dot) _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (float8)_brow.s0, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (float8)_brow.s1, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (float8)_brow.s2, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (float8)_brow.s3, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (float8)_brow.s4, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (float8)_brow.s5, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (float8)_brow.s6, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (float8)_brow.s7, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (float8)_brow.s8, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (float8)_brow.s9, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (float8)_brow.sa, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (float8)_brow.sb, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (float8)_brow.sc, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (float8)_brow.sd, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (float8)_brow.se, _dot ); _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (float8)_brow.sf, _dot );", // NOLINT +"#define MM_DOT_PRODUCT( _brow, _dot) _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (Dtype8)_brow.s0, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (Dtype8)_brow.s1, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (Dtype8)_brow.s2, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (Dtype8)_brow.s3, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (Dtype8)_brow.s4, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (Dtype8)_brow.s5, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (Dtype8)_brow.s6, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (Dtype8)_brow.s7, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (Dtype8)_brow.s8, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (Dtype8)_brow.s9, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (Dtype8)_brow.sa, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (Dtype8)_brow.sb, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (Dtype8)_brow.sc, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (Dtype8)_brow.sd, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (Dtype8)_brow.se, _dot ); _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (Dtype8)_brow.sf, _dot );", // NOLINT "MM_DOT_PRODUCT( brow0, dot0 );", // NOLINT "MM_DOT_PRODUCT( brow1, dot1 );", // NOLINT "MM_DOT_PRODUCT( brow2, dot2 );", // NOLINT @@ -4566,10 +4812,10 @@ static std::vector> cl_kernels{ "brow2 = vload16(0, src1_read0); src1_read0 += K;", // NOLINT "brow3 = vload16(0, src1_read0);", // NOLINT "", // NOLINT -"float8 arow0 = alpha * vload8(0, src0_read);", // NOLINT -"float8 arow1 = alpha * vload8(0, src0_read + M);", // NOLINT +"Dtype8 arow0 = alpha * vload8(0, src0_read);", // NOLINT +"Dtype8 arow1 = alpha * vload8(0, src0_read + M);", // NOLINT "", // NOLINT -"#define MM_DOT_PRODUCT( _brow, _dot) _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (float8)_brow.s0, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (float8)_brow.s1, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (float8)_brow.s2, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (float8)_brow.s3, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (float8)_brow.s4, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (float8)_brow.s5, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (float8)_brow.s6, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (float8)_brow.s7, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (float8)_brow.s8, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (float8)_brow.s9, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (float8)_brow.sa, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (float8)_brow.sb, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (float8)_brow.sc, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (float8)_brow.sd, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (float8)_brow.se, _dot ) : _dot; _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (float8)_brow.sf, _dot ) : _dot;", // NOLINT +"#define MM_DOT_PRODUCT( _brow, _dot) _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (Dtype8)_brow.s0, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (Dtype8)_brow.s1, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (Dtype8)_brow.s2, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (Dtype8)_brow.s3, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (Dtype8)_brow.s4, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (Dtype8)_brow.s5, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (Dtype8)_brow.s6, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (Dtype8)_brow.s7, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (Dtype8)_brow.s8, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (Dtype8)_brow.s9, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (Dtype8)_brow.sa, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (Dtype8)_brow.sb, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (Dtype8)_brow.sc, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (Dtype8)_brow.sd, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (Dtype8)_brow.se, _dot ) : _dot; _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (Dtype8)_brow.sf, _dot ) : _dot;", // NOLINT "int w = start_index;", // NOLINT "MM_DOT_PRODUCT( brow0, dot0 );", // NOLINT "w = start_index;", // NOLINT @@ -4581,14 +4827,14 @@ static std::vector> cl_kernels{ "#undef MM_DOT_PRODUCT", // NOLINT "}", // NOLINT "", // NOLINT -"dot00 += (float4)(dot0.s0, dot1.s0, dot2.s0, dot3.s0);", // NOLINT -"dot01 += (float4)(dot0.s1, dot1.s1, dot2.s1, dot3.s1);", // NOLINT -"dot02 += (float4)(dot0.s2, dot1.s2, dot2.s2, dot3.s2);", // NOLINT -"dot03 += (float4)(dot0.s3, dot1.s3, dot2.s3, dot3.s3);", // NOLINT -"dot04 += (float4)(dot0.s4, dot1.s4, dot2.s4, dot3.s4);", // NOLINT -"dot05 += (float4)(dot0.s5, dot1.s5, dot2.s5, dot3.s5);", // NOLINT -"dot06 += (float4)(dot0.s6, dot1.s6, dot2.s6, dot3.s6);", // NOLINT -"dot07 += (float4)(dot0.s7, dot1.s7, dot2.s7, dot3.s7);", // NOLINT +"dot00 += (Dtype4)(dot0.s0, dot1.s0, dot2.s0, dot3.s0);", // NOLINT +"dot01 += (Dtype4)(dot0.s1, dot1.s1, dot2.s1, dot3.s1);", // NOLINT +"dot02 += (Dtype4)(dot0.s2, dot1.s2, dot2.s2, dot3.s2);", // NOLINT +"dot03 += (Dtype4)(dot0.s3, dot1.s3, dot2.s3, dot3.s3);", // NOLINT +"dot04 += (Dtype4)(dot0.s4, dot1.s4, dot2.s4, dot3.s4);", // NOLINT +"dot05 += (Dtype4)(dot0.s5, dot1.s5, dot2.s5, dot3.s5);", // NOLINT +"dot06 += (Dtype4)(dot0.s6, dot1.s6, dot2.s6, dot3.s6);", // NOLINT +"dot07 += (Dtype4)(dot0.s7, dot1.s7, dot2.s7, dot3.s7);", // NOLINT "", // NOLINT "if(global_x * 4 < N && global_y * 8 < M) {", // NOLINT "if(mad24(global_x, 4, 3) < N) {", // NOLINT @@ -4681,6 +4927,9 @@ static std::vector> cl_kernels{ "#undef TILE_M", // NOLINT "#undef TILE_K", // NOLINT "#undef TILE_N", // NOLINT +"#undef SIMD_SIZE_GEMM", // NOLINT +"#undef SHUFFLE_TYPE2", // NOLINT +"#undef SHUFFLE_TYPE8", // NOLINT "", // NOLINT "#endif", // NOLINT ""}, // NOLINT @@ -5002,18 +5251,18 @@ static std::vector> cl_kernels{ "__kernel void TEMPLATE(lrn_compute_output,Dtype)(const int_tp nthreads,", // NOLINT "__global const Dtype* in,", // NOLINT "__global const Dtype* scale,", // NOLINT -"const Dtype negative_beta,", // NOLINT +"const KERNEL_ARG_DTYPE negative_beta,", // NOLINT "__global Dtype* out) {", // NOLINT "for (int_tp index = get_global_id(0); index < nthreads;", // NOLINT "index += get_global_size(0)) {", // NOLINT -"out[index] = in[index] * pow(scale[index], negative_beta);", // NOLINT +"out[index] = in[index] * pow(scale[index], (Dtype)negative_beta);", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(lrn_fill_scale,Dtype)(const int_tp nthreads, __global const Dtype* in,", // NOLINT "const int_tp num, const int_tp channels,", // NOLINT "const int_tp height, const int_tp width, const int_tp size,", // NOLINT -"const Dtype alpha_over_size, const Dtype k,", // NOLINT +"const KERNEL_ARG_DTYPE alpha_over_size, const KERNEL_ARG_DTYPE k,", // NOLINT "__global Dtype* const scale) {", // NOLINT "for (int_tp index = get_global_id(0); index < nthreads;", // NOLINT "index += get_global_size(0)) {", // NOLINT @@ -5064,8 +5313,8 @@ static std::vector> cl_kernels{ "__global const Dtype* top_diff, const int_tp num,", // NOLINT "const int_tp channels, const int_tp height,", // NOLINT "const int_tp width, const int_tp size,", // NOLINT -"const Dtype negative_beta,", // NOLINT -"const Dtype cache_ratio,", // NOLINT +"const KERNEL_ARG_DTYPE negative_beta,", // NOLINT +"const KERNEL_ARG_DTYPE cache_ratio,", // NOLINT "__global Dtype* bottom_diff) {", // NOLINT "for (int_tp index = get_global_id(0); index < nthreads;", // NOLINT "index += get_global_size(0)) {", // NOLINT @@ -5099,7 +5348,7 @@ static std::vector> cl_kernels{ "* top_off[(head - size) * step] / scale_off[(head - size) * step];", // NOLINT "}", // NOLINT "bottom_diff_off[(head - post_pad) * step] = top_diff_off[(head - post_pad)", // NOLINT -"* step] * pow(scale_off[(head - post_pad) * step], negative_beta)", // NOLINT +"* step] * pow(scale_off[(head - post_pad) * step], (Dtype)negative_beta)", // NOLINT "- cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio;", // NOLINT "++head;", // NOLINT "}", // NOLINT @@ -5110,7 +5359,7 @@ static std::vector> cl_kernels{ "* top_off[(head - size) * step] / scale_off[(head - size) * step];", // NOLINT "}", // NOLINT "bottom_diff_off[(head - post_pad) * step] = top_diff_off[(head - post_pad)", // NOLINT -"* step] * pow(scale_off[(head - post_pad) * step], negative_beta)", // NOLINT +"* step] * pow(scale_off[(head - post_pad) * step], (Dtype)negative_beta)", // NOLINT "- cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio;", // NOLINT "++head;", // NOLINT "}", // NOLINT @@ -5133,9 +5382,9 @@ static std::vector> cl_kernels{ "const int_tp height, const int_tp width,", // NOLINT "const int_tp tiled_height, int_tp tiled_width,", // NOLINT "const int_tp size,", // NOLINT -"const Dtype alpha_over_size, const Dtype k,", // NOLINT +"const KERNEL_ARG_DTYPE alpha_over_size, const KERNEL_ARG_DTYPE k,", // NOLINT "__global Dtype* const out,", // NOLINT -"const Dtype negative_beta,", // NOLINT +"const KERNEL_ARG_DTYPE negative_beta,", // NOLINT "const int_tp pool_h, const int_tp pool_w, const int_tp pool_stride_h, int_tp pool_stride_w,", // NOLINT "const int_tp pooled_height, const int_tp pooled_width,", // NOLINT "const int_tp tile_pooled_block_h, const int_tp tile_pooled_block_w) {", // NOLINT @@ -5165,7 +5414,7 @@ static std::vector> cl_kernels{ "while ( head < channels + post_pad ) {", // NOLINT "int ph = 0;", // NOLINT "int cur_out_h = 0;", // NOLINT -"Dtype output_val = -FLT_MAX;", // NOLINT +"Dtype output_val = -DTYPE_MAX;", // NOLINT "// fill the scale at [n, :, h, w]", // NOLINT "// accumulate values", // NOLINT "for( int lrn_out_h = 0; lrn_out_h < TILE_H && (lrn_out_h + h) < height; lrn_out_h++) {", // NOLINT @@ -5181,11 +5430,11 @@ static std::vector> cl_kernels{ "// compute output.", // NOLINT "if (head >= post_pad) {", // NOLINT "scale_val = k + prev_val * alpha_over_size;", // NOLINT -"Dtype tmp = -FLT_MAX;", // NOLINT +"Dtype tmp = -DTYPE_MAX;", // NOLINT "//if (w + get_local_id(1) < width)", // NOLINT -"tmp = in_off[(head - post_pad) * step + width * lrn_out_h] * native_powr(scale_val, negative_beta);", // NOLINT +"tmp = in_off[(head - post_pad) * step + width * lrn_out_h] * native_powr(scale_val, (Dtype)negative_beta);", // NOLINT "", // NOLINT -"Dtype h_max_val = -FLT_MAX;", // NOLINT +"Dtype h_max_val = -DTYPE_MAX;", // NOLINT "int index = (get_local_id(1) * pool_stride_w) % SIMD_WIDTH;", // NOLINT "for(int i = 0; i < pool_w; i++) {", // NOLINT "Dtype val = intel_sub_group_shuffle(tmp, index);", // NOLINT @@ -5230,9 +5479,9 @@ static std::vector> cl_kernels{ "__kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int_tp nthreads, __global const Dtype* in,", // NOLINT "const int_tp num, const int_tp channels,", // NOLINT "const int_tp height, const int_tp width, const int_tp size,", // NOLINT -"const Dtype alpha_over_size, const Dtype k,", // NOLINT +"const KERNEL_ARG_DTYPE alpha_over_size, const KERNEL_ARG_DTYPE k,", // NOLINT "__global Dtype* const out,", // NOLINT -"const Dtype negative_beta) {", // NOLINT +"const KERNEL_ARG_DTYPE negative_beta) {", // NOLINT "for (int_tp index = get_global_id(0); index < nthreads;", // NOLINT "index += get_global_size(0)) {", // NOLINT "// find out the local offset", // NOLINT @@ -5262,7 +5511,7 @@ static std::vector> cl_kernels{ "* in_off[(head - size) * step];", // NOLINT "}", // NOLINT "scale_val = k + accum_scale * alpha_over_size;", // NOLINT -"out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta);", // NOLINT +"out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta);", // NOLINT "++head;", // NOLINT "}", // NOLINT "// subtract only", // NOLINT @@ -5272,7 +5521,7 @@ static std::vector> cl_kernels{ "* in_off[(head - size) * step];", // NOLINT "}", // NOLINT "scale_val = k + accum_scale * alpha_over_size;", // NOLINT -"out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta);", // NOLINT +"out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta);", // NOLINT "++head;", // NOLINT "}", // NOLINT "}", // NOLINT @@ -5281,10 +5530,10 @@ static std::vector> cl_kernels{ "__kernel void TEMPLATE(lrn_full,Dtype)(const int_tp nthreads, __global const Dtype* in,", // NOLINT "const int_tp num, const int_tp channels,", // NOLINT "const int_tp height, const int_tp width, const int_tp size,", // NOLINT -"const Dtype alpha_over_size, const Dtype k,", // NOLINT +"const KERNEL_ARG_DTYPE alpha_over_size, const KERNEL_ARG_DTYPE k,", // NOLINT "__global Dtype* const scale,", // NOLINT "__global Dtype* const out,", // NOLINT -"const Dtype negative_beta) {", // NOLINT +"const KERNEL_ARG_DTYPE negative_beta) {", // NOLINT "for (int_tp index = get_global_id(0); index < nthreads;", // NOLINT "index += get_global_size(0)) {", // NOLINT "// find out the local offset", // NOLINT @@ -5316,7 +5565,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "scale_val = k + accum_scale * alpha_over_size;", // NOLINT "scale_off[(head - post_pad) * step] = scale_val;", // NOLINT -"out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta);", // NOLINT +"out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta);", // NOLINT "++head;", // NOLINT "}", // NOLINT "// subtract only", // NOLINT @@ -5327,7 +5576,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "scale_val = k + accum_scale * alpha_over_size;", // NOLINT "scale_off[(head - post_pad) * step] = scale_val;", // NOLINT -"out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta);", // NOLINT +"out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta);", // NOLINT "++head;", // NOLINT "}", // NOLINT "}", // NOLINT @@ -5451,7 +5700,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "}", // NOLINT "", // NOLINT -"__kernel void TEMPLATE(add_scalar,Dtype)(const int_tp N, const Dtype alpha,", // NOLINT +"__kernel void TEMPLATE(add_scalar,Dtype)(const int_tp N, const KERNEL_ARG_DTYPE alpha,", // NOLINT "__global Dtype* Y,", // NOLINT "const int_tp offY) {", // NOLINT "for (int_tp index = get_global_id(0); index < N; index += get_global_size(0)) {", // NOLINT @@ -5510,7 +5759,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(powx,Dtype)(const int_tp n, __global const Dtype* a,", // NOLINT -"const int_tp offa, Dtype alpha,", // NOLINT +"const int_tp offa, KERNEL_ARG_DTYPE alpha,", // NOLINT "__global Dtype* y,", // NOLINT "const int_tp offy) {", // NOLINT "for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) {", // NOLINT @@ -5539,157 +5788,182 @@ static std::vector> cl_kernels{ "}", // NOLINT "}", // NOLINT ""}, // NOLINT - {"#ifndef __OPENCL_VERSION__", // NOLINT -"#include \"header.cl\"", // NOLINT -"#endif", // NOLINT -"", // NOLINT -"__kernel void TEMPLATE(matvec_mul4,Dtype)(", // NOLINT -"__global const float * A,", // NOLINT -"int offA,", // NOLINT -"unsigned int A_col_size,", // NOLINT -"unsigned int trail_item,", // NOLINT -"__global const float * v,", // NOLINT -"int offv,", // NOLINT -"float alpha,", // NOLINT -"float beta,", // NOLINT -"__global float4 * result,", // NOLINT -"int offr,", // NOLINT -"__local float4 * work)", // NOLINT + {"void TEMPLATE(matvec_mul_trail_rows,Dtype)(unsigned int M,", // NOLINT +"unsigned int N,", // NOLINT +"int row_gid,", // NOLINT +"int lid,", // NOLINT +"const __global Dtype* src0_read,", // NOLINT +"int lda,", // NOLINT +"const __global Dtype* src1_read,", // NOLINT +"int incv,", // NOLINT +"__local Dtype4* work,", // NOLINT +"Dtype alpha,", // NOLINT +"Dtype beta,", // NOLINT +"__global Dtype* result,", // NOLINT +"int incr)", // NOLINT "{", // NOLINT -"unsigned int row_gid = get_group_id(0);", // NOLINT -"unsigned int lid = get_local_id(0);", // NOLINT -"const __global float *src0_read = A + row_gid * 4 * A_col_size + offA;", // NOLINT -"const __global float *src1_read = v + offv;", // NOLINT -"result = (__global float4*)((__global float*)result + offr);", // NOLINT -"float4 dot0 = (float4)(0.f);", // NOLINT -"float4 dot1 = (float4)(0.f);", // NOLINT -"float4 dot2 = (float4)(0.f);", // NOLINT -"float4 dot3 = (float4)(0.f);", // NOLINT -"", // NOLINT -"unsigned int i = lid;", // NOLINT -"while( i < A_col_size / 4) {", // NOLINT -"const float4 a0 = vload4(i, src0_read);", // NOLINT -"const float4 a1 = vload4(i, src0_read + A_col_size);", // NOLINT -"const float4 a2 = vload4(i, src0_read + 2 * A_col_size);", // NOLINT -"const float4 a3 = vload4(i, src0_read + 3 * A_col_size);", // NOLINT -"", // NOLINT -"const float4 b0 = vload4(i, src1_read);", // NOLINT -"", // NOLINT -"dot0 += a0 * b0;", // NOLINT -"dot1 += a1 * b0;", // NOLINT -"dot2 += a2 * b0;", // NOLINT -"dot3 += a3 * b0;", // NOLINT -"", // NOLINT -"i += get_local_size(0);", // NOLINT -"}", // NOLINT +"__local Dtype* work_each = (__local Dtype*)work;", // NOLINT "", // NOLINT -"work[lid].s0 = dot0.x + dot0.y + dot0.z + dot0.w;", // NOLINT -"work[lid].s1 = dot1.x + dot1.y + dot1.z + dot1.w;", // NOLINT -"work[lid].s2 = dot2.x + dot2.y + dot2.z + dot2.w;", // NOLINT -"work[lid].s3 = dot3.x + dot3.y + dot3.z + dot3.w;", // NOLINT +"int rows = M - row_gid * 4;", // NOLINT "", // NOLINT -"if(i == A_col_size / 4)", // NOLINT -"{", // NOLINT -"if(trail_item != 0)", // NOLINT -"{", // NOLINT -"const __global float *src0_trail = src0_read + i * 4;", // NOLINT -"const __global float *src1_trail = src1_read + i * 4;", // NOLINT -"for(unsigned int i = 0; i < trail_item; ++i) {", // NOLINT -"const float at0 = src0_trail[i];", // NOLINT -"const float at1 = src0_trail[i + A_col_size];", // NOLINT -"const float at2 = src0_trail[i + 2 * A_col_size];", // NOLINT -"const float at3 = src0_trail[i + 3 * A_col_size];", // NOLINT +"Dtype4 dot[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};", // NOLINT "", // NOLINT -"const float bt = src1_trail[i];", // NOLINT +"int i = lid;", // NOLINT +"while( i < N / 4) {", // NOLINT +"const Dtype4 b0 = {src1_read[i*4*incv], src1_read[(i*4+1)*incv], src1_read[(i*4+2)*incv], src1_read[(i*4+3)*incv]};", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"dot[j] += b0 * vload4(i, src0_read + j * lda);", // NOLINT +"}", // NOLINT "", // NOLINT -"work[lid].s0 += at0 * bt;", // NOLINT -"work[lid].s1 += at1 * bt;", // NOLINT -"work[lid].s2 += at2 * bt;", // NOLINT -"work[lid].s3 += at3 * bt;", // NOLINT +"i += get_local_size(0);", // NOLINT "}", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"work_each[lid * 4 + j] = dot[j].x + dot[j].y + dot[j].z + dot[j].w;", // NOLINT "}", // NOLINT "", // NOLINT +"if(i == N / 4) {", // NOLINT +"short trail_item = N % 4;", // NOLINT +"", // NOLINT +"if(trail_item != 0) {", // NOLINT +"const __global Dtype *src0_trail = src0_read + i * 4;", // NOLINT +"const __global Dtype *src1_trail = src1_read + i * 4 * incv;", // NOLINT +"#pragma unroll", // NOLINT +"for(short i = 0; i < trail_item; ++i) {", // NOLINT +"const Dtype bt = src1_trail[i*incv];", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"work_each[lid * 4 + j] += bt * src0_trail[i + j * lda];", // NOLINT +"}", // NOLINT +"}", // NOLINT +"}", // NOLINT "}", // NOLINT "", // NOLINT -"for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) {", // NOLINT +"for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {", // NOLINT "barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT "if(lid < stride)", // NOLINT "work[lid] += work[lid+stride];", // NOLINT "}", // NOLINT "", // NOLINT "if(lid == 0) {", // NOLINT -"if(beta == (Dtype)0)", // NOLINT -"result[row_gid] = alpha * work[0];", // NOLINT -"else", // NOLINT -"result[row_gid] = alpha * work[0] + beta * result[row_gid];", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < rows; ++j) {", // NOLINT +"result[(row_gid * 4 + j) * incr] = alpha * work_each[j] + beta * result[(row_gid * 4 + j) * incr];", // NOLINT +"}", // NOLINT "}", // NOLINT "", // NOLINT "}", // NOLINT "", // NOLINT -"/* This kernel used for the trailing rows when row_of_A %4 !=0 */", // NOLINT -"__kernel void TEMPLATE(matvec_mul1,Dtype)(", // NOLINT -"__global const float * A,", // NOLINT +"__kernel void TEMPLATE(matvec_mul,Dtype)(", // NOLINT +"unsigned int M,", // NOLINT +"unsigned int N,", // NOLINT +"__global const Dtype * A,", // NOLINT "int offA,", // NOLINT -"unsigned int A_col_size,", // NOLINT -"unsigned int row_offset,", // NOLINT -"unsigned int trail_item,", // NOLINT -"__global const float * v,", // NOLINT +"int lda,", // NOLINT +"__global const Dtype * v,", // NOLINT "int offv,", // NOLINT -"float alpha,", // NOLINT -"float beta,", // NOLINT -"__global float * result,", // NOLINT +"int incv,", // NOLINT +"KERNEL_ARG_DTYPE alpha,", // NOLINT +"KERNEL_ARG_DTYPE beta,", // NOLINT +"__global Dtype * result,", // NOLINT "int offr,", // NOLINT -"__local float * work)", // NOLINT +"int incr)", // NOLINT "{", // NOLINT -"unsigned int row_gid = get_group_id(0);", // NOLINT -"unsigned int lid = get_local_id(0);", // NOLINT -"", // NOLINT -"const __global float *src0_read = A + (row_offset + row_gid) * A_col_size + offA;", // NOLINT -"const __global float *src1_read = v + + offv;", // NOLINT +"int row_gid = get_group_id(0);", // NOLINT +"int lid = get_local_id(0);", // NOLINT +"const __global Dtype *src0_read = A + row_gid * 4 * lda + offA;", // NOLINT +"const __global Dtype *src1_read = v + offv;", // NOLINT "result = result + offr;", // NOLINT -"float4 dot0 = (float4)(0.f);", // NOLINT "", // NOLINT -"unsigned int i = lid;", // NOLINT -"while( i < A_col_size / 4)", // NOLINT -"{", // NOLINT -"const float4 a0 = vload4(i, src0_read);", // NOLINT -"const float4 b0 = vload4(i, src1_read);", // NOLINT +"src1_read += incv > 0 ? 0 : (1 - N) * incv;", // NOLINT +"result += incr > 0 ? 0 : (1 - M) * incr;", // NOLINT +"__local Dtype4 work[128];", // NOLINT +"__local Dtype* work_each = (__local Dtype*)work;", // NOLINT "", // NOLINT -"dot0 += a0 * b0;", // NOLINT +"if(row_gid == M / 4)", // NOLINT +"TEMPLATE(matvec_mul_trail_rows,Dtype)(M, N, row_gid, lid, src0_read, lda, src1_read, incv, work, alpha, beta, result, incr);", // NOLINT +"else", // NOLINT +"{", // NOLINT +"Dtype4 dot[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.f), (Dtype4)(0.f)};", // NOLINT +"int i = lid;", // NOLINT +"while( i < N / 4) {", // NOLINT +"const Dtype4 b0 = {src1_read[i*4*incv], src1_read[(i*4+1)*incv], src1_read[(i*4+2)*incv], src1_read[(i*4+3)*incv]};", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < 4; ++j) {", // NOLINT +"dot[j] += b0 * vload4(i, src0_read + j * lda);", // NOLINT +"}", // NOLINT "i += get_local_size(0);", // NOLINT "}", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < 4; ++j) {", // NOLINT +"work_each[lid * 4 + j] = dot[j].x + dot[j].y + dot[j].z + dot[j].w;", // NOLINT +"}", // NOLINT "", // NOLINT -"work[lid] = dot0.x + dot0.y + dot0.z + dot0.w;", // NOLINT -"", // NOLINT -"if(i == A_col_size / 4)", // NOLINT -"{", // NOLINT -"if(trail_item != 0)", // NOLINT -"{", // NOLINT -"const __global float *src0_trail = src0_read + i * 4;", // NOLINT -"const __global float *src1_trail = src1_read + i * 4;", // NOLINT -"for(unsigned int i = 0; i < trail_item; ++i) {", // NOLINT -"const float at0 = src0_trail[i];", // NOLINT -"const float bt = src1_trail[i];", // NOLINT -"", // NOLINT -"work[lid] += at0 * bt;", // NOLINT +"if(i == N / 4) {", // NOLINT +"short trail_item = N % 4;", // NOLINT +"if(trail_item != 0) {", // NOLINT +"const __global Dtype *src0_trail = src0_read + i * 4;", // NOLINT +"const __global Dtype *src1_trail = src1_read + i * 4 * incv;", // NOLINT +"#pragma unroll", // NOLINT +"for(short i = 0; i < trail_item; ++i) {", // NOLINT +"const Dtype bt = src1_trail[i * incv];", // NOLINT +"#pragma unroll", // NOLINT +"for(int j = 0; j < 4; ++j) {", // NOLINT +"work_each[lid * 4 + j] += bt * src0_trail[i + j * lda];", // NOLINT "}", // NOLINT "}", // NOLINT -"", // NOLINT "}", // NOLINT -"for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) {", // NOLINT +"}", // NOLINT +"", // NOLINT +"for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {", // NOLINT "barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT "if(lid < stride)", // NOLINT "work[lid] += work[lid+stride];", // NOLINT "}", // NOLINT "", // NOLINT "if(lid == 0) {", // NOLINT -"if(beta == (Dtype)0) {", // NOLINT -"result[row_gid+row_offset] = alpha * work[0];", // NOLINT -"} else {", // NOLINT -"result[row_gid+row_offset] *= beta;", // NOLINT -"result[row_gid+row_offset] += alpha * work[0];", // NOLINT +"// vstore4(alpha * work[0] + beta * vload4(row_gid, result), row_gid, result);", // NOLINT +"result[row_gid*4*incr] = alpha * work[0].s0 + beta * result[row_gid*4*incr];", // NOLINT +"result[(row_gid*4+1)*incr] = alpha * work[0].s1 + beta * result[(row_gid*4+1)*incr];", // NOLINT +"result[(row_gid*4+2)*incr] = alpha * work[0].s2 + beta * result[(row_gid*4+2)*incr];", // NOLINT +"result[(row_gid*4+3)*incr] = alpha * work[0].s3 + beta * result[(row_gid*4+3)*incr];", // NOLINT +"}", // NOLINT +"}", // NOLINT "}", // NOLINT +"", // NOLINT +"__kernel void TEMPLATE(trans_matvec_mul,Dtype)(", // NOLINT +"unsigned int M,", // NOLINT +"unsigned int N,", // NOLINT +"__global const Dtype * A,", // NOLINT +"int offA,", // NOLINT +"int lda,", // NOLINT +"__global const Dtype * v,", // NOLINT +"int offv,", // NOLINT +"int incv,", // NOLINT +"KERNEL_ARG_DTYPE alpha,", // NOLINT +"KERNEL_ARG_DTYPE beta,", // NOLINT +"__global Dtype * result,", // NOLINT +"int offr,", // NOLINT +"int incr)", // NOLINT +"{", // NOLINT +"int col_gid = get_global_id(0);", // NOLINT +"A += offA + col_gid;", // NOLINT +"v += offv;", // NOLINT +"result += offr;", // NOLINT +"", // NOLINT +"v += incv > 0 ? 0 : (1 - M) * incv;", // NOLINT +"result += incr > 0 ? 0 : (1 - N) * incr;", // NOLINT +"", // NOLINT +"Dtype dot_prod = 0;", // NOLINT +"int row_id = 0;", // NOLINT +"#pragma unroll", // NOLINT +"for(int row = 0; row < M; ++row) {", // NOLINT +"dot_prod += A[row_id] * v[row * incv];", // NOLINT +"row_id += lda;", // NOLINT "}", // NOLINT +"result[col_gid * incr] = beta * result[col_gid * incr];", // NOLINT +"result[col_gid * incr] += alpha * dot_prod;", // NOLINT "}", // NOLINT ""}, // NOLINT {"#ifndef __OPENCL_VERSION__", // NOLINT @@ -5928,7 +6202,7 @@ static std::vector> cl_kernels{ "const int_tp wend = min(wstart + kernel_w, width);", // NOLINT "hstart = max(hstart, (int_tp)0);", // NOLINT "wstart = max(wstart, (int_tp)0);", // NOLINT -"Dtype maxval = -FLT_MAX;", // NOLINT +"Dtype maxval = -DTYPE_MAX;", // NOLINT "int_tp maxidx = -1;", // NOLINT "__global const Dtype* bottom_slice = bottom_data", // NOLINT "+ (n * channels + c) * height * width;", // NOLINT @@ -6043,7 +6317,7 @@ static std::vector> cl_kernels{ "cumsum += bottom_slice[h * width + w];", // NOLINT "}", // NOLINT "}", // NOLINT -"const float thres = rand_idx[index] * cumsum;", // NOLINT +"const Dtype thres = rand_idx[index] * cumsum;", // NOLINT "// Second pass: get value, and set index.", // NOLINT "cumsum = 0;", // NOLINT "for (int_tp h = hstart; h < hend; ++h) {", // NOLINT @@ -6077,7 +6351,7 @@ static std::vector> cl_kernels{ "const int_tp wstart = pw * stride_w;", // NOLINT "const int_tp wend = min(wstart + kernel_w, width);", // NOLINT "// We set cumsum to be 0 to avoid divide-by-zero problems", // NOLINT -"Dtype cumsum = FLT_MIN;", // NOLINT +"Dtype cumsum = DTYPE_MIN;", // NOLINT "Dtype cumvalues = 0.;", // NOLINT "__global const Dtype* bottom_slice = bottom_data", // NOLINT "+ (n * channels + c) * height * width;", // NOLINT @@ -6273,7 +6547,7 @@ static std::vector> cl_kernels{ "d_iter[i] = d_start[i];", // NOLINT "", // NOLINT "if (d_start[i] >= d_end[i]) {", // NOLINT -"top_data[index] = -FLT_MAX;", // NOLINT +"top_data[index] = -DTYPE_MAX;", // NOLINT "if (use_mask) {", // NOLINT "mask[index] = -1;", // NOLINT "} else {", // NOLINT @@ -6291,7 +6565,7 @@ static std::vector> cl_kernels{ "num /= channels;", // NOLINT "offset *= (num * channels + chan);", // NOLINT "", // NOLINT -"Dtype maxval = -FLT_MAX;", // NOLINT +"Dtype maxval = -DTYPE_MAX;", // NOLINT "int_tp maxidx = -1;", // NOLINT "int_tp final_offset = 0;", // NOLINT "", // NOLINT @@ -6470,7 +6744,7 @@ static std::vector> cl_kernels{ "while (wstart < 0) {", // NOLINT "wstart += dilation_w;", // NOLINT "}", // NOLINT -"Dtype maxval = -FLT_MAX;", // NOLINT +"Dtype maxval = -DTYPE_MAX;", // NOLINT "int_tp maxidx = -1;", // NOLINT "__global Dtype* bottom_data_ptr = bottom_data", // NOLINT "+ (n * channels + c) * height * width;", // NOLINT @@ -6682,7 +6956,7 @@ static std::vector> cl_kernels{ "cumsum += bottom_data_ptr[h * width + w];", // NOLINT "}", // NOLINT "}", // NOLINT -"float thres = rand_idx[index] * cumsum;", // NOLINT +"Dtype thres = rand_idx[index] * cumsum;", // NOLINT "// Second pass: get value, and set index.", // NOLINT "cumsum = 0;", // NOLINT "for (int_tp h = hstart; h < hend; h += dilation_h) {", // NOLINT @@ -6719,7 +6993,7 @@ static std::vector> cl_kernels{ "int_tp wstart = pw * stride_w;", // NOLINT "int_tp wend = min(wstart + ext_kernel_w, width);", // NOLINT "// We set cumsum to be 0 to avoid divide-by-zero problems", // NOLINT -"Dtype cumsum = FLT_MIN;", // NOLINT +"Dtype cumsum = DTYPE_MIN;", // NOLINT "Dtype cumvalues = 0.;", // NOLINT "__global const Dtype* bottom_data_ptr = bottom_data;", // NOLINT "bottom_data_ptr += (n * channels + c) * height * width;", // NOLINT @@ -6786,7 +7060,7 @@ static std::vector> cl_kernels{ "int_tp n = get_global_id(1);", // NOLINT "for (int_tp index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index +=", // NOLINT "get_global_size(0), ++s) {", // NOLINT -"float maxval = -FLT_MAX;", // NOLINT +"Dtype maxval = -DTYPE_MAX;", // NOLINT "for (int_tp c = get_global_id(0); c < channels; c += get_global_size(0)) {", // NOLINT "Dtype tmp = data[(n * channels + c) * spatial_dim + s];", // NOLINT "maxval = max((Dtype)tmp, (Dtype)maxval);", // NOLINT @@ -6852,7 +7126,7 @@ static std::vector> cl_kernels{ "__global Dtype *group_tmp = scale + spatial_dim * num + n * get_max_sub_group_size() * spatial_dim;", // NOLINT "for (int_tp index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index +=", // NOLINT "get_global_size(0), ++s) {", // NOLINT -"float maxval = -FLT_MAX;", // NOLINT +"Dtype maxval = -DTYPE_MAX;", // NOLINT "for (int_tp c = get_global_id(0); c < channels; c += get_global_size(0)) {", // NOLINT "Dtype tmp = data[(n * channels + c) * spatial_dim + s];", // NOLINT "maxval = max((Dtype)tmp, (Dtype)maxval);", // NOLINT @@ -6996,7 +7270,7 @@ static std::vector> cl_kernels{ "} else {", // NOLINT "loss[index] = -log((Dtype)(", // NOLINT "max((Dtype) (prob_data[n * dim + label_value * spatial_dim + s]),", // NOLINT -"(Dtype) FLT_MIN)));", // NOLINT +"(Dtype) DTYPE_MIN)));", // NOLINT "counts[index] = 1;", // NOLINT "}", // NOLINT "}", // NOLINT @@ -7037,22 +7311,22 @@ static std::vector> cl_kernels{ "__kernel void TEMPLATE(ada_delta_update,Dtype)(int_tp N, __global Dtype* g,", // NOLINT "__global Dtype* h,", // NOLINT "__global Dtype* h2,", // NOLINT -"Dtype momentum,", // NOLINT -"Dtype delta,", // NOLINT -"Dtype local_rate) {", // NOLINT +"KERNEL_ARG_DTYPE momentum,", // NOLINT +"KERNEL_ARG_DTYPE delta,", // NOLINT +"KERNEL_ARG_DTYPE local_rate) {", // NOLINT "for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) {", // NOLINT "Dtype gi = g[i];", // NOLINT -"Dtype hi = h[i] = momentum * h[i] + (1.0 - momentum) * gi * gi;", // NOLINT +"Dtype hi = h[i] = momentum * h[i] + ((Dtype)1.0 - momentum) * gi * gi;", // NOLINT "gi = gi * sqrt((h2[i] + delta) / (hi + delta));", // NOLINT -"h2[i] = momentum * h2[i] + (1.0 - momentum) * gi * gi;", // NOLINT +"h2[i] = momentum * h2[i] + ((Dtype)1.0 - momentum) * gi * gi;", // NOLINT "g[i] = local_rate * gi;", // NOLINT "}", // NOLINT "}", // NOLINT "", // NOLINT "__kernel void TEMPLATE(ada_grad_update,Dtype)(int_tp N, __global Dtype* g,", // NOLINT "__global Dtype* h,", // NOLINT -"Dtype delta,", // NOLINT -"Dtype local_rate) {", // NOLINT +"KERNEL_ARG_DTYPE delta,", // NOLINT +"KERNEL_ARG_DTYPE local_rate) {", // NOLINT "for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) {", // NOLINT "Dtype gi = g[i];", // NOLINT "Dtype hi = h[i] = h[i] + gi * gi;", // NOLINT @@ -7063,10 +7337,10 @@ static std::vector> cl_kernels{ "__kernel void TEMPLATE(adam_update,Dtype)(int_tp N, __global Dtype* g,", // NOLINT "__global Dtype* m,", // NOLINT "__global Dtype* v,", // NOLINT -"Dtype beta1,", // NOLINT -"Dtype beta2,", // NOLINT -"Dtype eps_hat,", // NOLINT -"Dtype corrected_local_rate) {", // NOLINT +"KERNEL_ARG_DTYPE beta1,", // NOLINT +"KERNEL_ARG_DTYPE beta2,", // NOLINT +"KERNEL_ARG_DTYPE eps_hat,", // NOLINT +"KERNEL_ARG_DTYPE corrected_local_rate) {", // NOLINT "for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) {", // NOLINT "Dtype gi = g[i];", // NOLINT "Dtype mi = m[i] = m[i] * beta1 + gi * (1 - beta1);", // NOLINT @@ -7078,8 +7352,8 @@ static std::vector> cl_kernels{ "", // NOLINT "__kernel void TEMPLATE(nesterov_update,Dtype)(int_tp N, __global Dtype* g,", // NOLINT "__global Dtype* h,", // NOLINT -"Dtype momentum,", // NOLINT -"Dtype local_rate) {", // NOLINT +"KERNEL_ARG_DTYPE momentum,", // NOLINT +"KERNEL_ARG_DTYPE local_rate) {", // NOLINT "for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) {", // NOLINT "Dtype hi = h[i];", // NOLINT "Dtype hi_new = h[i] = momentum * hi + local_rate * g[i];", // NOLINT @@ -7089,9 +7363,9 @@ static std::vector> cl_kernels{ "", // NOLINT "__kernel void TEMPLATE(rms_prop_update,Dtype)(int_tp N, __global Dtype* g,", // NOLINT "__global Dtype* h,", // NOLINT -"Dtype rms_decay,", // NOLINT -"Dtype delta,", // NOLINT -"Dtype local_rate) {", // NOLINT +"KERNEL_ARG_DTYPE rms_decay,", // NOLINT +"KERNEL_ARG_DTYPE delta,", // NOLINT +"KERNEL_ARG_DTYPE local_rate) {", // NOLINT "for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) {", // NOLINT "Dtype gi = g[i];", // NOLINT "Dtype hi = h[i] = rms_decay * h[i] + (1 - rms_decay) * gi * gi;", // NOLINT @@ -7101,8 +7375,8 @@ static std::vector> cl_kernels{ "", // NOLINT "__kernel void TEMPLATE(sgd_update,Dtype)(int_tp N, __global Dtype* g,", // NOLINT "__global Dtype* h,", // NOLINT -"Dtype momentum,", // NOLINT -"Dtype local_rate) {", // NOLINT +"KERNEL_ARG_DTYPE momentum,", // NOLINT +"KERNEL_ARG_DTYPE local_rate) {", // NOLINT "for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) {", // NOLINT "g[i] = h[i] = momentum * h[i] + local_rate * g[i];", // NOLINT "}", // NOLINT @@ -7199,7 +7473,15 @@ viennacl::ocl::program & RegisterKernels(viennacl::ocl::context *ctx) { ss << "#define Dtype4 float4" << "\n\n"; // NOLINT ss << "#define Dtype8 float8" << "\n\n"; // NOLINT ss << "#define Dtype16 float16" << "\n\n"; // NOLINT + ss << "#define as_Dtype as_float" << "\n\n"; // NOLINT + ss << "#define as_Dtype2 as_float2" << "\n\n"; // NOLINT + ss << "#define as_Dtype4 as_float4" << "\n\n"; // NOLINT + ss << "#define as_Dtype8 as_float8" << "\n\n"; // NOLINT + ss << "#define as_Dtype16 as_float16" << "\n\n"; // NOLINT ss << "#define TYPE TYPE_FLOAT" << "\n\n"; // NOLINT + ss << "#define KERNEL_ARG_DTYPE float" << "\n\n"; // NOLINT + ss << "#define DTYPE_MAX FLT_MAX" << "\n\n"; // NOLINT + ss << "#define DTYPE_MIN FLT_MIN" << "\n\n"; // NOLINT for (int i = 0; i < cl_kernels.size(); ++i) { for (int j = 0; j < cl_kernels[i].size(); ++j) { ss << cl_kernels[i][j] << "\n\n"; @@ -7216,8 +7498,24 @@ viennacl::ocl::program & RegisterKernels(viennacl::ocl::context *ctx) { ss << "#define Dtype4 double4" << "\n\n"; // NOLINT ss << "#define Dtype8 double8" << "\n\n"; // NOLINT ss << "#define Dtype16 double16" << "\n\n"; // NOLINT + ss << "#undef as_Dtype" << "\n\n"; // NOLINT + ss << "#undef as_Dtype2" << "\n\n"; // NOLINT + ss << "#undef as_Dtype4" << "\n\n"; // NOLINT + ss << "#undef as_Dtype8" << "\n\n"; // NOLINT + ss << "#undef as_Dtype16" << "\n\n"; // NOLINT + ss << "#define as_Dtype as_double" << "\n\n"; // NOLINT + ss << "#define as_Dtype2 as_double2" << "\n\n"; // NOLINT + ss << "#define as_Dtype4 as_double4" << "\n\n"; // NOLINT + ss << "#define as_Dtype8 as_double8" << "\n\n"; // NOLINT + ss << "#define as_Dtype16 as_double16" << "\n\n"; // NOLINT ss << "#undef TYPE" << "\n\n"; // NOLINT ss << "#define TYPE TYPE_DOUBLE" << "\n\n"; // NOLINT + ss << "#undef KERNEL_ARG_DTYPE" << "\n\n"; // NOLINT + ss << "#define KERNEL_ARG_DTYPE double" << "\n\n"; // NOLINT + ss << "#undef DTYPE_MAX" << "\n\n"; // NOLINT + ss << "#undef DTYPE_MIN" << "\n\n"; // NOLINT + ss << "#define DTYPE_MAX FLT_MAX" << "\n\n"; // NOLINT + ss << "#define DTYPE_MIN FLT_MIN" << "\n\n"; // NOLINT for (int i = 0; i < cl_kernels.size(); ++i) { if (cl_kernel_names[i] != std::string("fft")) { for (int j = 0; j < cl_kernels[i].size(); ++j) { @@ -7226,11 +7524,51 @@ viennacl::ocl::program & RegisterKernels(viennacl::ocl::context *ctx) { } } ss << "#endif // DOUBLE_SUPPORT_AVAILABLE" << "\n\n"; // NOLINT + ss << "#if defined(HALF_SUPPORT_AVAILABLE) && defined(HAS_HALF_SUPPORT)" << "\n\n"; // NOLINT + ss << "#undef Dtype" << "\n\n"; // NOLINT + ss << "#undef Dtype2" << "\n\n"; // NOLINT + ss << "#undef Dtype4" << "\n\n"; // NOLINT + ss << "#undef Dtype8" << "\n\n"; // NOLINT + ss << "#undef Dtype16" << "\n\n"; // NOLINT + ss << "#define Dtype half" << "\n\n"; // NOLINT + ss << "#define Dtype2 half2" << "\n\n"; // NOLINT + ss << "#define Dtype4 half4" << "\n\n"; // NOLINT + ss << "#define Dtype8 half8" << "\n\n"; // NOLINT + ss << "#define Dtype16 half16" << "\n\n"; // NOLINT + ss << "#undef as_Dtype" << "\n\n"; // NOLINT + ss << "#undef as_Dtype2" << "\n\n"; // NOLINT + ss << "#undef as_Dtype4" << "\n\n"; // NOLINT + ss << "#undef as_Dtype8" << "\n\n"; // NOLINT + ss << "#undef as_Dtype16" << "\n\n"; // NOLINT + ss << "#define as_Dtype as_half" << "\n\n"; // NOLINT + ss << "#define as_Dtype2 as_half2" << "\n\n"; // NOLINT + ss << "#define as_Dtype4 as_half4" << "\n\n"; // NOLINT + ss << "#define as_Dtype8 as_half8" << "\n\n"; // NOLINT + ss << "#define as_Dtype16 as_half16" << "\n\n"; // NOLINT + ss << "#undef TYPE" << "\n\n"; // NOLINT + ss << "#define TYPE TYPE_HALF" << "\n\n"; // NOLINT + ss << "#undef KERNEL_ARG_DTYPE" << "\n\n"; // NOLINT + ss << "#undef DTYPE_MAX" << "\n\n"; // NOLINT + ss << "#undef DTYPE_MIN" << "\n\n"; // NOLINT + ss << "#define DTYPE_MAX HALF_MAX" << "\n\n"; // NOLINT + ss << "#define DTYPE_MIN HALF_MIN" << "\n\n"; // NOLINT + ss << "#define KERNEL_ARG_DTYPE float" << "\n\n"; // NOLINT + for (int i = 0; i < cl_kernels.size(); ++i) { + if (cl_kernel_names[i] != std::string("fft")) { + for (int j = 0; j < cl_kernels[i].size(); ++j) { + ss << cl_kernels[i][j] << "\n\n"; + } + } + } + ss << "#endif // HALF_SUPPORT_AVAILABLE" << "\n\n"; // NOLINT std::string kernel_string = ss.str(); const char* kernel_program = kernel_string.c_str(); string options; #ifdef USE_FFT - options = " -DFFT " + options = " -DFFT "; +#endif +#ifdef HAS_HALF_SUPPORT + options += " -DHAS_HALF_SUPPORT; "; #endif bool is_beignet = ctx->devices()[0].opencl_c_version().find("beignet") != std::string::npos; @@ -7241,18 +7579,45 @@ viennacl::ocl::program & RegisterKernels(viennacl::ocl::context *ctx) { "kernel_program"); return program; } +template viennacl::ocl::program & submit_conv_spatial_program( viennacl::ocl::context *ctx, string name, string options) { - static const char* core_defines = + static const char* float_core_defines = "#define Dtype float\n" "#define Dtype2 float2\n" "#define Dtype4 float4\n" "#define Dtype8 float8\n" "#define Dtype16 float16\n" - "#define OCL_KERNEL_LOOP(i, n)" - " for (int i = get_global_id(0); i < (n); i += get_global_size(0))\n"; + "#define as_Dtype as_float\n" + "#define as_Dtype2 as_float2\n" + "#define as_Dtype4 as_float4\n" + "#define as_Dtype8 as_float8\n" + "#define as_Dtype16 as_float16\n" + "#define TYPE TYPE_FLOAT\n" + "#define DTYPE_MAX FLT_MAX\n" + "#define DTYPE_MIN FLT_MIN\n" + "#define KERNEL_ARG_DTYPE float\n"; + static const char* half_core_defines = + "#define Dtype half\n" + "#define Dtype2 half2\n" + "#define Dtype4 half4\n" + "#define Dtype8 half8\n" + "#define Dtype16 half16\n" + "#define as_Dtype as_half\n" + "#define as_Dtype2 as_half2\n" + "#define as_Dtype4 as_half4\n" + "#define as_Dtype8 as_half8\n" + "#define as_Dtype16 as_half16\n" + "#define TYPE TYPE_HALF\n" + "#define DTYPE_MAX HALF_MAX\n" + "#define DTYPE_MIN HALF_MIN\n" + "#define KERNEL_ARG_DTYPE float\n"; std::stringstream ss; - ss << core_defines; + if (std::is_same::value) { + ss << float_core_defines; + } else { + ss << half_core_defines; + } #ifdef USE_INDEX_64 ss << header + "\n"; ss << definitions_64 + "\n"; @@ -7275,6 +7640,15 @@ viennacl::ocl::context *ctx, string name, string options) { viennacl::ocl::program &program = ctx->add_program(ss.str(), name); return program; } +template +viennacl::ocl::program & submit_conv_spatial_program( +viennacl::ocl::context *ctx, string name, string options); +template +viennacl::ocl::program & submit_conv_spatial_program( +viennacl::ocl::context *ctx, string name, string options); +template +viennacl::ocl::program & submit_conv_spatial_program( +viennacl::ocl::context *ctx, string name, string options); int getKernelBundleCount() { return cl_kernels.size(); } @@ -7295,7 +7669,15 @@ std::string getKernelBundleSource(int index) { ss << "#define Dtype8 float8" << "\n\n"; // NOLINT ss << "#define Dtype16 float16" << "\n\n"; // NOLINT ss << "#define TYPE TYPE_FLOAT" << "\n\n"; // NOLINT - } else { + ss << "#define as_Dtype as_float" << "\n\n"; // NOLINT + ss << "#define as_Dtype2 as_float2" << "\n\n"; // NOLINT + ss << "#define as_Dtype4 as_float4" << "\n\n"; // NOLINT + ss << "#define as_Dtype8 as_float8" << "\n\n"; // NOLINT + ss << "#define as_Dtype16 as_float16" << "\n\n"; // NOLINT + ss << "#define KERNEL_ARG_DTYPE float" << "\n\n"; // NOLINT + ss << "#define DTYPE_MAX FLT_MAX" << "\n\n"; // NOLINT + ss << "#define DTYPE_MIN FLT_MIN" << "\n\n"; // NOLINT + } else if (std::is_same::value) { ss << "#ifdef DOUBLE_SUPPORT_AVAILABLE" << "\n\n"; // NOLINT ss << "#define Dtype double" << "\n\n"; // NOLINT ss << "#define Dtype2 double2" << "\n\n"; // NOLINT @@ -7303,6 +7685,30 @@ std::string getKernelBundleSource(int index) { ss << "#define Dtype8 double8" << "\n\n"; // NOLINT ss << "#define Dtype16 double16" << "\n\n"; // NOLINT ss << "#define TYPE TYPE_DOUBLE" << "\n\n"; // NOLINT + ss << "#define as_Dtype as_double" << "\n\n"; // NOLINT + ss << "#define as_Dtype2 as_double2" << "\n\n"; // NOLINT + ss << "#define as_Dtype4 as_double4" << "\n\n"; // NOLINT + ss << "#define as_Dtype8 as_double8" << "\n\n"; // NOLINT + ss << "#define as_Dtype16 as_double16" << "\n\n"; // NOLINT + ss << "#define KERNEL_ARG_DTYPE double" << "\n\n"; // NOLINT + ss << "#define DTYPE_MAX FLT_MAX" << "\n\n"; // NOLINT + ss << "#define DTYPE_MIN FLT_MIN" << "\n\n"; // NOLINT + } else { + ss << "#if defined(HALF_SUPPORT_AVAILABLE) && defined(HAS_HALF_SUPPORT)" << "\n\n"; // NOLINT + ss << "#define Dtype half" << "\n\n"; // NOLINT + ss << "#define Dtype2 half2" << "\n\n"; // NOLINT + ss << "#define Dtype4 half4" << "\n\n"; // NOLINT + ss << "#define Dtype8 half8" << "\n\n"; // NOLINT + ss << "#define Dtype16 half16" << "\n\n"; // NOLINT + ss << "#define TYPE TYPE_HALF" << "\n\n"; // NOLINT + ss << "#define as_Dtype as_half" << "\n\n"; // NOLINT + ss << "#define as_Dtype2 as_half2" << "\n\n"; // NOLINT + ss << "#define as_Dtype4 as_half4" << "\n\n"; // NOLINT + ss << "#define as_Dtype8 as_half8" << "\n\n"; // NOLINT + ss << "#define as_Dtype16 as_half16" << "\n\n"; // NOLINT + ss << "#define KERNEL_ARG_DTYPE float" << "\n\n"; // NOLINT + ss << "#define DTYPE_MAX HALF_MAX" << "\n\n"; // NOLINT + ss << "#define DTYPE_MIN HALF_MIN" << "\n\n"; // NOLINT } for (int j = 0; j < cl_kernels[index].size(); ++j) { ss << cl_kernels[index][j] << "\n\n"; @@ -7313,6 +7719,7 @@ std::string getKernelBundleSource(int index) { } return ss.str(); } +template std::string getKernelBundleSource(int index); template std::string getKernelBundleSource(int index); template std::string getKernelBundleSource(int index); std::string getKernelBundleName(int index) { diff --git a/src/caffe/greentea/cl_kernels.sh b/src/caffe/greentea/cl_kernels.sh index 5c1fd66b129..bea5c919666 100755 --- a/src/caffe/greentea/cl_kernels.sh +++ b/src/caffe/greentea/cl_kernels.sh @@ -42,6 +42,7 @@ echo "#endif // DISABLE_DOUBLE_SUPPORT" >> $SOURCE echo "namespace caffe {" >> $SOURCE echo "viennacl::ocl::program & RegisterKernels(viennacl::ocl::context *ctx);" >> $HEADER +echo "template " >> $HEADER echo "viennacl::ocl::program & submit_conv_spatial_program(" >> $HEADER echo "viennacl::ocl::context *ctx, string name, string options);" >> $HEADER echo "std::string getKernelBundleName(int index);" >> $HEADER @@ -154,8 +155,15 @@ echo " ss << \"#define Dtype2 float2\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype4 float4\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype8 float8\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype16 float16\" << \"\\n\\n\"; // NOLINT" >> $SOURCE - +echo " ss << \"#define as_Dtype as_float\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype2 as_float2\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype4 as_float4\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype8 as_float8\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype16 as_float16\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define TYPE TYPE_FLOAT\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define KERNEL_ARG_DTYPE float\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MAX FLT_MAX\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MIN FLT_MIN\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " for (int i = 0; i < cl_kernels.size(); ++i) {" >> $SOURCE echo " for (int j = 0; j < cl_kernels[i].size(); ++j) {" >> $SOURCE echo " ss << cl_kernels[i][j] << \"\n\n\";" >> $SOURCE @@ -173,8 +181,24 @@ echo " ss << \"#define Dtype2 double2\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype4 double4\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype8 double8\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype16 double16\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype2\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype4\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype8\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype16\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype as_double\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype2 as_double2\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype4 as_double4\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype8 as_double8\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype16 as_double16\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " ss << \"#undef TYPE\" << \"\\n\\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define TYPE TYPE_DOUBLE\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef KERNEL_ARG_DTYPE\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define KERNEL_ARG_DTYPE double\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef DTYPE_MAX\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef DTYPE_MIN\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MAX FLT_MAX\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MIN FLT_MIN\" << \"\\n\\n\"; // NOLINT" >> $SOURCE shopt -s nullglob echo " for (int i = 0; i < cl_kernels.size(); ++i) {" >> $SOURCE @@ -186,11 +210,55 @@ echo " }" >> $SOURCE echo " }" >> $SOURCE echo " ss << \"#endif // DOUBLE_SUPPORT_AVAILABLE\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#if defined(HALF_SUPPORT_AVAILABLE) && defined(HAS_HALF_SUPPORT)\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef Dtype\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef Dtype2\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef Dtype4\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef Dtype8\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef Dtype16\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype half\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype2 half2\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype4 half4\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype8 half8\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype16 half16\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype2\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype4\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype8\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef as_Dtype16\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype as_half\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype2 as_half2\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype4 as_half4\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype8 as_half8\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype16 as_half16\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef TYPE\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define TYPE TYPE_HALF\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef KERNEL_ARG_DTYPE\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef DTYPE_MAX\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#undef DTYPE_MIN\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MAX HALF_MAX\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MIN HALF_MIN\" << \"\\n\\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define KERNEL_ARG_DTYPE float\" << \"\\n\\n\"; // NOLINT" >> $SOURCE + +shopt -s nullglob +echo " for (int i = 0; i < cl_kernels.size(); ++i) {" >> $SOURCE +echo " if (cl_kernel_names[i] != std::string(\"fft\")) {" >> $SOURCE +echo " for (int j = 0; j < cl_kernels[i].size(); ++j) {" >> $SOURCE +echo " ss << cl_kernels[i][j] << \"\n\n\";" >> $SOURCE +echo " }" >> $SOURCE +echo " }" >> $SOURCE +echo " }" >> $SOURCE +echo " ss << \"#endif // HALF_SUPPORT_AVAILABLE\" << \"\\n\\n\"; // NOLINT" >> $SOURCE + + echo " std::string kernel_string = ss.str();" >> $SOURCE echo " const char* kernel_program = kernel_string.c_str();" >> $SOURCE echo " string options;" >> $SOURCE echo "#ifdef USE_FFT" >> $SOURCE -echo " options = \" -DFFT \"" >> $SOURCE +echo " options = \" -DFFT \";" >> $SOURCE +echo "#endif" >> $SOURCE +echo "#ifdef HAS_HALF_SUPPORT" >> $SOURCE +echo " options += \" -DHAS_HALF_SUPPORT; \";" >> $SOURCE echo "#endif" >> $SOURCE echo " bool is_beignet = ctx->devices()[0].opencl_c_version().find(\"beignet\")" >> $SOURCE echo " != std::string::npos;" >> $SOURCE @@ -201,18 +269,47 @@ echo " viennacl::ocl::program &program = ctx->add_program(kernel_program," >> $ echo " \"kernel_program\");" >> $SOURCE echo " return program;" >> $SOURCE echo "}" >> $SOURCE +echo "template" >> $SOURCE echo "viennacl::ocl::program & submit_conv_spatial_program(" >> $SOURCE echo "viennacl::ocl::context *ctx, string name, string options) {" >> $SOURCE -echo " static const char* core_defines =" >> $SOURCE +echo " static const char* float_core_defines =" >> $SOURCE echo " \"#define Dtype float\n\"" >> $SOURCE echo " \"#define Dtype2 float2\n\"" >> $SOURCE echo " \"#define Dtype4 float4\n\"" >> $SOURCE echo " \"#define Dtype8 float8\n\"" >> $SOURCE echo " \"#define Dtype16 float16\n\"" >> $SOURCE -echo " \"#define OCL_KERNEL_LOOP(i, n)\"" >> $SOURCE -echo " \" for (int i = get_global_id(0); i < (n); i += get_global_size(0))\n\";" >> $SOURCE +echo " \"#define as_Dtype as_float\n\"" >> $SOURCE +echo " \"#define as_Dtype2 as_float2\n\"" >> $SOURCE +echo " \"#define as_Dtype4 as_float4\n\"" >> $SOURCE +echo " \"#define as_Dtype8 as_float8\n\"" >> $SOURCE +echo " \"#define as_Dtype16 as_float16\n\"" >> $SOURCE +echo " \"#define TYPE TYPE_FLOAT\n\"" >> $SOURCE +echo " \"#define DTYPE_MAX FLT_MAX\n\"" >> $SOURCE +echo " \"#define DTYPE_MIN FLT_MIN\n\"" >> $SOURCE +echo " \"#define KERNEL_ARG_DTYPE float\n\";" >> $SOURCE + +echo " static const char* half_core_defines =" >> $SOURCE +echo " \"#define Dtype half\n\"" >> $SOURCE +echo " \"#define Dtype2 half2\n\"" >> $SOURCE +echo " \"#define Dtype4 half4\n\"" >> $SOURCE +echo " \"#define Dtype8 half8\n\"" >> $SOURCE +echo " \"#define Dtype16 half16\n\"" >> $SOURCE +echo " \"#define as_Dtype as_half\n\"" >> $SOURCE +echo " \"#define as_Dtype2 as_half2\n\"" >> $SOURCE +echo " \"#define as_Dtype4 as_half4\n\"" >> $SOURCE +echo " \"#define as_Dtype8 as_half8\n\"" >> $SOURCE +echo " \"#define as_Dtype16 as_half16\n\"" >> $SOURCE +echo " \"#define TYPE TYPE_HALF\n\"" >> $SOURCE +echo " \"#define DTYPE_MAX HALF_MAX\n\"" >> $SOURCE +echo " \"#define DTYPE_MIN HALF_MIN\n\"" >> $SOURCE +echo " \"#define KERNEL_ARG_DTYPE float\n\";" >> $SOURCE + echo " std::stringstream ss;" >> $SOURCE -echo " ss << core_defines;" >> $SOURCE +echo " if (std::is_same::value) {" >> $SOURCE +echo " ss << float_core_defines;" >> $SOURCE +echo " } else {" >> $SOURCE +echo " ss << half_core_defines;" >> $SOURCE +echo " }" >> $SOURCE echo "#ifdef USE_INDEX_64" >> $SOURCE echo " ss << header + \"\n\";" >> $SOURCE echo " ss << definitions_64 + \"\n\";" >> $SOURCE @@ -235,6 +332,16 @@ echo " ctx->build_options(options);" >> $SOURCE echo " viennacl::ocl::program &program = ctx->add_program(ss.str(), name);" >> $SOURCE echo " return program;" >> $SOURCE echo "}" >> $SOURCE + +echo "template" >> $SOURCE +echo "viennacl::ocl::program & submit_conv_spatial_program(" >> $SOURCE +echo "viennacl::ocl::context *ctx, string name, string options);" >> $SOURCE +echo "template" >> $SOURCE +echo "viennacl::ocl::program & submit_conv_spatial_program(" >> $SOURCE +echo "viennacl::ocl::context *ctx, string name, string options);" >> $SOURCE +echo "template" >> $SOURCE +echo "viennacl::ocl::program & submit_conv_spatial_program(" >> $SOURCE +echo "viennacl::ocl::context *ctx, string name, string options);" >> $SOURCE echo "int getKernelBundleCount() {" >> $SOURCE echo " return cl_kernels.size();" >> $SOURCE echo "}" >> $SOURCE @@ -255,7 +362,15 @@ echo " ss << \"#define Dtype4 float4\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype8 float8\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype16 float16\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define TYPE TYPE_FLOAT\" << \"\n\n\"; // NOLINT" >> $SOURCE -echo " } else {" >> $SOURCE +echo " ss << \"#define as_Dtype as_float\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype2 as_float2\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype4 as_float4\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype8 as_float8\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype16 as_float16\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define KERNEL_ARG_DTYPE float\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MAX FLT_MAX\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MIN FLT_MIN\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " } else if (std::is_same::value) {" >> $SOURCE echo " ss << \"#ifdef DOUBLE_SUPPORT_AVAILABLE\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype double\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype2 double2\" << \"\n\n\"; // NOLINT" >> $SOURCE @@ -263,6 +378,30 @@ echo " ss << \"#define Dtype4 double4\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype8 double8\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define Dtype16 double16\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " ss << \"#define TYPE TYPE_DOUBLE\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype as_double\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype2 as_double2\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype4 as_double4\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype8 as_double8\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype16 as_double16\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define KERNEL_ARG_DTYPE double\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MAX FLT_MAX\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MIN FLT_MIN\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " } else {" >> $SOURCE +echo " ss << \"#if defined(HALF_SUPPORT_AVAILABLE) && defined(HAS_HALF_SUPPORT)\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype half\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype2 half2\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype4 half4\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype8 half8\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define Dtype16 half16\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define TYPE TYPE_HALF\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype as_half\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype2 as_half2\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype4 as_half4\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype8 as_half8\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define as_Dtype16 as_half16\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define KERNEL_ARG_DTYPE float\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MAX HALF_MAX\" << \"\n\n\"; // NOLINT" >> $SOURCE +echo " ss << \"#define DTYPE_MIN HALF_MIN\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " }" >> $SOURCE echo " for (int j = 0; j < cl_kernels[index].size(); ++j) {" >> $SOURCE echo " ss << cl_kernels[index][j] << \"\n\n\";" >> $SOURCE @@ -273,6 +412,7 @@ echo " ss << \"#endif\" << \"\n\n\"; // NOLINT" >> $SOURCE echo " }" >> $SOURCE echo " return ss.str();" >> $SOURCE echo "}" >> $SOURCE +echo "template std::string getKernelBundleSource(int index);" >> $SOURCE echo "template std::string getKernelBundleSource(int index);" >> $SOURCE echo "template std::string getKernelBundleSource(int index);" >> $SOURCE echo "std::string getKernelBundleName(int index) {" >> $SOURCE diff --git a/src/caffe/greentea/cl_kernels/activation.cl b/src/caffe/greentea/cl_kernels/activation.cl index 6f0eaedc4e1..cb5a2834717 100644 --- a/src/caffe/greentea/cl_kernels/activation.cl +++ b/src/caffe/greentea/cl_kernels/activation.cl @@ -5,7 +5,7 @@ __kernel void TEMPLATE(relu_forward,Dtype)(const int_tp n, __global const Dtype* in, __global Dtype* out, - Dtype negative_slope) { + KERNEL_ARG_DTYPE negative_slope) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { out[index] = in[index] > 0 ? in[index] : in[index] * negative_slope; } @@ -15,7 +15,7 @@ __kernel void TEMPLATE(relu_backward,Dtype)(const int_tp n, __global const Dtype* in_diff, __global const Dtype* in_data, __global Dtype* out_diff, - Dtype negative_slope) { + KERNEL_ARG_DTYPE negative_slope) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { out_diff[index] = in_diff[index] * ((Dtype)(in_data[index] > 0?1.0:0.0) + (Dtype)(in_data[index] <= 0?1.0:0.0) * negative_slope); @@ -58,7 +58,7 @@ __kernel void TEMPLATE(sigmoid_backward,Dtype)(const int_tp n, } } -__kernel void TEMPLATE(threshold,Dtype)(const int_tp n, const Dtype threshold, +__kernel void TEMPLATE(threshold,Dtype)(const int_tp n, const KERNEL_ARG_DTYPE threshold, __global const Dtype* in, __global Dtype* out) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { diff --git a/src/caffe/greentea/cl_kernels/auxiliary.cl b/src/caffe/greentea/cl_kernels/auxiliary.cl index 940cecb7c5f..a6ec1a8eb89 100644 --- a/src/caffe/greentea/cl_kernels/auxiliary.cl +++ b/src/caffe/greentea/cl_kernels/auxiliary.cl @@ -2,7 +2,7 @@ #include "header.cl" #endif -__kernel void TEMPLATE(gpu_set,Dtype)(const int_tp n, const Dtype alpha, __global Dtype* y) { +__kernel void TEMPLATE(gpu_set,Dtype)(const int_tp n, const KERNEL_ARG_DTYPE alpha, __global Dtype* y) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { y[index] = alpha; } diff --git a/src/caffe/greentea/cl_kernels/batch_norm.cl b/src/caffe/greentea/cl_kernels/batch_norm.cl index 08d0ddeff53..64d4d2d278e 100644 --- a/src/caffe/greentea/cl_kernels/batch_norm.cl +++ b/src/caffe/greentea/cl_kernels/batch_norm.cl @@ -3,11 +3,11 @@ #endif Dtype TEMPLATE(bn_common,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, - const Dtype scale, const Dtype eps, - __global const Dtype* mean, - __global const Dtype* variance, - __global const Dtype* data, - int_tp *out_off) { + const Dtype scale, const Dtype eps, + __global const Dtype* mean, + __global const Dtype* variance, + __global const Dtype* data, + int_tp *out_off) { const int_tp idx_num = get_global_id(0); const int_tp idx_chans = get_global_id(1); const int_tp idx_spatial_dim = get_global_id(2); @@ -16,7 +16,7 @@ Dtype TEMPLATE(bn_common,Dtype)(const int_tp num, const int_tp channels, const i Dtype v = variance[idx_chans]; m = -scale * m; - v = (Dtype)native_powr((float)mad(scale, v, eps), (float)-0.5); + v = (Dtype)native_powr((Dtype)mad(scale, v, eps), (Dtype)-0.5); *out_off = (idx_num * channels + idx_chans) * spatial_dim + idx_spatial_dim; return (v * (data[*out_off] + m)); @@ -24,7 +24,7 @@ Dtype TEMPLATE(bn_common,Dtype)(const int_tp num, const int_tp channels, const i __kernel void TEMPLATE(bn_use_global_stats_in_place,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, - const Dtype scale, const Dtype eps, + const KERNEL_ARG_DTYPE scale, const KERNEL_ARG_DTYPE eps, __global const Dtype* mean, __global const Dtype* variance, __global Dtype* top) { @@ -34,7 +34,7 @@ __kernel void TEMPLATE(bn_use_global_stats_in_place,Dtype)(const int_tp num, con } __kernel void TEMPLATE(bn_use_global_stats_in_place_fused_relu,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, - const Dtype scale, const Dtype eps, + const KERNEL_ARG_DTYPE scale, const KERNEL_ARG_DTYPE eps, __global const Dtype* mean, __global const Dtype* variance, __global Dtype* top) { @@ -44,7 +44,7 @@ __kernel void TEMPLATE(bn_use_global_stats_in_place_fused_relu,Dtype)(const int_ } __kernel void TEMPLATE(bn_use_global_stats,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, - const Dtype scale, const Dtype eps, + const KERNEL_ARG_DTYPE scale, const KERNEL_ARG_DTYPE eps, __global const Dtype* mean, __global const Dtype* variance, __global const Dtype* bottom, @@ -55,7 +55,7 @@ __kernel void TEMPLATE(bn_use_global_stats,Dtype)(const int_tp num, const int_tp } __kernel void TEMPLATE(bn_use_global_stats_fused_relu,Dtype)(const int_tp num, const int_tp channels, const int_tp spatial_dim, - const Dtype scale, const Dtype eps, + const KERNEL_ARG_DTYPE scale, const KERNEL_ARG_DTYPE eps, __global const Dtype* mean, __global const Dtype* variance, __global const Dtype* bottom, diff --git a/src/caffe/greentea/cl_kernels/benchmark.cl b/src/caffe/greentea/cl_kernels/benchmark.cl index a4004c0fe60..4db2395e2c3 100644 --- a/src/caffe/greentea/cl_kernels/benchmark.cl +++ b/src/caffe/greentea/cl_kernels/benchmark.cl @@ -2,6 +2,6 @@ #include "header.cl" #endif -__kernel void TEMPLATE(null_kernel,Dtype)(Dtype arg) { +__kernel void TEMPLATE(null_kernel,Dtype)(KERNEL_ARG_DTYPE arg) { Dtype out = arg; } diff --git a/src/caffe/greentea/cl_kernels/channel.cl b/src/caffe/greentea/cl_kernels/channel.cl index bf65f536fb1..7fcdeaa9a96 100644 --- a/src/caffe/greentea/cl_kernels/channel.cl +++ b/src/caffe/greentea/cl_kernels/channel.cl @@ -10,7 +10,7 @@ __kernel void TEMPLATE(kernel_channel_max,Dtype)(const int_tp num, const int_tp get_global_size(0)) { int_tp n = index / spatial_dim; int_tp s = index % spatial_dim; - float maxval = -FLT_MAX; + Dtype maxval = -DTYPE_MAX; for (int_tp c = 0; c < channels; ++c) { maxval = max((Dtype)(data[(n * channels + c) * spatial_dim + s]), (Dtype)maxval); } diff --git a/src/caffe/greentea/cl_kernels/contrastive_loss.cl b/src/caffe/greentea/cl_kernels/contrastive_loss.cl index bb71dfec8b0..a4b111eb44d 100644 --- a/src/caffe/greentea/cl_kernels/contrastive_loss.cl +++ b/src/caffe/greentea/cl_kernels/contrastive_loss.cl @@ -3,7 +3,7 @@ #endif __kernel void TEMPLATE(cll_backward,Dtype)(const int_tp count, const int_tp channels, - const Dtype margin, const Dtype alpha, __global const Dtype* y, + const KERNEL_ARG_DTYPE margin, const KERNEL_ARG_DTYPE alpha, __global const Dtype* y, __global const Dtype* diff, __global const Dtype* dist_sq, __global Dtype *bottom_diff) { for (int_tp i = get_global_id(0); i < count; @@ -27,7 +27,7 @@ __kernel void TEMPLATE(cll_backward,Dtype)(const int_tp count, const int_tp chan } __kernel void TEMPLATE(cll_backward_legacy,Dtype)(const int count, const int channels, - const Dtype margin, const Dtype alpha, __global Dtype* y, + const KERNEL_ARG_DTYPE margin, const KERNEL_ARG_DTYPE alpha, __global Dtype* y, __global Dtype* diff, __global Dtype* dist_sq, __global Dtype* bottom_diff) { for (int_tp i = get_global_id(0); i < count; diff --git a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl index e35b632d1ce..515ab46029e 100644 --- a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl +++ b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl @@ -2,7 +2,7 @@ #include "header.cl" #endif -__kernel void TEMPLATE(conv_layer_spatial_phony,Dtype)(Dtype arg) { +__kernel void TEMPLATE(conv_layer_spatial_phony,Dtype)(KERNEL_ARG_DTYPE arg) { Dtype out = arg; } @@ -122,8 +122,30 @@ __kernel void CFMultiNoPadding( } } } + #endif +#if defined(convolve_simd) || defined(Conv_Interleaved) +#if TYPE == TYPE_HALF +#define INT_TYPE ushort +#define INT_TYPE2 ushort2 +#define INT_TYPE4 ushort4 +#define INT_TYPE8 ushort8 +#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read_us2 +#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read_us4 +#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read_us8 +#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read_us +#else +#define INT_TYPE uint +#define INT_TYPE2 uint2 +#define INT_TYPE4 uint4 +#define INT_TYPE8 uint8 +#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read2 +#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read4 +#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read8 +#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read +#endif +#endif //Begin IDLF kernels below here #ifdef IDLF @@ -134,26 +156,28 @@ __kernel void CFMultiNoPadding( // Each work-group (which will be mapped to 1 SIMD16/SIMD8 EU thread) will compute 16/8 different feature maps, but each feature map is for the same region of the imput image. // NDRange: (output_width+pad)/ OUT_BLOCK_WIDTH, (output_height+pad)/OUT_BLOCK_HEIGHT, NUM_FILTERS/OUT_BLOCK_DEPTH -// NOTE: for beignet this reqd_work_group_size does not guarantee that SIMD16/8 mode will be used, the compiler could choose to use two SIMD8 threads, and if that happens the code will break. +// NOTE: for beignet this reqd_work_group_size does not guarantee that SIMD16 mode will be used, the compiler could choose to use two SIMD8 threads, and if that happens the code will break. +#ifndef __BEIGNET__ __attribute__((reqd_work_group_size(1, 1, SIMD_SIZE))) +#endif __kernel void -convolve_simd( // __global float *inputs, __global float* weights, __global float* outputs +convolve_simd( #ifdef FUSED_CONV_ELTWISE __global Dtype* eltwise_data, #endif - __global float* inputs_base, - filter_qualifier float* weights_base, - __global float* biases_base, - __global float* outputs_base, + __global Dtype* inputs_base, + filter_qualifier Dtype* weights_base, + __global Dtype* biases_base, + __global Dtype* outputs_base, const ushort input_width, const ushort input_height, const ushort output_width, const ushort output_height) { - __global float* outputs = outputs_base; - __global float* inputs = inputs_base; - filter_qualifier float* weights = weights_base; - __global float* biases = biases_base; + __global Dtype* outputs = outputs_base; + __global Dtype* inputs = inputs_base; + filter_qualifier Dtype* weights = weights_base; + __global Dtype* biases = biases_base; uint_tp oc = get_global_id(0) * OUT_BLOCK_WIDTH; // oc = Output Column uint_tp or = get_global_id(1) * OUT_BLOCK_HEIGHT;// or = Output Row @@ -161,7 +185,7 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo uint_tp fmg = get_group_id(2); uint_tp lid = get_local_id(2); - float out[OUT_BLOCK_SIZE]; + Dtype out[OUT_BLOCK_SIZE]; int_tp in_addr; @@ -187,8 +211,8 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo + (curr_y - INPUT_PAD_H) * input_width // y tile offset + curr_x - INPUT_PAD_W; // x tile offset union { - float4 in_vec[INVEC_SIZE]; - float in_array[INVEC_SIZE * 4]; + Dtype4 in_vec[INVEC_SIZE]; + Dtype in_array[INVEC_SIZE * 4]; } in_buf; for(int_tp kd = 0; kd < INPUT_DEPTH; kd++) @@ -212,7 +236,7 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo in_buf.in_vec[reg].s2 = 0; in_buf.in_vec[reg].s3 = *(inputs + in_offset + 3); } else { - in_buf.in_vec[reg] = *(global float4*)(inputs + in_offset); // read SIMD_SIZE elements + in_buf.in_vec[reg] = vload4(0, (inputs + in_offset)); // read 16 elements if (curr_x + 1 >= input_width + INPUT_PAD_W) in_buf.in_vec[reg].s1 = 0; if (curr_x + 2 >= input_width + INPUT_PAD_W) @@ -225,7 +249,7 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo } curr_y += TILE_Y_STRIDE; #else - in_buf.in_vec[reg] = *(global float4*)(inputs + in_offset); // read SIMD_SIZE elements + in_buf.in_vec[reg] = vload4(0, (inputs + in_offset)); // read 16 elements #endif } in_offset += input_width * TILE_Y_STRIDE; @@ -241,19 +265,19 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo #define WEIGHT_PREF 1 #endif union { - float w[WEIGHT_PREF]; + Dtype w[WEIGHT_PREF]; #if KERNEL_WIDTH * KERNEL_HEIGHT != 1 - uint8 ui8; + INT_TYPE8 ui8; #endif } weight_buf; int_tp w_idx=0; uint_tp orig_weight_addr = weight_addr; #if KERNEL_WIDTH * KERNEL_HEIGHT != 1 - weight_buf.ui8 = intel_sub_group_block_read8((__global uint *)&weights[weight_addr]); + weight_buf.ui8 = SUB_GROUP_BLOCK_READ8((__global INT_TYPE *)&weights[weight_addr]); weight_addr += SIMD_SIZE * WEIGHT_PREF; #else - weight_buf.w[0] = as_float(intel_sub_group_block_read((__global uint *)&weights[weight_addr])); + weight_buf.w[0] = as_Dtype(SUB_GROUP_BLOCK_READ((__global INT_TYPE *)&weights[weight_addr])); weight_addr += SIMD_SIZE * 1; #endif @@ -267,7 +291,7 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo { for(int_tp br=0; br < OUT_BLOCK_HEIGHT; br++) { for(int_tp bc=0; bc < OUT_BLOCK_WIDTH; bc++) { - float input = BLOCK_IN((br * STRIDEY + kr * DILATION_Y) * TILE_X + bc * STRIDEX + kc * DILATION_X); + Dtype input = BLOCK_IN((br * STRIDEY + kr * DILATION_Y) * TILE_X + bc * STRIDEX + kc * DILATION_X); out[br * OUT_BLOCK_WIDTH + bc] = mad(weight_buf.w[w_idx % WEIGHT_PREF], input, out[br * OUT_BLOCK_WIDTH + bc]); } } @@ -278,7 +302,7 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo && ((w_idx + 1) <= (KERNEL_WIDTH * KERNEL_HEIGHT - WEIGHT_PREF)) #endif ) { - weight_buf.ui8 = intel_sub_group_block_read8((__global uint *)&weights[weight_addr]); + weight_buf.ui8 = SUB_GROUP_BLOCK_READ8((__global INT_TYPE *)&weights[weight_addr]); weight_addr += SIMD_SIZE * WEIGHT_PREF; // weights must be stored in just the right SIMD swizzled format for this to work, see host code for details. } #if KERNEL_WIDTH*KERNEL_HEIGHT % 8 == 0 @@ -288,11 +312,11 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo #if KERNEL_WIDTH * KERNEL_HEIGHT % 8 == 1 weight_buf.w[0] = weights[weight_addr]; #elif KERNEL_WIDTH * KERNEL_HEIGHT % 8 == 2 - weight_buf.ui8.s01 = intel_sub_group_block_read2((__global uint *)&weights[weight_addr]); + weight_buf.ui8.s01 = SUB_GROUP_BLOCK_READ2((__global INT_TYPE *)&weights[weight_addr]); #elif KERNEL_WIDTH * KERNEL_HEIGHT % 8 <= 4 - weight_buf.ui8.s0123 = intel_sub_group_block_read4((__global uint *)&weights[weight_addr]); + weight_buf.ui8.s0123 = SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)&weights[weight_addr]); #else - weight_buf.ui8 = intel_sub_group_block_read8((__global uint *)&weights[weight_addr]); + weight_buf.ui8 = SUB_GROUP_BLOCK_READ8((__global INT_TYPE *)&weights[weight_addr]); #endif #endif #endif @@ -311,7 +335,8 @@ convolve_simd( // __global float *inputs, __global float* weights, __global flo if ((ALIGNED_NUM_FILTERS == NUM_FILTERS || fm < NUM_FILTERS)) { uint_tp out_addr = OUT_BUFF_OFFSET + ( num_in_batch * TOTAL_OUTPUT_DEPTH + fm ) * output_width * output_height; out_addr += or * output_width + oc; - float bias = biases[(fm % ALIGNED_NUM_FILTERS)]; + // we need this address calculation for biases because we support views and batching + Dtype bias = biases[fm]; for(uint_tp r = 0; r < OUT_BLOCK_HEIGHT; r++) { if (r + or >= output_height) break; for(uint_tp c = 0; c < OUT_BLOCK_WIDTH; c++) { @@ -365,6 +390,25 @@ typedef struct float15 { float s0; float s1; float s2; float s3; float s4; float float s6; float s7; float s8; float s9; float sa; float sb; float sc; float sd; float se; } float15; typedef struct float0 { float s0; } float0; //never used but makes compiler happy. +typedef struct half1 { half s0; } half1; +typedef struct half5 { half s0; half s1; half s2; half s3; half s4; } half5; +typedef struct half6 { half s0; half s1; half s2; half s3; half s4; half s5; } half6; +typedef struct half7 { half s0; half s1; half s2; half s3; half s4; half s5; half s6; } half7; +typedef struct half9 { half s0; half s1; half s2; half s3; half s4; half s5; half s6; half s7; half s8; } half9; +typedef struct half10 { half s0; half s1; half s2; half s3; half s4; half s5; + half s6; half s7; half s8; half s9; } half10; +typedef struct half11 { half s0; half s1; half s2; half s3; half s4; half s5; + half s6; half s7; half s8; half s9; half sa; } half11; +typedef struct half12 { half s0; half s1; half s2; half s3; half s4; half s5; + half s6; half s7; half s8; half s9; half sa; half sb; } half12; +typedef struct half13 { half s0; half s1; half s2; half s3; half s4; half s5; + half s6; half s7; half s8; half s9; half sa; half sb; half sc; } half13; +typedef struct half14 { half s0; half s1; half s2; half s3; half s4; half s5; + half s6; half s7; half s8; half s9; half sa; half sb; half sc; half sd; } half14; +typedef struct half15 { half s0; half s1; half s2; half s3; half s4; half s5; + half s6; half s7; half s8; half s9; half sa; half sb; half sc; half sd; half se; } half15; +typedef struct half0 { half s0; } half0; //never used but makes compiler happy. + #define OUT_PITCH_X output_width #define ROW_PITCH input_width @@ -419,7 +463,7 @@ typedef struct float0 { float s0; } float0; //never used but makes compiler happ #define TILE_K KERNEL_WIDTH #define TILE_N 32 -#ifdef __BEIGNET__ +#ifndef __BEIGNET__ __attribute__((intel_reqd_sub_group_size(8))) #endif __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) @@ -444,7 +488,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); \ _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); \ } - typedef CAT( float, KERNEL_WIDTH ) float_t; + typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t; // True for all threads if filter_width is multiple of TILE_N // else, true for all but right-most column of threads. @@ -452,10 +496,10 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) { // Result ctile (*dst) is M rows x N columns // LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile. - float8 blockC00 = 0.f; - float8 blockC10 = 0.f; - float8 blockC20 = 0.f; - float8 blockC30 = 0.f; + Dtype8 blockC00 = 0.f; + Dtype8 blockC10 = 0.f; + Dtype8 blockC20 = 0.f; + Dtype8 blockC30 = 0.f; // Src0 (patch input) is directly used as atile. // Each work item points to the start of a different patch. @@ -465,7 +509,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) #if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1 int saved_y = curr_y; #endif - const __global float *src0_read = src0 + const __global Dtype *src0_read = src0 + aligned_input_size * global_z // batch offset + (curr_y - INPUT_PAD_H) * ROW_PITCH // y offset + (curr_x - INPUT_PAD_W); // x offset @@ -473,7 +517,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Src1 (filter) is directly used as btile. // It starts at the top of src1 and walks down. // btile is K rows x N columns. - const __global float *src1_read = src1 + ( global_x * TILE_N * 2); + const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2); // Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1. // Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch @@ -489,7 +533,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) do { // Load atile and btile. - // Kernel data is partially interleaved. Every 2 rows are interleaved at float8 granularity. + // Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype8 granularity. // The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non // interleaved row is padded with zero to ensure same size as interleaved rows. This // interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the @@ -501,11 +545,11 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1; #if INPUT_PAD_W == 0 && INPUT_PAD_H == 0 && DILATION_X == 1 && DILATION_Y == 1 - float_t blockA00 = ( (const __global float_t*)src0_read )[ 0 ]; - float* pblockA00 = (float*)(&blockA00); + Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[ 0 ]; + Dtype* pblockA00 = (Dtype*)(&blockA00); #else - float_t blockA00; - float* pblockA00 = (float*)(&blockA00); + Dtype_t blockA00; + Dtype* pblockA00 = (Dtype*)(&blockA00); int pos = 0; LOOP(KERNEL_WIDTH, pos, { @@ -518,20 +562,20 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) #endif src0_read += (ROW_PITCH * DILATION_Y); - float blockB00[KERNEL_WIDTH*4]; - float8* p8BlockB00 = (float8*)blockB00; - float4* p4BlockB00 = (float4*)blockB00; - float* pBlockB00 = (float* )blockB00; + Dtype blockB00[KERNEL_WIDTH*4]; + Dtype8* p8BlockB00 = (Dtype8*)blockB00; + Dtype4* p4BlockB00 = (Dtype4*)blockB00; + Dtype* pBlockB00 = (Dtype* )blockB00; interleaved_y = 0; LOOP(KERNEL_WIDTH_DIV2, interleaved_y, { - p8BlockB00[interleaved_y] = as_float8( intel_sub_group_block_read8( (const __global uint*)src1_read ) ); + p8BlockB00[interleaved_y] = as_Dtype8( SUB_GROUP_BLOCK_READ8( (const __global INT_TYPE *)src1_read ) ); src1_read += WIDTH1 * 2; } ) if ( kernel_width_is_odd ) { - p4BlockB00[KERNEL_WIDTH - 1] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) ); + p4BlockB00[KERNEL_WIDTH - 1] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE *)src1_read ) ); src1_read += WIDTH1 * 2; } @@ -575,11 +619,11 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) + ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset - __global float *out = dst + out_offset; - float bias[4]; - float4 *bias_vec; - bias_vec = (float4*)bias; - *bias_vec = as_float4(intel_sub_group_block_read4((__global uint *)biases + group_x * TILE_N)); + __global Dtype *out = dst + out_offset; + Dtype bias[4]; + Dtype4 *bias_vec; + bias_vec = (Dtype4*)bias; + *bias_vec = as_Dtype4(SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)biases + group_x * TILE_N)); if (global_y * TILE_M < output_width * output_height ) { @@ -599,7 +643,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Result ctile (*dst) is M rows x N columns // LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile. int i = 0; - float8 blockC[TILE_N_LAST_DIV8]; + Dtype8 blockC[TILE_N_LAST_DIV8]; LOOP(TILE_N_LAST_DIV8, i, { blockC[i] = 0.f; @@ -613,7 +657,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) #if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1 int saved_y = curr_y; #endif - const __global float *src0_read = src0 + const __global Dtype *src0_read = src0 + aligned_input_size * global_z // batch offset + (curr_y - INPUT_PAD_H) * ROW_PITCH // y offset + (curr_x - INPUT_PAD_W); // x offset @@ -621,7 +665,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Src1 (filter) is directly used as btile. // It starts at the top of src1 and walks down. // btile is K rows x N columns. - const __global float *src1_read = src1 + ( global_x * TILE_N * 2); + const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2); // Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1. // Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch @@ -638,11 +682,11 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Load atile and interleaved btile. const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1; #if INPUT_PAD_W == 0 && INPUT_PAD_H == 0 && DILATION_X == 1 && DILATION_Y == 1 - float_t blockA00 = ( (const __global float_t*)src0_read )[ 0 ]; - float* pblockA00 = (float*)(&blockA00); + Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[ 0 ]; + Dtype* pblockA00 = (Dtype*)(&blockA00); #else - float_t blockA00; - float* pblockA00 = (float*)(&blockA00); + Dtype_t blockA00; + Dtype* pblockA00 = (Dtype*)(&blockA00); int pos = 0; LOOP(KERNEL_WIDTH, pos, { @@ -654,43 +698,43 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) curr_y += DILATION_Y; #endif src0_read += (ROW_PITCH * DILATION_Y); - float blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8]; + Dtype blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8]; interleaved_y = 0; LOOP(KERNEL_WIDTH_DIV2, interleaved_y, { #if TILE_N_LAST_DIV8 == 1 - float2* p2BlockB = (float2* )blockB; - p2BlockB[interleaved_y] = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) ); + Dtype2* p2BlockB = (Dtype2* )blockB; + p2BlockB[interleaved_y] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) ); #elif TILE_N_LAST_DIV8 == 2 - float4* p4BlockB = (float4* )blockB; - p4BlockB[interleaved_y] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) ); + Dtype4* p4BlockB = (Dtype4* )blockB; + p4BlockB[interleaved_y] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) ); #elif TILE_N_LAST_DIV8 == 3 //TODO: broken. No block_read6 - float6* p6BlockB = (float6* )blockB; - (*((float8*)(&p6BlockB[interleaved_y]))).s0123 = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) ); - (*((float8*)(&p6BlockB[interleaved_y]))).s45 = as_float2( intel_sub_group_block_read2( (const __global uint*)(src1_read + 4 * 8) ) ); + Dtype6* p6BlockB = (Dtype6* )blockB; + (*((Dtype8*)(&p6BlockB[interleaved_y]))).s0123 = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) ); + (*((Dtype8*)(&p6BlockB[interleaved_y]))).s45 = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)(src1_read + 4 * 8) ) ); #endif src1_read += WIDTH1 * 2; } ) if ( kernel_width_is_odd ) { #if TILE_N_LAST_DIV8 == 1 - float* pBlockB = (float* )blockB; - pBlockB[KERNEL_WIDTH - 1] = as_float( intel_sub_group_block_read( (const __global uint*)src1_read ) ); + Dtype* pBlockB = (Dtype* )blockB; + pBlockB[KERNEL_WIDTH - 1] = as_Dtype( SUB_GROUP_BLOCK_READ( (const __global INT_TYPE*)src1_read ) ); #elif TILE_N_LAST_DIV8 == 2 - float2* p2BlockB = (float2* )blockB; - p2BlockB[KERNEL_WIDTH - 1] = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) ); + Dtype2* p2BlockB = (Dtype2* )blockB; + p2BlockB[KERNEL_WIDTH - 1] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) ); #elif TILE_N_LAST_DIV8 == 3 - float3* p3BlockB = (float3* )blockB; - p3BlockB[KERNEL_WIDTH - 1].s01 = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) ); - p3BlockB[KERNEL_WIDTH - 1].s2 = as_float( intel_sub_group_block_read( (const __global uint*) (src1_read + 2 * 8) ) ); + Dtype3* p3BlockB = (Dtype3* )blockB; + p3BlockB[KERNEL_WIDTH - 1].s01 = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) ); + p3BlockB[KERNEL_WIDTH - 1].s2 = as_Dtype( SUB_GROUP_BLOCK_READ( (const __global INT_TYPE*) (src1_read + 2 * 8) ) ); #endif src1_read += WIDTH1 * 2; } // Perform MADs - float* pBlockB = (float*)blockB; + Dtype* pBlockB = (Dtype*)blockB; kernel_idx = 0; interleaved_y = 0; LOOP(KERNEL_WIDTH_DIV2, interleaved_y, @@ -734,12 +778,11 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) + ( group_x * TILE_N ) * out_pitch_y // channel offset + ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset - - __global float *out = dst + out_offset; - float bias[4]; - float4 *bias_vec; - bias_vec = (float4*)bias; - *bias_vec = as_float4(intel_sub_group_block_read4((__global uint *)biases + group_x * TILE_N)); + __global Dtype *out = dst + out_offset; + Dtype bias[4]; + Dtype4 *bias_vec; + bias_vec = (Dtype4*)bias; + *bias_vec = as_Dtype4(SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)biases + group_x * TILE_N)); if (global_y * TILE_M < output_width * output_height ) { @@ -755,239 +798,6 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) #endif } #endif - -#ifdef GEMM_LIKE_CONV_32_1_SIMD16 -#define TILE_M 1 -#define TILE_K KERNEL_WIDTH -#define TILE_N 32 - -#ifndef __BEIGNET__ -__attribute__((intel_reqd_sub_group_size(16))) -#endif -__kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) -{ - const int group_x = get_group_id(0); - const int group_y = get_group_id(1); - const int global_x = get_global_id(0); - const int global_y = get_global_id(1); - const int global_z = get_global_id(2); - int interleaved_y; - int kernel_y; - int kernel_idx; - - // Result ctile (*dst) is M rows x N columns - // LWG size is 1x16. Thus each thread calculates 16*M rows x N cols of ctile. - Dtype16 blockC00 = 0.f; - Dtype16 blockC10 = 0.f; - - // Src0 (patch input) is directly used as atile. - // Each work item points to the start of a different patch. - // atile is M rows x K columns. - int curr_x = ( global_y % output_width ) * STRIDE_X; - int curr_y = ( global_y / output_width ) * STRIDE_Y; -#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1 - int saved_y = curr_y; -#endif - - const __global Dtype *src0_read = src0 - + aligned_input_size * global_z // batch offset - + (curr_y - INPUT_PAD_H) * ROW_PITCH // y offset - + curr_x - INPUT_PAD_W; // x offset - const __global Dtype *src0_read_orig = src0_read; - - // Src1 (filter) is directly used as btile. - // It starts at the top of src1 and walks down. - // btile is K rows x N columns. - const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2 ); - -#define DOT_PRODUCT_16( _result, _rowA, colB ) \ - { \ - _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); \ - _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); \ - _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); \ - _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); \ - _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); \ - _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); \ - _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); \ - _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); \ - _result.s8 = mad( _rowA, sub_group_broadcast( colB, 8 ), _result.s8 ); \ - _result.s9 = mad( _rowA, sub_group_broadcast( colB, 9 ), _result.s9 ); \ - _result.sa = mad( _rowA, sub_group_broadcast( colB, 10 ), _result.sa ); \ - _result.sb = mad( _rowA, sub_group_broadcast( colB, 11 ), _result.sb ); \ - _result.sc = mad( _rowA, sub_group_broadcast( colB, 12 ), _result.sc ); \ - _result.sd = mad( _rowA, sub_group_broadcast( colB, 13 ), _result.sd ); \ - _result.se = mad( _rowA, sub_group_broadcast( colB, 14 ), _result.se ); \ - _result.sf = mad( _rowA, sub_group_broadcast( colB, 15 ), _result.sf ); \ - } - typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t; - // Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1. - // Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch - // and KERNEL_WIDTH/2 rows of interleaved filter. - int patch_depth = 0; -#ifndef __BEIGNET__ - __attribute__((opencl_unroll_hint(1))) -#endif - do - { - int patch_row = 0; -#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 - curr_y = saved_y; -#endif -#ifndef __BEIGNET__ - __attribute__((opencl_unroll_hint(1))) -#endif - do - { - // Load atile and btile. - // Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype16 granularity. - // The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non - // interleaved row is padded with zero to ensure same size as interleaved rows. This - // interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the - // kernel data would be arranged before/after interleaving for KERNEL_WIDTH=3. - // (0, 0) (16, 0) (32, 0) (48, 0) ... (0, 0) ( 0, 1) (16, 0) ( 0, 1) (32, 0) (0, 1) (48, 0) ... - // (0, 1) (16, 1) (32, 1) (48, 1) ... => (0, 2) (16, 2) (32, 2) (48, 2) ... - // (0, 2) (16, 2) (32, 2) (48, 2) ... ... - // ... - const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1; - -#if INPUT_PAD_W == 0 && INPUT_PAD_H == 0 && DILATION_X == 1 && DILATION_Y == 1 - Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[ 0 ]; - Dtype* pblockA00 = (Dtype*)(&blockA00); -#else - Dtype_t blockA00; - Dtype* pblockA00 = (Dtype*)(&blockA00); - int pos = 0; - LOOP(KERNEL_WIDTH, pos, - { - if (curr_y >= INPUT_PAD_H && curr_y < input_height + INPUT_PAD_H && curr_x + pos * DILATION_X >= INPUT_PAD_W && curr_x + pos * DILATION_X < input_width + INPUT_PAD_W) - pblockA00[pos] = src0_read[pos * DILATION_X]; - else - pblockA00[pos] = 0; - }) - curr_y += DILATION_Y; -#endif - src0_read += ROW_PITCH * DILATION_X; - uint blockB00[KERNEL_WIDTH * 2]; - uint4* p4BlockB00 = (uint4*)blockB00; - uint2* p2BlockB00 = (uint2*)blockB00; - Dtype* pBlockB00 = (Dtype*)blockB00; - - interleaved_y = 0; - LOOP(KERNEL_WIDTH_DIV2, interleaved_y, - { - p4BlockB00[interleaved_y] = intel_sub_group_block_read4( (const __global uint*)src1_read ); - src1_read += WIDTH1 * 2; - } ) - if ( kernel_width_is_odd ) - { - p2BlockB00[KERNEL_WIDTH - 1] = intel_sub_group_block_read2( (const __global uint*)src1_read ); - src1_read += WIDTH1 * 2; - } - - // Perform MADs - kernel_idx = 0; - interleaved_y = 0; - LOOP(KERNEL_WIDTH_DIV2, interleaved_y, - { - kernel_y = interleaved_y * 2; - DOT_PRODUCT_16( blockC00, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++; - DOT_PRODUCT_16( blockC00, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++; - DOT_PRODUCT_16( blockC10, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++; - DOT_PRODUCT_16( blockC10, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++; - } ) - if ( kernel_width_is_odd ) - { - kernel_y = interleaved_y * 2; - DOT_PRODUCT_16( blockC00, pblockA00[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++; - DOT_PRODUCT_16( blockC10, pblockA00[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++; - } - } - - //while( ++patch_row < 1 ); //debug - while( ++patch_row < KERNEL_HEIGHT ); - - src0_read += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y ); // reset to start of next slice of patch - } - //while ( ++patch_depth < 1 ); //debug - while ( ++patch_depth < INPUT_DEPTH ); - - // Dst resembles a cube of width x height x (output channel * batches). Each tile writes: - // (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used. - int_tp out_offset = global_z * out_pitch_z // batch offset - + ( group_x * TILE_N ) * out_pitch_y // channel offset - + ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset - + ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset - __global Dtype *out = dst + out_offset; - - Dtype bias[2]; - Dtype2 *bias_vec; - bias_vec = (Dtype2*)bias; - *bias_vec = as_float2(intel_sub_group_block_read2((__global uint *)biases + group_x * TILE_N)); - // Work around a potential compiler bug. - if (group_x > 0xFFFFFFFEul) { - out[0] = bias[0] + bias[1]; - } - - if (global_y * TILE_M < output_width * output_height ) - { -#if ( ( OUT_DEPTH % TILE_N ) == 0 ) - for (int i = 0; i < 16; i++) - { - ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); - ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); - } -#elif ( ( OUT_DEPTH % 16 ) == 0 ) - if ( ( global_x + 1 ) < get_global_size(0) ) - { - for ( int i = 0; i < 16; i++ ) - { - ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); - ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); - } - } - else - { - for (int i = 0; i < 16; i++) - { - ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); - } - } -#else - if ( ( global_x + 1 ) < get_global_size(0) ) - { - for ( int i = 0; i < 16; i++ ) - { - ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); - ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); - } - } - else - { -#if ( (OUT_DEPTH % TILE_N) > 16 ) - { - for (int i = 0; i < 16 ; i++) - { - ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); - } - for (int i = 0; i < OUT_DEPTH % 16 ; i++) - { - ACTIVATION_FUNCTION(dst, out_offset + (16+i) * out_pitch_y, blockC10[i] + intel_sub_group_shuffle(bias[1], i)); - } - } -#else - { - for (int i = 0; i < OUT_DEPTH % 16 ; i++) - { - ACTIVATION_FUNCTION(dst, out_offset + ( 0+i) * out_pitch_y, blockC00[i] + intel_sub_group_shuffle(bias[0], i)); - } - } -#endif - } -#endif - } -} -#endif - #ifdef GEMM_LIKE_CONV_32_2 ////////////////////////////////////////////////////////////////////////////// @@ -1005,7 +815,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) #define TILE_K KERNEL_WIDTH #define TILE_N 32 -#ifdef __BEIGNET__ +#ifndef __BEIGNET__ __attribute__((intel_reqd_sub_group_size(8))) #endif __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) @@ -1030,7 +840,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); \ _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); \ } - typedef CAT( float, KERNEL_WIDTH ) float_t; + typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t; // True for all threads if filter_width is multiple of TILE_N // else, true for all but right-most column of threads. @@ -1038,14 +848,14 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) { // Result ctile (*dst) is M rows x N columns // LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile. - float8 blockC00 = 0.f; - float8 blockC10 = 0.f; - float8 blockC20 = 0.f; - float8 blockC30 = 0.f; - float8 blockC01 = 0.f; - float8 blockC11 = 0.f; - float8 blockC21 = 0.f; - float8 blockC31 = 0.f; + Dtype8 blockC00 = 0.f; + Dtype8 blockC10 = 0.f; + Dtype8 blockC20 = 0.f; + Dtype8 blockC30 = 0.f; + Dtype8 blockC01 = 0.f; + Dtype8 blockC11 = 0.f; + Dtype8 blockC21 = 0.f; + Dtype8 blockC31 = 0.f; // Src0 (patch input) is directly used as atile. // Each work item points to the start of a different patch. @@ -1058,11 +868,11 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) int saved_y0 = curr_y0; int saved_y1 = curr_y1; #endif - const __global float *src0_read0 = src0 + const __global Dtype *src0_read0 = src0 + aligned_input_size * global_z // batch offset + (curr_y0 - INPUT_PAD_H) * ROW_PITCH // y offset + curr_x0 - INPUT_PAD_W; // x offset - const __global float *src0_read1 = src0 + const __global Dtype *src0_read1 = src0 + aligned_input_size * global_z // batch offset + (curr_y1 - INPUT_PAD_H) * ROW_PITCH // y offset + curr_x1 - INPUT_PAD_W; // x offset @@ -1070,7 +880,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Src1 (filter) is directly used as btile. // It starts at the top of src1 and walks down. // btile is K rows x N columns. - const __global float *src1_read = src1 + ( global_x * TILE_N * 2); + const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2); // Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1. // Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch @@ -1082,7 +892,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) do { // Load atile and btile. - // Kernel data is partially interleaved. Every 2 rows are interleaved at float8 granularity. + // Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype8 granularity. // The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non // interleaved row is padded with zero to ensure same size as interleaved rows. This // interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the @@ -1093,13 +903,13 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // ... const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1; #if INPUT_PAD_H == 0 && INPUT_PAD_W == 0 && DILATION_X == 1 && DILATION_Y == 1 - float_t blockA00 = ( (const __global float_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH; - float_t blockA01 = ( (const __global float_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH; - float* pblockA00 = (float*)(&blockA00); - float* pblockA01 = (float*)(&blockA01); + Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH; + Dtype_t blockA01 = ( (const __global Dtype_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH; + Dtype* pblockA00 = (Dtype*)(&blockA00); + Dtype* pblockA01 = (Dtype*)(&blockA01); #else - float_t blockA00; - float* pblockA00 = (float*)(&blockA00); + Dtype_t blockA00; + Dtype* pblockA00 = (Dtype*)(&blockA00); int pos = 0; LOOP(KERNEL_WIDTH, pos, { @@ -1109,8 +919,8 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) pblockA00[pos] = 0; }) curr_y0 += DILATION_Y; - float_t blockA01; - float* pblockA01 = (float*)(&blockA01); + Dtype_t blockA01; + Dtype* pblockA01 = (Dtype*)(&blockA01); pos = 0; LOOP(KERNEL_WIDTH, pos, { @@ -1120,23 +930,23 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) pblockA01[pos] = 0; }) curr_y1 += DILATION_Y; - src0_read0 += ROW_PITCH * DILATION_Y; - src0_read1 += ROW_PITCH * DILATION_Y; + src0_read0 += (ROW_PITCH * DILATION_Y); + src0_read1 += (ROW_PITCH * DILATION_Y); #endif - float blockB00[KERNEL_WIDTH*4]; - float8* p8BlockB00 = (float8*)blockB00; - float4* p4BlockB00 = (float4*)blockB00; - float* pBlockB00 = (float* )blockB00; + Dtype blockB00[KERNEL_WIDTH*4]; + Dtype8* p8BlockB00 = (Dtype8*)blockB00; + Dtype4* p4BlockB00 = (Dtype4*)blockB00; + Dtype* pBlockB00 = (Dtype* )blockB00; interleaved_y = 0; LOOP(KERNEL_WIDTH_DIV2, interleaved_y, { - p8BlockB00[interleaved_y] = as_float8( intel_sub_group_block_read8( (const __global uint*)src1_read ) ); + p8BlockB00[interleaved_y] = as_Dtype8( SUB_GROUP_BLOCK_READ8( (const __global INT_TYPE*)src1_read ) ); src1_read += WIDTH1 * 2; } ) if ( kernel_width_is_odd ) { - p4BlockB00[KERNEL_WIDTH - 1] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) ); + p4BlockB00[KERNEL_WIDTH - 1] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) ); src1_read += WIDTH1 * 2; } // Perform MADs @@ -1199,10 +1009,10 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) + ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset - float bias[4]; - float4 *bias_vec; - bias_vec = (float4*)bias; - *bias_vec = as_float4(intel_sub_group_block_read4((__global uint *)biases + group_x * TILE_N)); + Dtype bias[4]; + Dtype4 *bias_vec; + bias_vec = (Dtype4*)bias; + *bias_vec = as_Dtype4(SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)biases + group_x * TILE_N)); if( global_y * TILE_M < output_width * output_height ) { @@ -1232,8 +1042,8 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Result ctile (*dst) is M rows x N columns // LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile. int i = 0; - float8 blockC0[TILE_N_LAST_DIV8]; - float8 blockC1[TILE_N_LAST_DIV8]; + Dtype8 blockC0[TILE_N_LAST_DIV8]; + Dtype8 blockC1[TILE_N_LAST_DIV8]; LOOP(TILE_N_LAST_DIV8, i, { blockC0[i] = 0.f; @@ -1251,11 +1061,11 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) int saved_y0 = curr_y0; int saved_y1 = curr_y1; #endif - const __global float *src0_read0 = src0 + const __global Dtype *src0_read0 = src0 + aligned_input_size * global_z // batch offset + (curr_y0 - INPUT_PAD_H) * ROW_PITCH // y offset + curr_x0 - INPUT_PAD_W; // x offset - const __global float *src0_read1 = src0 + const __global Dtype *src0_read1 = src0 + aligned_input_size * global_z // batch offset + (curr_y1 - INPUT_PAD_H) * ROW_PITCH // y offset + curr_x1 - INPUT_PAD_W; // x offset @@ -1263,7 +1073,7 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Src1 (filter) is directly used as btile. // It starts at the top of src1 and walks down. // btile is K rows x N columns. - const __global float *src1_read = src1 + ( global_x * TILE_N * 2); + const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2); // Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1. // Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch @@ -1277,13 +1087,13 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) // Load atile and interleaved btile. const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1; #if INPUT_PAD_H == 0 && INPUT_PAD_W == 0 && DILATION_X == 1 && DILATION_Y == 1 - float_t blockA00 = ( (const __global float_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH; - float_t blockA01 = ( (const __global float_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH; - float* pblockA00 = (float*)(&blockA00); - float* pblockA01 = (float*)(&blockA01); + Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH; + Dtype_t blockA01 = ( (const __global Dtype_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH; + Dtype* pblockA00 = (Dtype*)(&blockA00); + Dtype* pblockA01 = (Dtype*)(&blockA01); #else - float_t blockA00; - float* pblockA00 = (float*)(&blockA00); + Dtype_t blockA00; + Dtype* pblockA00 = (Dtype*)(&blockA00); int pos = 0; LOOP(KERNEL_WIDTH, pos, { @@ -1293,8 +1103,8 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) pblockA00[pos] = 0; }) curr_y0 += DILATION_Y; - float_t blockA01; - float* pblockA01 = (float*)(&blockA01); + Dtype_t blockA01; + Dtype* pblockA01 = (Dtype*)(&blockA01); pos = 0; LOOP(KERNEL_WIDTH, pos, { @@ -1307,43 +1117,43 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) src0_read0 += (ROW_PITCH * DILATION_Y); src0_read1 += (ROW_PITCH * DILATION_Y); #endif - float blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8]; + Dtype blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8]; interleaved_y = 0; LOOP(KERNEL_WIDTH_DIV2, interleaved_y, { #if TILE_N_LAST_DIV8 == 1 - float2* p2BlockB = (float2* )blockB; - p2BlockB[interleaved_y] = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) ); + Dtype2* p2BlockB = (Dtype2* )blockB; + p2BlockB[interleaved_y] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) ); #elif TILE_N_LAST_DIV8 == 2 - float4* p4BlockB = (float4* )blockB; - p4BlockB[interleaved_y] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) ); + Dtype4* p4BlockB = (Dtype4* )blockB; + p4BlockB[interleaved_y] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) ); #elif TILE_N_LAST_DIV8 == 3 //TODO: broken. No block_read6 - float6* p6BlockB = (float6* )blockB; - (*((float8*)(&p6BlockB[interleaved_y]))).s0123 = as_float4( intel_sub_group_block_read4( (const __global uint*)src1_read ) ); - (*((float8*)(&p6BlockB[interleaved_y]))).s45 = as_float2( intel_sub_group_block_read2( (const __global uint*)(src1_read + 4 * 8) ) ); + Dtype6* p6BlockB = (Dtype6* )blockB; + (*((Dtype8*)(&p6BlockB[interleaved_y]))).s0123 = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) ); + (*((Dtype8*)(&p6BlockB[interleaved_y]))).s45 = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)(src1_read + 4 * 8) ) ); #endif src1_read += WIDTH1 * 2; } ) if ( kernel_width_is_odd ) { #if TILE_N_LAST_DIV8 == 1 - float* pBlockB = (float* )blockB; - pBlockB[KERNEL_WIDTH - 1] = as_float( intel_sub_group_block_read( (const __global uint*)src1_read ) ); + Dtype* pBlockB = (Dtype* )blockB; + pBlockB[KERNEL_WIDTH - 1] = as_Dtype( SUB_GROUP_BLOCK_READ( (const __global INT_TYPE*)src1_read ) ); #elif TILE_N_LAST_DIV8 == 2 - float2* p2BlockB = (float2* )blockB; - p2BlockB[KERNEL_WIDTH - 1] = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) ); + Dtype2* p2BlockB = (Dtype2* )blockB; + p2BlockB[KERNEL_WIDTH - 1] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) ); #elif TILE_N_LAST_DIV8 == 3 - float3* p3BlockB = (float3* )blockB; - p3BlockB[KERNEL_WIDTH - 1].s01 = as_float2( intel_sub_group_block_read2( (const __global uint*)src1_read ) ); - p3BlockB[KERNEL_WIDTH - 1].s2 = as_float( intel_sub_group_block_read( (const __global uint*) (src1_read + 8) ) ); + Dtype3* p3BlockB = (Dtype3* )blockB; + p3BlockB[KERNEL_WIDTH - 1].s01 = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) ); + p3BlockB[KERNEL_WIDTH - 1].s2 = as_Dtype( SUB_GROUP_BLOCK_READ( (const __global INT_TYPE*) (src1_read + 8) ) ); #endif src1_read += WIDTH1 * 2; } // Perform MADs - float* pBlockB = (float*)blockB; + Dtype* pBlockB = (Dtype*)blockB; kernel_idx = 0; interleaved_y = 0; LOOP(KERNEL_WIDTH_DIV2, interleaved_y, @@ -1404,11 +1214,12 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) + ( group_x * TILE_N ) * out_pitch_y // channel offset + ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset + ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset + __global Dtype *out1 = dst + out1_offset; - float bias[4]; - float4 *bias_vec; - bias_vec = (float4*)bias; - *bias_vec = as_float4(intel_sub_group_block_read4((__global uint *)biases + group_x * TILE_N)); + Dtype bias[4]; + Dtype4 *bias_vec; + bias_vec = (Dtype4*)bias; + *bias_vec = as_Dtype4(SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)biases + group_x * TILE_N)); if( global_y * TILE_M < output_width * output_height ) { for( int i = 0; i < 8; i++ ) @@ -1433,3 +1244,440 @@ __kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) #endif } #endif +#if defined(GEMM_LIKE_CONV_32_2_SIMD16) || defined(GEMM_LIKE_CONV_32_1_SIMD16) + +#define INTERLEAVED_SIMD16_OUTPUT(_out_, _offset_, _m_) do {\ + if (global_y * TILE_M < output_width * output_height ) \ + { \ + if ( ( OUT_DEPTH % TILE_N ) == 0 ) {\ + for (int i = 0; i < 16; i++) \ + { \ + ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_ [i] + intel_sub_group_shuffle(bias[0], i)); \ + ACTIVATION_FUNCTION(_out_, _offset_ + (16+i) * out_pitch_y, blockC1 ##_m_ [i] + intel_sub_group_shuffle(bias[1], i)); \ + } \ + } \ + else if( ( OUT_DEPTH % 16 ) == 0 ) { \ + if ( ( global_x + 1 ) < get_global_size(0) ) { \ + for ( int i = 0; i < 16; i++ ) \ + { \ + ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_ [i] + intel_sub_group_shuffle(bias[0], i)); \ + ACTIVATION_FUNCTION(_out_, _offset_ + (16+i) * out_pitch_y, blockC1 ##_m_ [i] + intel_sub_group_shuffle(bias[1], i)); \ + } \ + } \ + else { \ + for (int i = 0; i < 16; i++) \ + { \ + ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_ [i] + intel_sub_group_shuffle(bias[0], i)); \ + } \ + } \ + } \ + else { \ + if ( ( global_x + 1 ) < get_global_size(0) ) \ + { \ + for ( int i = 0; i < 16; i++ ) \ + { \ + ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_[i] + intel_sub_group_shuffle(bias[0], i)); \ + ACTIVATION_FUNCTION(_out_, _offset_ + (16+i) * out_pitch_y, blockC1 ##_m_[i] + intel_sub_group_shuffle(bias[1], i)); \ + } \ + } \ + else { \ + if ( (OUT_DEPTH % TILE_N) > 16 ) { \ + for (int i = 0; i < 16 ; i++) \ + { \ + ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_[i] + intel_sub_group_shuffle(bias[0], i)); \ + } \ + for (int i = 0; i < OUT_DEPTH % 16 ; i++) \ + { \ + ACTIVATION_FUNCTION(_out_, _offset_ + (16+i) * out_pitch_y, blockC1 ##_m_[i] + intel_sub_group_shuffle(bias[1], i)); \ + } \ + } \ + else { \ + for (int i = 0; i < OUT_DEPTH % 16 ; i++) \ + { \ + ACTIVATION_FUNCTION(_out_, _offset_ + ( 0+i) * out_pitch_y, blockC0 ##_m_[i] + intel_sub_group_shuffle(bias[0], i)); \ + } \ + } \ + } \ + } \ + } \ + }while(0) +#endif + +#ifdef GEMM_LIKE_CONV_32_1_SIMD16 +#define TILE_M 1 +#define TILE_K KERNEL_WIDTH +#define TILE_N 32 + +#ifndef __BEIGNET__ +__attribute__((intel_reqd_sub_group_size(16))) +#endif +__kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) +{ + const int group_x = get_group_id(0); + const int group_y = get_group_id(1); + const int global_x = get_global_id(0); + const int global_y = get_global_id(1); + const int global_z = get_global_id(2); + int interleaved_y; + int kernel_y; + int kernel_idx; + + // Result ctile (*dst) is M rows x N columns + // LWG size is 1x16. Thus each thread calculates 16*M rows x N cols of ctile. + Dtype16 blockC00 = 0.f; + Dtype16 blockC10 = 0.f; + + // Src0 (patch input) is directly used as atile. + // Each work item points to the start of a different patch. + // atile is M rows x K columns. + int curr_x = ( global_y % output_width ) * STRIDE_X; + int curr_y = ( global_y / output_width ) * STRIDE_Y; +#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1 + int saved_y = curr_y; +#endif + const __global Dtype *src0_read = src0 + + aligned_input_size * global_z // batch offset + + (curr_y - INPUT_PAD_H) * ROW_PITCH // y offset + + curr_x - INPUT_PAD_W; // x offset + const __global Dtype *src0_read_orig = src0_read; + + // Src1 (filter) is directly used as btile. + // It starts at the top of src1 and walks down. + // btile is K rows x N columns. + const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2 ); + +#define DOT_PRODUCT_16( _result, _rowA, colB ) \ + { \ + _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); \ + _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); \ + _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); \ + _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); \ + _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); \ + _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); \ + _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); \ + _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); \ + _result.s8 = mad( _rowA, sub_group_broadcast( colB, 8 ), _result.s8 ); \ + _result.s9 = mad( _rowA, sub_group_broadcast( colB, 9 ), _result.s9 ); \ + _result.sa = mad( _rowA, sub_group_broadcast( colB, 10 ), _result.sa ); \ + _result.sb = mad( _rowA, sub_group_broadcast( colB, 11 ), _result.sb ); \ + _result.sc = mad( _rowA, sub_group_broadcast( colB, 12 ), _result.sc ); \ + _result.sd = mad( _rowA, sub_group_broadcast( colB, 13 ), _result.sd ); \ + _result.se = mad( _rowA, sub_group_broadcast( colB, 14 ), _result.se ); \ + _result.sf = mad( _rowA, sub_group_broadcast( colB, 15 ), _result.sf ); \ + } + typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t; + // Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1. + // Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch + // and KERNEL_WIDTH/2 rows of interleaved filter. + int patch_depth = 0; +#ifndef __BEIGNET__ + __attribute__((opencl_unroll_hint(1))) +#endif + do + { + int patch_row = 0; +#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1 + curr_y = saved_y; +#endif +#ifndef __BEIGNET__ + __attribute__((opencl_unroll_hint(1))) +#endif + do + { + // Load atile and btile. + // Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype16 granularity. + // The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non + // interleaved row is padded with zero to ensure same size as interleaved rows. This + // interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the + // kernel data would be arranged before/after interleaving for KERNEL_WIDTH=3. + // (0, 0) (16, 0) (32, 0) (48, 0) ... (0, 0) ( 0, 1) (16, 0) ( 0, 1) (32, 0) (0, 1) (48, 0) ... + // (0, 1) (16, 1) (32, 1) (48, 1) ... => (0, 2) (16, 2) (32, 2) (48, 2) ... + // (0, 2) (16, 2) (32, 2) (48, 2) ... ... + // ... + const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1; + +#if INPUT_PAD_W == 0 && INPUT_PAD_H == 0 && DILATION_X == 1 && DILATION_Y == 1 + Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[ 0 ]; + Dtype* pblockA00 = (Dtype*)(&blockA00); +#else + Dtype_t blockA00; + Dtype* pblockA00 = (Dtype*)(&blockA00); + int pos = 0; + LOOP(KERNEL_WIDTH, pos, + { + if (curr_y >= INPUT_PAD_H && curr_y < input_height + INPUT_PAD_H && curr_x + pos * DILATION_X >= INPUT_PAD_W && curr_x + pos * DILATION_X < input_width + INPUT_PAD_W) + pblockA00[pos] = src0_read[pos * DILATION_X]; + else + pblockA00[pos] = 0; + }) + curr_y += DILATION_Y; +#endif + src0_read += ROW_PITCH * DILATION_Y; + INT_TYPE blockB00[KERNEL_WIDTH * 2]; + INT_TYPE4* p4BlockB00 = (INT_TYPE4*)blockB00; + INT_TYPE2* p2BlockB00 = (INT_TYPE2*)blockB00; + Dtype* pBlockB00 = (Dtype*)blockB00; + interleaved_y = 0; + LOOP(KERNEL_WIDTH_DIV2, interleaved_y, + { + p4BlockB00[interleaved_y] = SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ); + src1_read += WIDTH1 * 2; + } ) + if ( kernel_width_is_odd ) + { + p2BlockB00[KERNEL_WIDTH - 1] = SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ); + src1_read += WIDTH1 * 2; + } + + // Perform MADs + kernel_idx = 0; + interleaved_y = 0; + LOOP(KERNEL_WIDTH_DIV2, interleaved_y, + { + kernel_y = interleaved_y * 2; + DOT_PRODUCT_16( blockC00, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++; + DOT_PRODUCT_16( blockC00, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++; + DOT_PRODUCT_16( blockC10, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++; + DOT_PRODUCT_16( blockC10, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++; + } ) + if ( kernel_width_is_odd ) + { + kernel_y = interleaved_y * 2; + DOT_PRODUCT_16( blockC00, pblockA00[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++; + DOT_PRODUCT_16( blockC10, pblockA00[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++; + } + } + + //while( ++patch_row < 1 ); //debug + while( ++patch_row < KERNEL_HEIGHT ); + + src0_read += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y ); // reset to start of next slice of patch + } + //while ( ++patch_depth < 1 ); //debug + while ( ++patch_depth < INPUT_DEPTH ); + + // Dst resembles a cube of width x height x (output channel * batches). Each tile writes: + // (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used. + int_tp out_offset = global_z * out_pitch_z // batch offset + + ( group_x * TILE_N ) * out_pitch_y // channel offset + + ( ( global_y * TILE_M ) / output_width + OUT_PADDING_HEIGHT) * OUT_PITCH_X // y offset + + ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT; // x offset + __global Dtype *out = dst + out_offset; + + Dtype bias[2]; + Dtype2 *bias_vec; + bias_vec = (Dtype2*)bias; + *bias_vec = as_Dtype2(SUB_GROUP_BLOCK_READ2((__global INT_TYPE *)biases + group_x * TILE_N)); + // Work around a potential compiler bug. + if (group_x > 0xFFFFFFFEul) { + out[0] = bias[0] + bias[1]; + } + INTERLEAVED_SIMD16_OUTPUT(dst, out_offset, 0); +} +#endif + +#ifdef GEMM_LIKE_CONV_32_2_SIMD16 + +////////////////////////////////////////////////////////////////////////////// +// Conv_Interleaved_32_2_SIMD16 +// +// Convolution: each workitem computes 1 patch x 32 filters worth of output +// data. +#define TILE_M 2 +#define TILE_K KERNEL_WIDTH +#define TILE_N 32 + +#ifndef __BEIGNET__ +__attribute__((intel_reqd_sub_group_size(16))) +#endif +__kernel void Conv_Interleaved(GEMM_LIKE_KERNEL_ARGS) +{ + const int group_x = get_group_id(0); + const int group_y = get_group_id(1); + const int global_x = get_global_id(0); + const int global_y = get_global_id(1); + const int global_z = get_global_id(2); + int interleaved_y; + int kernel_y; + int kernel_idx; +#define DOT_PRODUCT_16( _result, _rowA, colB ) \ + { \ + _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); \ + _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); \ + _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); \ + _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); \ + _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); \ + _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); \ + _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); \ + _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); \ + _result.s8 = mad( _rowA, sub_group_broadcast( colB, 8 ), _result.s8 ); \ + _result.s9 = mad( _rowA, sub_group_broadcast( colB, 9 ), _result.s9 ); \ + _result.sa = mad( _rowA, sub_group_broadcast( colB, 10 ), _result.sa ); \ + _result.sb = mad( _rowA, sub_group_broadcast( colB, 11 ), _result.sb ); \ + _result.sc = mad( _rowA, sub_group_broadcast( colB, 12 ), _result.sc ); \ + _result.sd = mad( _rowA, sub_group_broadcast( colB, 13 ), _result.sd ); \ + _result.se = mad( _rowA, sub_group_broadcast( colB, 14 ), _result.se ); \ + _result.sf = mad( _rowA, sub_group_broadcast( colB, 15 ), _result.sf ); \ + } + typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t; + + // True for all threads if filter_width is multiple of TILE_N + // else, true for all but right-most column of threads. + { + // Result ctile (*dst) is M rows x N columns + // LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile. + Dtype16 blockC00 = 0.f; + Dtype16 blockC10 = 0.f; + Dtype16 blockC01 = 0.f; + Dtype16 blockC11 = 0.f; + + // Src0 (patch input) is directly used as atile. + // Each work item points to the start of a different patch. + // atile is M rows x K columns. + int curr_x0 = ( ( global_y * TILE_M + 0 ) % output_width ) * STRIDE_X; + int curr_x1 = ( ( global_y * TILE_M + 1 ) % output_width ) * STRIDE_X; + int curr_y0 = ( ( global_y * TILE_M + 0 ) / output_width ) * STRIDE_Y; + int curr_y1 = ( ( global_y * TILE_M + 1 ) / output_width ) * STRIDE_Y; +#if INPUT_PAD_H != 0 || INPUT_PAD_W != 0 || DILATION_X != 1 || DILATION_Y != 1 + int saved_y0 = curr_y0; + int saved_y1 = curr_y1; +#endif + const __global Dtype *src0_read0 = src0 + + aligned_input_size * global_z // batch offset + + (curr_y0 - INPUT_PAD_H) * ROW_PITCH // y offset + + curr_x0 - INPUT_PAD_W; // x offset + const __global Dtype *src0_read1 = src0 + + aligned_input_size * global_z // batch offset + + (curr_y1 - INPUT_PAD_H) * ROW_PITCH // y offset + + curr_x1 - INPUT_PAD_W; // x offset + + // Src1 (filter) is directly used as btile. + // It starts at the top of src1 and walks down. + // btile is K rows x N columns. + const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2); + + // Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1. + // Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch + // and KERNEL_WIDTH/2 rows of interleaved filter. + int patch_depth = 0; + do + { + int patch_row = 0; + do + { + // Load atile and btile. + // Kernel data is partially interleaved. Every 2 rows are interleaved at Dtype8 granularity. + // The exception is that if KERNEL_WIDTH is odd the last row is not interleaved. The non + // interleaved row is padded with zero to ensure same size as interleaved rows. This + // interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the + // kernel data would be arranged before/after interleaving for KERNEL_WIDTH=3. + // (0, 0) (8, 0) (16, 0) (24, 0) ... (0, 0) (0, 1) (8, 0) (0, 1) (16, 0) (0, 1) (24, 0) .. + // (0, 1) (8, 1) (16, 1) (24, 1) ... => (0, 2) (8, 2) (16, 2) (24, 2) ... + // (0, 2) (8, 2) (16, 2) (24, 2) ... ... + // ... + const bool kernel_width_is_odd = KERNEL_WIDTH % 2 == 1; +#if INPUT_PAD_H == 0 && INPUT_PAD_W == 0 && DILATION_X == 1 && DILATION_Y == 1 + Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read0 )[ 0 ]; src0_read0 += ROW_PITCH; + Dtype_t blockA01 = ( (const __global Dtype_t*)src0_read1 )[ 0 ]; src0_read1 += ROW_PITCH; + Dtype* pblockA00 = (Dtype*)(&blockA00); + Dtype* pblockA01 = (Dtype*)(&blockA01); +#else + Dtype_t blockA00; + Dtype* pblockA00 = (Dtype*)(&blockA00); + int pos = 0; + LOOP(KERNEL_WIDTH, pos, + { + if (curr_y0 >= INPUT_PAD_H && curr_y0 < input_height + INPUT_PAD_H && curr_x0 + pos * DILATION_X >= INPUT_PAD_W && curr_x0 + pos * DILATION_X < input_width + INPUT_PAD_W) + pblockA00[pos] = src0_read0[pos * DILATION_X]; + else + pblockA00[pos] = 0; + }) + curr_y0 += DILATION_Y; + Dtype_t blockA01; + Dtype* pblockA01 = (Dtype*)(&blockA01); + pos = 0; + LOOP(KERNEL_WIDTH, pos, + { + if (curr_y1 >= INPUT_PAD_H && curr_y1 < input_height + INPUT_PAD_H && curr_x1 + pos * DILATION_X >= INPUT_PAD_W && curr_x1 + pos * DILATION_X < input_width + INPUT_PAD_W) + pblockA01[pos] = src0_read1[pos * DILATION_X]; + else + pblockA01[pos] = 0; + }) + curr_y1 += DILATION_Y; + src0_read0 += (ROW_PITCH * DILATION_Y); + src0_read1 += (ROW_PITCH * DILATION_Y); +#endif + Dtype blockB00[KERNEL_WIDTH*2]; + Dtype4* p4BlockB00 = (Dtype4*)blockB00; + Dtype2* p2BlockB00 = (Dtype2*)blockB00; + Dtype* pBlockB00 = (Dtype* )blockB00; + + interleaved_y = 0; + LOOP(KERNEL_WIDTH_DIV2, interleaved_y, + { + p4BlockB00[interleaved_y] = as_Dtype4( SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE*)src1_read ) ); + src1_read += WIDTH1 * 2; + } ) + if ( kernel_width_is_odd ) + { + p2BlockB00[KERNEL_WIDTH - 1] = as_Dtype2( SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE*)src1_read ) ); + src1_read += WIDTH1 * 2; + } + // Perform MADs + kernel_idx = 0; + interleaved_y = 0; + LOOP(KERNEL_WIDTH_DIV2, interleaved_y, + { + kernel_y = interleaved_y * 2; + DOT_PRODUCT_16( blockC00, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); + DOT_PRODUCT_16( blockC01, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++; + DOT_PRODUCT_16( blockC00, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); + DOT_PRODUCT_16( blockC01, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++; + DOT_PRODUCT_16( blockC10, pblockA00[kernel_y ], pBlockB00[kernel_idx] ); + DOT_PRODUCT_16( blockC11, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++; + DOT_PRODUCT_16( blockC10, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] ); + DOT_PRODUCT_16( blockC11, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++; + } ) + if ( kernel_width_is_odd ) + { + kernel_y = interleaved_y * 2; + DOT_PRODUCT_16( blockC00, pblockA00[kernel_y], pBlockB00[kernel_idx] ); + DOT_PRODUCT_16( blockC01, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++; + DOT_PRODUCT_16( blockC10, pblockA00[kernel_y], pBlockB00[kernel_idx] ); + DOT_PRODUCT_16( blockC11, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++; + } + } + + //while( ++patch_row < 1 ); //debug + while( ++patch_row < KERNEL_HEIGHT ); +#if INPUT_PAD_W != 0 || INPUT_PAD_H != 0 || DILATION_X != 1 || DILATION_Y != 1 + curr_y0 = saved_y0; + curr_y1 = saved_y1; +#endif + src0_read0 += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y); // reset to start of next slice of patch + src0_read1 += slice_pitch - ( KERNEL_HEIGHT * ROW_PITCH * DILATION_Y); + } + //while ( ++patch_depth < 1 ); //debug + while ( ++patch_depth < INPUT_DEPTH ); + + // Dst resembles a cube of width x height x (output channel * batches). Each tile writes: + // (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used. + int_tp out0_offset = global_z * out_pitch_z // batch offset + + ( group_x * TILE_N ) * out_pitch_y // channel offset + + ( ( global_y * TILE_M + 0 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset + + ( ( global_y * TILE_M + 0 ) % output_width ) + OUT_PADDING_LEFT; // x offset + int_tp out1_offset = global_z * out_pitch_z // batch offset + + ( group_x * TILE_N ) * out_pitch_y // channel offset + + ( ( global_y * TILE_M + 1 ) / output_width + OUT_PADDING_HEIGHT ) * OUT_PITCH_X // y offset + + ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT; // x offset + + Dtype bias[2]; + Dtype2 *bias_vec; + bias_vec = (Dtype2*)bias; + *bias_vec = as_Dtype2(SUB_GROUP_BLOCK_READ2((__global INT_TYPE *)biases + group_x * TILE_N)); + + INTERLEAVED_SIMD16_OUTPUT(dst, out0_offset, 0); + INTERLEAVED_SIMD16_OUTPUT(dst, out1_offset, 1); + } +} +#endif diff --git a/src/caffe/greentea/cl_kernels/dropout.cl b/src/caffe/greentea/cl_kernels/dropout.cl index 103ab889c56..67540ce5eb1 100644 --- a/src/caffe/greentea/cl_kernels/dropout.cl +++ b/src/caffe/greentea/cl_kernels/dropout.cl @@ -6,7 +6,7 @@ __kernel void TEMPLATE(dropout_forward,Dtype)(const int_tp n, __global const Dtype* in, __global const uint_tp* mask, const uint_tp threshold, - const Dtype scale, + const KERNEL_ARG_DTYPE scale, __global Dtype* out) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { out[index] = in[index] * (Dtype)((mask[index] > threshold)?1.0:0.0) * scale; @@ -16,7 +16,7 @@ __kernel void TEMPLATE(dropout_forward,Dtype)(const int_tp n, __kernel void TEMPLATE(dropout_backward,Dtype)( const int_tp n, __global const Dtype* in_diff, __global const uint_tp* mask, const uint_tp threshold, - const Dtype scale, + const KERNEL_ARG_DTYPE scale, __global Dtype* out_diff) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { out_diff[index] = in_diff[index] * (Dtype)((mask[index] > threshold)?1.0:0.0) * scale; diff --git a/src/caffe/greentea/cl_kernels/eltwise.cl b/src/caffe/greentea/cl_kernels/eltwise.cl index 7a075cb5e75..0328c1dcec4 100644 --- a/src/caffe/greentea/cl_kernels/eltwise.cl +++ b/src/caffe/greentea/cl_kernels/eltwise.cl @@ -9,7 +9,7 @@ __kernel void TEMPLATE(eltwise_max_forward,Dtype)( __global int_tp* mask) { for (int_tp index = get_global_id(0); index < nthreads; index += get_global_size(0)) { - Dtype maxval = -FLT_MAX; + Dtype maxval = -DTYPE_MAX; int_tp maxidx = -1; if (bottom_data_a[index] > bottom_data_b[index]) { // only update for very first bottom_data blob (blob_idx == 0) diff --git a/src/caffe/greentea/cl_kernels/elu.cl b/src/caffe/greentea/cl_kernels/elu.cl index 0e3ef6f0d8c..4c5e357ceb0 100644 --- a/src/caffe/greentea/cl_kernels/elu.cl +++ b/src/caffe/greentea/cl_kernels/elu.cl @@ -4,7 +4,7 @@ __kernel void TEMPLATE(elu_forward,Dtype)(const int n, __global const Dtype* in, __global Dtype* out, - Dtype alpha) { + KERNEL_ARG_DTYPE alpha) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { out[index] = in[index] > 0 ? in[index] : alpha * (exp(in[index]) - (Dtype)1.0); } @@ -14,7 +14,7 @@ __kernel void TEMPLATE(elu_backward,Dtype)(const int n, __global const Dtype* in __global const Dtype* out_data, __global const Dtype* in_data, __global Dtype* out_diff, - Dtype alpha) { + KERNEL_ARG_DTYPE alpha) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { out_diff[index] = in_data[index] > 0 ? diff --git a/src/caffe/greentea/cl_kernels/embed.cl b/src/caffe/greentea/cl_kernels/embed.cl index 6face2dbf33..01c12cb422e 100644 --- a/src/caffe/greentea/cl_kernels/embed.cl +++ b/src/caffe/greentea/cl_kernels/embed.cl @@ -19,6 +19,45 @@ __kernel void TEMPLATE(embed_forward,Dtype)(const int_tp nthreads, } // atomic_add from: http://suhorukov.blogspot.com/2011/12/opencl-11-atomic-operations-on-floating.html + +// atomic_add fddrom: http://suhorukov.blogspot.com/2011/12/opencl-11-atomic-operations-on-floating.html +#if (TYPE == TYPE_HALF) + +// FIXME, has bug which may hang GPU. +inline void TEMPLATE(atomic_add,Dtype)(volatile __global Dtype *source, const Dtype operand) { + union { + uint_tp intVal; + Dtype floatVal[2]; + } newVal; + union { + uint_tp intVal; + Dtype floatVal[2]; + } prevVal; + do { + // FIXME, need to consider buffer overflow. + prevVal.floatVal[0] = *source; + prevVal.floatVal[1] = *(source+1); + newVal.floatVal[0] = prevVal.floatVal[0] + operand; + newVal.floatVal[1] = prevVal.floatVal[1]; + } while (atomic_cmpxchg((volatile __global unsigned int *)source, prevVal.intVal, newVal.intVal) != prevVal.intVal); +} + +__kernel void TEMPLATE(embed_backward,Dtype)(const int_tp nthreads, __global const Dtype* bottom_data, + __global const Dtype* top_diff, const int_tp M, const int_tp N, const int_tp K, + __global Dtype* weight_diff) { + for (int_tp top_index = get_global_id(0); top_index < nthreads; + top_index += get_global_size(0)) { + const int_tp n = top_index / N; + const int_tp d = top_index % N; + const int_tp index = (int_tp)(bottom_data[n]); + const int_tp weight_index = index * N + d; + + TEMPLATE(atomic_add,Dtype)((weight_diff + weight_index), *(top_diff + top_index)); + } +} +#endif + + #if (TYPE == TYPE_FLOAT) #ifdef ATOMICS_32_AVAILABLE inline void TEMPLATE(atomic_add,Dtype)(volatile __global Dtype *source, const Dtype operand) { diff --git a/src/caffe/greentea/cl_kernels/fft.cl b/src/caffe/greentea/cl_kernels/fft.cl index 589a5607fbf..85e7df625c7 100644 --- a/src/caffe/greentea/cl_kernels/fft.cl +++ b/src/caffe/greentea/cl_kernels/fft.cl @@ -2,7 +2,7 @@ #include "header.cl" #endif -__kernel void TEMPLATE(fft_phony,Dtype)(Dtype arg) { +__kernel void TEMPLATE(fft_phony,Dtype)(KERNEL_ARG_DTYPE arg) { Dtype out = arg; } @@ -813,8 +813,8 @@ __kernel void TEMPLATE(batchedCdotc,Dtype)(__global Dtype2* dst, cdotc4.xz += mad( s1.xz, s2.xz, s1.yw * s2.yw); cdotc4.yw += mad(-s1.xz, s2.yw, s1.yw * s2.xz); } - cdotc.x += dot(cdotc4.xz, (float2)(1)); - cdotc.y += dot(cdotc4.yw, (float2)(1)); + cdotc.x += dot(cdotc4.xz, (Dtype2)(1)); + cdotc.y += dot(cdotc4.yw, (Dtype2)(1)); if (r == 1) { const __global Dtype* src1_ptr2 = (const __global Dtype*)(((const __global Dtype4*)(src1_ptr)) + n); diff --git a/src/caffe/greentea/cl_kernels/fillbuffer.cl b/src/caffe/greentea/cl_kernels/fillbuffer.cl index 52d55a04a1a..f1847817022 100644 --- a/src/caffe/greentea/cl_kernels/fillbuffer.cl +++ b/src/caffe/greentea/cl_kernels/fillbuffer.cl @@ -9,7 +9,7 @@ __kernel void TEMPLATE(fillbuffer,Dtype)(const int_tp n, const char alpha, __glo } } -__kernel void TEMPLATE(fill,Dtype)(const int_tp n, const Dtype alpha, __global Dtype* x, +__kernel void TEMPLATE(fill,Dtype)(const int_tp n, const KERNEL_ARG_DTYPE alpha, __global Dtype* x, const int_tp offx) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { x[index + offx] = alpha; diff --git a/src/caffe/greentea/cl_kernels/gemm.cl b/src/caffe/greentea/cl_kernels/gemm.cl index 3767107761a..f3577bd9522 100644 --- a/src/caffe/greentea/cl_kernels/gemm.cl +++ b/src/caffe/greentea/cl_kernels/gemm.cl @@ -30,17 +30,17 @@ //#define USE_IMAGE_C #ifdef USE_IMAGE_C #if TYPE == TYPE_HALF -#define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read_us8( _C, _coordC ) ) +#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read_us8( _C, _coordC ) ) #define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write_us8( _C, _coordC, as_ushort8( _val ) ) #else -#define BLOCKC_READ8( _C, _coordC ) as_float8( intel_sub_group_block_read8( _C, _coordC ) ) +#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read8( _C, _coordC ) ) #define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) ) #endif #define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst #define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint)) #else #define BLOCKC_READ8( _C, _coordC ) \ - (float8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, \ + (Dtype8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, \ (_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ (_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ (_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ @@ -68,27 +68,27 @@ if (_coordC.y + 7 < M) \ _C[ ( _coordC.y + 7 )* ldc + _coordC.x + get_local_id(0) ] = _val.s7; \ }} while(0) -#define MATC_PARAMETER __global float * C, const int offC, const int M, const int N, const int ldc +#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N, const int ldc #define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1) #endif #define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) \ int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); \ int2 coordC = coordDst; \ - float8 blockC00; \ - float8 blockC01; \ - float8 blockC02; \ - float8 blockC03; \ + Dtype8 blockC00; \ + Dtype8 blockC01; \ + Dtype8 blockC02; \ + Dtype8 blockC03; \ if (BETA_NOT0) { \ blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \ if (!ALPHA1) { \ - blockC00 = mad(blockAxB00, (float8)alpha, blockC00); \ - blockC01 = mad(blockAxB01, (float8)alpha, blockC01); \ - blockC02 = mad(blockAxB02, (float8)alpha, blockC02); \ - blockC03 = mad(blockAxB03, (float8)alpha, blockC03); \ + blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \ + blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \ + blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \ + blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \ } else { \ blockC00 += blockAxB00; \ blockC01 += blockAxB01; \ @@ -101,10 +101,10 @@ blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \ if (!ALPHA1) { \ - blockC00 = mad(blockAxB00, (float8)alpha, blockC00); \ - blockC01 = mad(blockAxB01, (float8)alpha, blockC01); \ - blockC02 = mad(blockAxB02, (float8)alpha, blockC02); \ - blockC03 = mad(blockAxB03, (float8)alpha, blockC03); \ + blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \ + blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \ + blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \ + blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \ } else { \ blockC00 += blockAxB00; \ blockC01 += blockAxB01; \ @@ -119,7 +119,7 @@ // Get the specified column of the block of the block #define TRANSPOSE_BLOCK_8( _block, _col ) \ - (float8)( intel_sub_group_shuffle( _block.s0, _col ), \ + (Dtype8)( intel_sub_group_shuffle( _block.s0, _col ), \ intel_sub_group_shuffle( _block.s1, _col ), \ intel_sub_group_shuffle( _block.s2, _col ), \ intel_sub_group_shuffle( _block.s3, _col ), \ @@ -132,58 +132,58 @@ #if TYPE == TYPE_HALF #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB00, _blockB01 ) \ { \ - const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ - const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ - const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ - const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ - const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ - const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ - const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ - const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ - const float8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \ - const float8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \ - const float8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \ - const float8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \ - const float8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \ - const float8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \ - const float8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \ - const float8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \ - _result = mad( (float8)(_blockB00.s0), acol0, _result ); \ - _result = mad( (float8)(_blockB00.s1), acol1, _result ); \ - _result = mad( (float8)(_blockB00.s2), acol2, _result ); \ - _result = mad( (float8)(_blockB00.s3), acol3, _result ); \ - _result = mad( (float8)(_blockB00.s4), acol4, _result ); \ - _result = mad( (float8)(_blockB00.s5), acol5, _result ); \ - _result = mad( (float8)(_blockB00.s6), acol6, _result ); \ - _result = mad( (float8)(_blockB00.s7), acol7, _result ); \ - _result = mad( (float8)(_blockB01.s0), acol8, _result ); \ - _result = mad( (float8)(_blockB01.s1), acol9, _result ); \ - _result = mad( (float8)(_blockB01.s2), acola, _result ); \ - _result = mad( (float8)(_blockB01.s3), acolb, _result ); \ - _result = mad( (float8)(_blockB01.s4), acolc, _result ); \ - _result = mad( (float8)(_blockB01.s5), acold, _result ); \ - _result = mad( (float8)(_blockB01.s6), acole, _result ); \ - _result = mad( (float8)(_blockB01.s7), acolf, _result ); \ + const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ + const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ + const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ + const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ + const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ + const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ + const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ + const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ + const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \ + const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \ + const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \ + const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \ + const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \ + const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \ + const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \ + const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \ + _result = mad( (Dtype8)(_blockB00.s0), acol0, _result ); \ + _result = mad( (Dtype8)(_blockB00.s1), acol1, _result ); \ + _result = mad( (Dtype8)(_blockB00.s2), acol2, _result ); \ + _result = mad( (Dtype8)(_blockB00.s3), acol3, _result ); \ + _result = mad( (Dtype8)(_blockB00.s4), acol4, _result ); \ + _result = mad( (Dtype8)(_blockB00.s5), acol5, _result ); \ + _result = mad( (Dtype8)(_blockB00.s6), acol6, _result ); \ + _result = mad( (Dtype8)(_blockB00.s7), acol7, _result ); \ + _result = mad( (Dtype8)(_blockB01.s0), acol8, _result ); \ + _result = mad( (Dtype8)(_blockB01.s1), acol9, _result ); \ + _result = mad( (Dtype8)(_blockB01.s2), acola, _result ); \ + _result = mad( (Dtype8)(_blockB01.s3), acolb, _result ); \ + _result = mad( (Dtype8)(_blockB01.s4), acolc, _result ); \ + _result = mad( (Dtype8)(_blockB01.s5), acold, _result ); \ + _result = mad( (Dtype8)(_blockB01.s6), acole, _result ); \ + _result = mad( (Dtype8)(_blockB01.s7), acolf, _result ); \ } #else #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ { \ - const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ - const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ - const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ - const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ - const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ - const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ - const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ - const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ - _result = mad( (float8)(_blockB.s0), acol0, _result ); \ - _result = mad( (float8)(_blockB.s1), acol1, _result ); \ - _result = mad( (float8)(_blockB.s2), acol2, _result ); \ - _result = mad( (float8)(_blockB.s3), acol3, _result ); \ - _result = mad( (float8)(_blockB.s4), acol4, _result ); \ - _result = mad( (float8)(_blockB.s5), acol5, _result ); \ - _result = mad( (float8)(_blockB.s6), acol6, _result ); \ - _result = mad( (float8)(_blockB.s7), acol7, _result ); \ + const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ + const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ + const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ + const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ + const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ + const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ + const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ + const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ + _result = mad( (Dtype8)(_blockB.s0), acol0, _result ); \ + _result = mad( (Dtype8)(_blockB.s1), acol1, _result ); \ + _result = mad( (Dtype8)(_blockB.s2), acol2, _result ); \ + _result = mad( (Dtype8)(_blockB.s3), acol3, _result ); \ + _result = mad( (Dtype8)(_blockB.s4), acol4, _result ); \ + _result = mad( (Dtype8)(_blockB.s5), acol5, _result ); \ + _result = mad( (Dtype8)(_blockB.s6), acol6, _result ); \ + _result = mad( (Dtype8)(_blockB.s7), acol7, _result ); \ } #endif @@ -195,31 +195,31 @@ __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \ __read_only image2d_t A, \ __read_only image2d_t B, \ MATC_PARAMETER, \ - float alpha_in, \ - float beta_in, \ + KERNEL_ARG_DTYPE alpha_in, \ + KERNEL_ARG_DTYPE beta_in, \ int width0, \ int isFirstColBlock) \ { \ - const float alpha = (float)alpha_in; \ - const float beta = (float)beta_in; \ + const Dtype alpha = (Dtype)alpha_in; \ + const Dtype beta = (Dtype)beta_in; \ const int group_x = get_group_id(0); \ const int group_y = get_group_id(1); \ - float8 blockAxB00 = 0; \ - float8 blockAxB01 = 0; \ - float8 blockAxB02 = 0; \ - float8 blockAxB03 = 0; \ + Dtype8 blockAxB00 = 0; \ + Dtype8 blockAxB01 = 0; \ + Dtype8 blockAxB02 = 0; \ + Dtype8 blockAxB03 = 0; \ int2 coordA = (int2)( 0, group_y * TILE_M ); \ int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \ do \ { \ int2 coordBTemp = coordB; \ - float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ - float8 blockB01 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ + Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ + Dtype8 blockB01 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ int2 coordATemp = coordA; \ - float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \ + Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, blockB01 ); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, blockB01 ); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, blockB01 ); \ @@ -236,30 +236,30 @@ __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \ __read_only image2d_t A, \ __read_only image2d_t B, \ MATC_PARAMETER, \ - float alpha_in, \ - float beta_in, \ + KERNEL_ARG_DTYPE alpha_in, \ + KERNEL_ARG_DTYPE beta_in, \ int width0, \ int isFirstColBlock) \ { \ - const float alpha = (float)alpha_in; \ - const float beta = (float)beta_in; \ + const Dtype alpha = (Dtype)alpha_in; \ + const Dtype beta = (Dtype)beta_in; \ const int group_x = get_group_id(0); \ const int group_y = get_group_id(1); \ - float8 blockAxB00 = 0.0f; \ - float8 blockAxB01 = 0.0f; \ - float8 blockAxB02 = 0.0f; \ - float8 blockAxB03 = 0.0f; \ + Dtype8 blockAxB00 = 0.0f; \ + Dtype8 blockAxB01 = 0.0f; \ + Dtype8 blockAxB02 = 0.0f; \ + Dtype8 blockAxB03 = 0.0f; \ int2 coordA = (int2)( 0, group_y * TILE_M ); \ int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \ do \ { \ int2 coordBTemp = coordB; \ - float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ + Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ int2 coordATemp = coordA; \ - float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \ + Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ @@ -277,10 +277,11 @@ GEMM_NN(0, 1) // ALPHA != 1, BETA != 0 #undef TRANSPOSE_BLOCK_8 #undef MULTIPLY_BLOCKS_8x8 +#undef GEMM_NN // replicate the first row to column block. #define TRANSPOSE_BLOCK_8(_vec, _col) \ - (float8)( intel_sub_group_shuffle(_vec, _col + 0), \ + (Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \ intel_sub_group_shuffle(_vec, _col + 1), \ intel_sub_group_shuffle(_vec, _col + 2), \ intel_sub_group_shuffle(_vec, _col + 3), \ @@ -291,14 +292,14 @@ GEMM_NN(0, 1) // ALPHA != 1, BETA != 0 #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \ { \ - _result = mad( (float8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); \ - _result = mad( (float8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); \ - _result = mad( (float8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); \ - _result = mad( (float8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); \ - _result = mad( (float8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); \ - _result = mad( (float8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); \ - _result = mad( (float8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); \ - _result = mad( (float8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); \ + _result = mad( (Dtype8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); \ + _result = mad( (Dtype8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); \ + _result = mad( (Dtype8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); \ + _result = mad( (Dtype8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); \ + _result = mad( (Dtype8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); \ + _result = mad( (Dtype8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); \ + _result = mad( (Dtype8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); \ + _result = mad( (Dtype8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); \ } #if TYPE == TYPE_HALF @@ -309,28 +310,28 @@ __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ __read_only image2d_t A, \ __read_only image2d_t B, \ MATC_PARAMETER, \ - float alpha_in, \ - float beta_in, \ + KERNEL_ARG_DTYPE alpha_in, \ + KERNEL_ARG_DTYPE beta_in, \ int width0, \ int isFirstColBlock) \ { \ - const float alpha = (float)alpha_in; \ - const float beta = (float)beta_in; \ + const Dtype alpha = (Dtype)alpha_in; \ + const Dtype beta = (Dtype)beta_in; \ const int group_x = get_group_id(0);\ const int group_y = get_group_id(1);\ - float8 blockAxB00 = 0;\ - float8 blockAxB01 = 0;\ - float8 blockAxB02 = 0;\ - float8 blockAxB03 = 0;\ + Dtype8 blockAxB00 = 0;\ + Dtype8 blockAxB01 = 0;\ + Dtype8 blockAxB02 = 0;\ + Dtype8 blockAxB03 = 0;\ int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\ int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\ do\ {\ int2 coordBTemp = coordB;\ - float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\ + Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\ int2 coordATemp = coordA;\ - float8 blockA00 = as_half8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\ - float8 blockA01 = as_half8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ + Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\ + Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \ @@ -347,30 +348,30 @@ __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ __read_only image2d_t A, \ __read_only image2d_t B, \ MATC_PARAMETER, \ - float alpha_in, \ - float beta_in, \ + KERNEL_ARG_DTYPE alpha_in, \ + KERNEL_ARG_DTYPE beta_in, \ int width0, \ int isFirstColBlock) \ { \ - const float alpha = (float)alpha_in; \ - const float beta = (float)beta_in; \ + const Dtype alpha = (Dtype)alpha_in; \ + const Dtype beta = (Dtype)beta_in; \ const int group_x = get_group_id(0);\ const int group_y = get_group_id(1);\ - float8 blockAxB00 = 0.0f;\ - float8 blockAxB01 = 0.0f;\ - float8 blockAxB02 = 0.0f;\ - float8 blockAxB03 = 0.0f;\ + Dtype8 blockAxB00 = 0.0f;\ + Dtype8 blockAxB01 = 0.0f;\ + Dtype8 blockAxB02 = 0.0f;\ + Dtype8 blockAxB03 = 0.0f;\ int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\ int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\ do\ {\ int2 coordBTemp = coordB;\ - float8 blockB00 = as_float8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\ + Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\ int2 coordATemp = coordA;\ - float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ - float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ - float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ - float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ + Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ + Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ + Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ + Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0 ); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, 0 ); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, 0 ); \ @@ -388,10 +389,11 @@ GEMM_TN(0, 1) // ALPHA != 1, BETA != 0 #undef MULTIPLY_BLOCKS_8x8 #undef TRANSPOSE_BLOCK_8 +#undef GEMM_TN // The same as GEMM_NN #define TRANSPOSE_BLOCK_8( _block, _col ) \ - (float8)( intel_sub_group_shuffle( _block.s0, _col), \ + (Dtype8)( intel_sub_group_shuffle( _block.s0, _col), \ intel_sub_group_shuffle( _block.s1, _col), \ intel_sub_group_shuffle( _block.s2, _col), \ intel_sub_group_shuffle( _block.s3, _col), \ @@ -403,58 +405,58 @@ GEMM_TN(0, 1) // ALPHA != 1, BETA != 0 #if TYPE == TYPE_HALF #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ { \ - const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ - const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ - const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ - const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ - const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ - const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ - const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ - const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ - const float8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \ - const float8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \ - const float8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \ - const float8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \ - const float8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \ - const float8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \ - const float8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \ - const float8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \ - _result = mad( (float8)_blockB.s0, acol0, _result ); \ - _result = mad( (float8)_blockB.s1, acol1, _result ); \ - _result = mad( (float8)_blockB.s2, acol2, _result ); \ - _result = mad( (float8)_blockB.s3, acol3, _result ); \ - _result = mad( (float8)_blockB.s4, acol4, _result ); \ - _result = mad( (float8)_blockB.s5, acol5, _result ); \ - _result = mad( (float8)_blockB.s6, acol6, _result ); \ - _result = mad( (float8)_blockB.s7, acol7, _result ); \ - _result = mad( (float8)_blockB.s8, acol8, _result ); \ - _result = mad( (float8)_blockB.s9, acol9, _result ); \ - _result = mad( (float8)_blockB.sa, acola, _result ); \ - _result = mad( (float8)_blockB.sb, acolb, _result ); \ - _result = mad( (float8)_blockB.sc, acolc, _result ); \ - _result = mad( (float8)_blockB.sd, acold, _result ); \ - _result = mad( (float8)_blockB.se, acole, _result ); \ - _result = mad( (float8)_blockB.sf, acolf, _result ); \ + const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ + const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ + const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ + const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ + const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ + const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ + const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ + const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ + const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \ + const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \ + const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \ + const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \ + const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \ + const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \ + const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \ + const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \ + _result = mad( (Dtype8)_blockB.s0, acol0, _result ); \ + _result = mad( (Dtype8)_blockB.s1, acol1, _result ); \ + _result = mad( (Dtype8)_blockB.s2, acol2, _result ); \ + _result = mad( (Dtype8)_blockB.s3, acol3, _result ); \ + _result = mad( (Dtype8)_blockB.s4, acol4, _result ); \ + _result = mad( (Dtype8)_blockB.s5, acol5, _result ); \ + _result = mad( (Dtype8)_blockB.s6, acol6, _result ); \ + _result = mad( (Dtype8)_blockB.s7, acol7, _result ); \ + _result = mad( (Dtype8)_blockB.s8, acol8, _result ); \ + _result = mad( (Dtype8)_blockB.s9, acol9, _result ); \ + _result = mad( (Dtype8)_blockB.sa, acola, _result ); \ + _result = mad( (Dtype8)_blockB.sb, acolb, _result ); \ + _result = mad( (Dtype8)_blockB.sc, acolc, _result ); \ + _result = mad( (Dtype8)_blockB.sd, acold, _result ); \ + _result = mad( (Dtype8)_blockB.se, acole, _result ); \ + _result = mad( (Dtype8)_blockB.sf, acolf, _result ); \ } #else #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ { \ - const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ - const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ - const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ - const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ - const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ - const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ - const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ - const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ - _result = mad( (float8)_blockB.s0, acol0, _result ); \ - _result = mad( (float8)_blockB.s1, acol1, _result ); \ - _result = mad( (float8)_blockB.s2, acol2, _result ); \ - _result = mad( (float8)_blockB.s3, acol3, _result ); \ - _result = mad( (float8)_blockB.s4, acol4, _result ); \ - _result = mad( (float8)_blockB.s5, acol5, _result ); \ - _result = mad( (float8)_blockB.s6, acol6, _result ); \ - _result = mad( (float8)_blockB.s7, acol7, _result ); \ + const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ + const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ + const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ + const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ + const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ + const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ + const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ + const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ + _result = mad( (Dtype8)_blockB.s0, acol0, _result ); \ + _result = mad( (Dtype8)_blockB.s1, acol1, _result ); \ + _result = mad( (Dtype8)_blockB.s2, acol2, _result ); \ + _result = mad( (Dtype8)_blockB.s3, acol3, _result ); \ + _result = mad( (Dtype8)_blockB.s4, acol4, _result ); \ + _result = mad( (Dtype8)_blockB.s5, acol5, _result ); \ + _result = mad( (Dtype8)_blockB.s6, acol6, _result ); \ + _result = mad( (Dtype8)_blockB.s7, acol7, _result ); \ } #endif @@ -466,32 +468,32 @@ __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dt __read_only image2d_t A, \ MATB_PARAMETER, \ MATC_PARAMETER, \ - float alpha_in, \ - float beta_in, \ + KERNEL_ARG_DTYPE alpha_in, \ + KERNEL_ARG_DTYPE beta_in, \ int padded_k, \ int k, \ int isFirstColBlock) \ { \ - const float alpha = (float)alpha_in; \ - const float beta = (float)beta_in; \ + const Dtype alpha = (Dtype)alpha_in; \ + const Dtype beta = (Dtype)beta_in; \ const int group_x = get_group_id(0); \ const int group_y = get_group_id(1); \ - float8 blockAxB00 = 0; \ - float8 blockAxB01 = 0; \ - float8 blockAxB02 = 0; \ - float8 blockAxB03 = 0; \ + Dtype8 blockAxB00 = 0; \ + Dtype8 blockAxB01 = 0; \ + Dtype8 blockAxB02 = 0; \ + Dtype8 blockAxB03 = 0; \ int2 coordA = (int2)( 0, group_y * TILE_M ); \ int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ do \ { \ - float16 blockB00; \ + Dtype16 blockB00; \ BLOCKB_READ8(blockB00, B, coordB); \ int2 coordATemp = coordA; \ - float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \ + Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ @@ -508,32 +510,32 @@ __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dt __read_only image2d_t A, \ MATB_PARAMETER, \ MATC_PARAMETER, \ - float alpha_in, \ - float beta_in, \ + KERNEL_ARG_DTYPE alpha_in, \ + KERNEL_ARG_DTYPE beta_in, \ int padded_k, \ int k, \ int isFirstColBlock) \ { \ - const float alpha = (float)alpha_in; \ - const float beta = (float)beta_in; \ + const Dtype alpha = (Dtype)alpha_in; \ + const Dtype beta = (Dtype)beta_in; \ const int group_x = get_group_id(0); \ const int group_y = get_group_id(1); \ - float8 blockAxB00 = 0.0f; \ - float8 blockAxB01 = 0.0f; \ - float8 blockAxB02 = 0.0f; \ - float8 blockAxB03 = 0.0f; \ + Dtype8 blockAxB00 = 0.0f; \ + Dtype8 blockAxB01 = 0.0f; \ + Dtype8 blockAxB02 = 0.0f; \ + Dtype8 blockAxB03 = 0.0f; \ int2 coordA = (int2)( 0, group_y * TILE_M ); \ int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ do \ { \ - float8 blockB00; \ + Dtype8 blockB00; \ BLOCKB_READ8(blockB00, B, coordB); \ int2 coordATemp = coordA; \ - float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ - float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \ + Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ + Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ @@ -574,18 +576,18 @@ GEMM_NT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0 int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \ - _blockb = as_half16(as_ushort16(vload8(0, B_read))); \ + _blockb = as_Dtype16(as_ushort16(vload8(0, B_read))); \ _coordB.x += TILE_K * 2; #else #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ - const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \ + const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \ _blockb = vload8(0, B_read); \ _coordB.x += TILE_K; #endif -#define MATB_PARAMETER __global float *B, int offB, int ldb +#define MATB_PARAMETER __global Dtype *B, int offB, int ldb GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0 GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0 @@ -598,7 +600,7 @@ GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0 #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ - float4 temp; \ + Dtype4 temp; \ temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s0 = temp.s0; \ temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ @@ -636,7 +638,7 @@ GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0 #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ - float4 temp; \ + Dtype4 temp; \ temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s0 = temp.s0; \ temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ @@ -667,10 +669,11 @@ GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 #undef MULTIPLY_BLOCKS_8x8 #undef TRANSPOSE_BLOCK_8 +#undef GEMM_NT //The same as GEMM_TN. #define TRANSPOSE_BLOCK_8(_vec, _col) \ - (float8)( intel_sub_group_shuffle(_vec, _col + 0), \ + (Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \ intel_sub_group_shuffle(_vec, _col + 1), \ intel_sub_group_shuffle(_vec, _col + 2), \ intel_sub_group_shuffle(_vec, _col + 3), \ @@ -681,22 +684,22 @@ GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \ { \ - const float8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); \ - const float8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); \ - const float8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); \ - const float8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); \ - const float8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); \ - const float8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); \ - const float8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); \ - const float8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); \ - _result = mad( (float8)_blockB.s0, acol0, _result ); \ - _result = mad( (float8)_blockB.s1, acol1, _result ); \ - _result = mad( (float8)_blockB.s2, acol2, _result ); \ - _result = mad( (float8)_blockB.s3, acol3, _result ); \ - _result = mad( (float8)_blockB.s4, acol4, _result ); \ - _result = mad( (float8)_blockB.s5, acol5, _result ); \ - _result = mad( (float8)_blockB.s6, acol6, _result ); \ - _result = mad( (float8)_blockB.s7, acol7, _result ); \ + const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); \ + const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); \ + const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); \ + const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); \ + const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); \ + const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); \ + const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); \ + const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); \ + _result = mad( (Dtype8)_blockB.s0, acol0, _result ); \ + _result = mad( (Dtype8)_blockB.s1, acol1, _result ); \ + _result = mad( (Dtype8)_blockB.s2, acol2, _result ); \ + _result = mad( (Dtype8)_blockB.s3, acol3, _result ); \ + _result = mad( (Dtype8)_blockB.s4, acol4, _result ); \ + _result = mad( (Dtype8)_blockB.s5, acol5, _result ); \ + _result = mad( (Dtype8)_blockB.s6, acol6, _result ); \ + _result = mad( (Dtype8)_blockB.s7, acol7, _result ); \ } #if TYPE == TYPE_HALF @@ -707,30 +710,30 @@ __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, D __read_only image2d_t A, \ MATB_PARAMETER, \ MATC_PARAMETER, \ - float alpha_in, \ - float beta_in, \ + KERNEL_ARG_DTYPE alpha_in, \ + KERNEL_ARG_DTYPE beta_in, \ int padded_k, \ int k, \ int isFirstColBlock) \ { \ - const float alpha = (float)alpha_in; \ - const float beta = (float)beta_in; \ + const Dtype alpha = (Dtype)alpha_in; \ + const Dtype beta = (Dtype)beta_in; \ const int group_x = get_group_id(0); \ const int group_y = get_group_id(1); \ - float8 blockAxB00 = 0; \ - float8 blockAxB01 = 0; \ - float8 blockAxB02 = 0; \ - float8 blockAxB03 = 0; \ + Dtype8 blockAxB00 = 0; \ + Dtype8 blockAxB01 = 0; \ + Dtype8 blockAxB02 = 0; \ + Dtype8 blockAxB03 = 0; \ int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \ int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ do \ { \ - float8 blockB00; \ + Dtype8 blockB00; \ BLOCKB_READ8(blockB00, B, coordB); \ int2 coordATemp = coordA; \ - float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\ - float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ + Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\ + Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \ @@ -747,32 +750,32 @@ __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, D __read_only image2d_t A, \ MATB_PARAMETER, \ MATC_PARAMETER, \ - float alpha_in, \ - float beta_in, \ + KERNEL_ARG_DTYPE alpha_in, \ + KERNEL_ARG_DTYPE beta_in, \ int padded_k, \ int k, \ int isFirstColBlock) \ { \ - const float alpha = (float)alpha_in; \ - const float beta = (float)beta_in; \ + const Dtype alpha = (Dtype)alpha_in; \ + const Dtype beta = (Dtype)beta_in; \ const int group_x = get_group_id(0); \ const int group_y = get_group_id(1); \ - float8 blockAxB00 = 0.0f; \ - float8 blockAxB01 = 0.0f; \ - float8 blockAxB02 = 0.0f; \ - float8 blockAxB03 = 0.0f; \ + Dtype8 blockAxB00 = 0.0f; \ + Dtype8 blockAxB01 = 0.0f; \ + Dtype8 blockAxB02 = 0.0f; \ + Dtype8 blockAxB03 = 0.0f; \ int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \ int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ do \ { \ - float8 blockB00; \ + Dtype8 blockB00; \ BLOCKB_READ8(blockB00, B, coordB); \ int2 coordATemp = coordA; \ - float8 blockA00 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ - float8 blockA01 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ - float8 blockA02 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ - float8 blockA03 = as_float8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; \ + Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ + Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ + Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ + Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; \ MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00, 0 ); \ MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00, 0 ); \ MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00, 0 ); \ @@ -803,18 +806,18 @@ GEMM_TT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0 int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \ - _blockb = as_half8(as_ushort8(vload4(0, B_read))); \ + _blockb = as_Dtype8(as_ushort8(vload4(0, B_read))); \ _coordB.x += TILE_K; #else #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ - const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \ + const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \ _blockb = vload8(0, B_read); \ _coordB.x += TILE_K; #endif -#define MATB_PARAMETER __global float *B, int offB, int ldb +#define MATB_PARAMETER __global Dtype *B, int offB, int ldb GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0 GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0 @@ -826,7 +829,7 @@ GEMM_TT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0 #define BLOCKB_READ8(_blockb, _B, _coordB) \ int2 _coordBTemp = _coordB; \ _coordBTemp.y += get_local_id(0); \ - float4 temp; \ + Dtype4 temp; \ temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ _blockb.s0 = temp.s0; \ temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ @@ -856,13 +859,17 @@ GEMM_TT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 #undef MULTIPLY_BLOCKS_8x8 #undef TRANSPOSE_BLOCK_8 +#undef GEMM_TT #undef TILE_M #undef TILE_K #undef TILE_N +#undef SUBGROUP_BLOCK_READ8 +#undef READ_IMAGE +#undef SIZE_OF_ELEMENT __kernel void TEMPLATE(gemm_buffer_copy_image_transpose, Dtype)( - __global float* A, + __global Dtype* A, __write_only image2d_t ImA, int offA, int width, @@ -872,17 +879,17 @@ __kernel void TEMPLATE(gemm_buffer_copy_image_transpose, Dtype)( const int gidx = get_global_id(0); const int gidy = get_global_id(1); int2 coord_dst = (int2)(gidx, gidy); - __global float* A_off = A + offA; - float srcA = A_off[gidy * ldA + gidx]; + __global Dtype* A_off = A + offA; + Dtype srcA = A_off[gidy * ldA + gidx]; #if TYPE == TYPE_HALF - write_imageh(ImA, coord_dst, (float4)srcA); + write_imageh(ImA, coord_dst, (Dtype4)srcA); #else - write_imagef(ImA, coord_dst, (float4)srcA); + write_imagef(ImA, coord_dst, (Dtype4)srcA); #endif } __kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)( - __global float* A, + __global Dtype* A, __write_only image2d_t ImA, int offA, int width, @@ -897,14 +904,14 @@ __kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)( write_imageh(ImA, coord_dst, 0); return; } - __global float* A_off = A + offA; + __global Dtype* A_off = A + offA; write_imageh(ImA, coord_dst, A_off[gidy * ldA + gidx]); #else if (gidx >= width || gidy >= height) { write_imageui(ImA, coord_dst, (uint4)0); return; } - __global float* A_off = A + offA; + __global Dtype* A_off = A + offA; uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * ldA + gidx])); write_imageui(ImA, coord_dst, srcA); #endif @@ -925,18 +932,18 @@ __kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)( __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1))) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( - const __global float *src0, int off0, - const __global float *src1, int off1, - __global float *dst, int offd, + const __global Dtype *src0, int off0, + const __global Dtype *src1, int off1, + __global Dtype *dst, int offd, int M, int N, int K, - float alpha_in, - float beta_in, + KERNEL_ARG_DTYPE alpha_in, + KERNEL_ARG_DTYPE beta_in, int start_index) { - const float alpha = (float)alpha_in; - const float beta = (float)beta_in; + const Dtype alpha = (Dtype)alpha_in; + const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); const int local_x = get_local_id(0); @@ -944,14 +951,14 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( const int global_x = get_global_id(0); const int global_y = get_global_id(1); - float4 brow; - float2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7; + Dtype4 brow; + Dtype2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7; - __global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + __global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; - const __global float *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0; + const __global Dtype *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0; - const __global float *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1; + const __global Dtype *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1; int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M); @@ -964,14 +971,14 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border; int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border; - float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); - float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N); - float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); - float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); - float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); - float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); - float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); - float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); + Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); + Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N); + Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); + Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); + Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); + Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); + Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); + Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); int end_index = min(start_index + 256, K); int w = start_index; @@ -987,14 +994,14 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( #define MM_DOT_PRODUCT( index, suffix ) \ brow = vload4(0, src1_read0); src1_read0 += N; \ - dot00 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \ - dot01 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \ - dot02 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \ - dot03 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \ - dot04 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \ - dot05 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \ - dot06 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \ - dot07 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 ); \ + dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \ + dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \ + dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \ + dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \ + dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \ + dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \ + dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \ + dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 ); \ MM_DOT_PRODUCT(0, 0); MM_DOT_PRODUCT(0, 1); @@ -1055,15 +1062,15 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f; #define MM_DOT_PRODUCT( index, suffix ) \ - brow = (w < K) ? vload4(0, src1_read0) : (float4)0.0f; src1_read0 += N; w++; \ - dot00 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \ - dot01 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \ - dot02 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \ - dot03 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \ - dot04 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \ - dot05 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \ - dot06 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \ - dot07 = mad( (float4)(as_float2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 ); \ + brow = (w < K) ? vload4(0, src1_read0) : (Dtype4)0.0f; src1_read0 += N; w++; \ + dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \ + dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \ + dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \ + dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \ + dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \ + dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \ + dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \ + dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 ); \ MM_DOT_PRODUCT(0, 0); MM_DOT_PRODUCT(0, 1); @@ -1213,17 +1220,17 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( __attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) __attribute__((intel_reqd_sub_group_size(8))) __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( - const __global float *src0, int off0, - const __global float *src1, int off1, - __global float *dst, int offd, + const __global Dtype *src0, int off0, + const __global Dtype *src1, int off1, + __global Dtype *dst, int offd, int M, int N, int K, - float alpha_in, - float beta_in) + KERNEL_ARG_DTYPE alpha_in, + KERNEL_ARG_DTYPE beta_in) { - const float alpha = (float)alpha_in; - const float beta = (float)beta_in; + const Dtype alpha = (Dtype)alpha_in; + const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); const int local_x = get_local_id(0); @@ -1231,32 +1238,32 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( const int global_x = get_global_id(0); const int global_y = get_global_id(1); - float8 dot00 = 0.f; - float8 dot01 = 0.f; - float8 dot02 = 0.f; - float8 dot03 = 0.f; - float8 dot04 = 0.f; - float8 dot05 = 0.f; - float8 dot06 = 0.f; - float8 dot07 = 0.f; + Dtype8 dot00 = 0.f; + Dtype8 dot01 = 0.f; + Dtype8 dot02 = 0.f; + Dtype8 dot03 = 0.f; + Dtype8 dot04 = 0.f; + Dtype8 dot05 = 0.f; + Dtype8 dot06 = 0.f; + Dtype8 dot07 = 0.f; - float4 brow0; - float4 brow1; - float4 brow2; - float4 brow3; - float4 brow4; - float4 brow5; - float4 brow6; - float4 brow7; + Dtype4 brow0; + Dtype4 brow1; + Dtype4 brow2; + Dtype4 brow3; + Dtype4 brow4; + Dtype4 brow5; + Dtype4 brow6; + Dtype4 brow7; - __global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + __global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; - const __global float *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0; + const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0; - const __global float *src1_read0 = src1 + (group_x * TILE_N) * K + off1; + const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1; - __local float slm_brow[8 * SLM_BLOCK]; - __local float* slm_brow0; + __local Dtype slm_brow[8 * SLM_BLOCK]; + __local Dtype* slm_brow0; int local_index = mad24(local_y, 8, local_x) * 4; int w; @@ -1276,7 +1283,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( w = b_tile; int end_w = min(b_tile + SLM_BLOCK, K); while( w + TILE_K <= end_w ) { - float4 arow; + Dtype4 arow; brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK); brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK); @@ -1289,10 +1296,10 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( #define MM_DOT_PRODUCT( _row, _dot ) \ arow = vload4(0, src0_read + _row * K); \ - _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ - _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ - _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ - _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \ + _dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ + _dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ + _dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ + _dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \ MM_DOT_PRODUCT( 0, dot00 ); MM_DOT_PRODUCT( 1, dot01 ); @@ -1312,7 +1319,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( } if(w < K) { - float4 arow; + Dtype4 arow; #define READ_BROW(_brow, _row) \ _brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \ @@ -1336,10 +1343,10 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \ arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \ arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \ - _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ - _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ - _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ - _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \ + _dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ + _dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ + _dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ + _dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \ MM_DOT_PRODUCT( 0, dot00 ); MM_DOT_PRODUCT( 1, dot01 ); @@ -1353,8 +1360,8 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( } #define REDUCE(_dot) \ - _dot = as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \ - as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_float8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7)); \ + _dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \ + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7)); \ REDUCE(dot00); REDUCE(dot01); @@ -1366,7 +1373,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( REDUCE(dot07); #undef REDUCE - float output = 0.0f; + Dtype output = 0.0f; #define OUTPUT( _dot) \ output = (local_x == 0) ? _dot.s0 : output; \ output = (local_x == 1) ? _dot.s1 : output; \ @@ -1401,32 +1408,32 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( #define SLM_SIZE 64 void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( - const __global float* srca_read0, - const __global float* srca_read1, - const __global float* srcb_read, - __local float4* work0, - __local float4* work1, + const __global Dtype* srca_read0, + const __global Dtype* srca_read1, + const __global Dtype* srcb_read, + __local Dtype4* work0, + __local Dtype4* work1, int N, int K, int x_gid, int lid, - float alpha, - float beta, - __global float* dstc0, - __global float* dstc1) + Dtype alpha, + Dtype beta, + __global Dtype* dstc0, + __global Dtype* dstc1) { - __local float* work_each0 = (__local float*)work0; - __local float* work_each1 = (__local float*)work1; + __local Dtype* work_each0 = (__local Dtype*)work0; + __local Dtype* work_each1 = (__local Dtype*)work1; int rows = N - x_gid * 4; - float4 dot0[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; - float4 dot1[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; + Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; int i = lid; while( i < K / 4) { - const float4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; - const float4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; + const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; + const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; #pragma unroll for(int j = 0; j < rows; ++j) { dot0[j] += b0 * vload4(i, srcb_read + j * K); @@ -1445,13 +1452,13 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( short tail_items = K % 4; if(tail_items != 0) { - const __global float *srcb_tail = srcb_read + i * 4; - const __global float *srca_tail0 = srca_read0 + i * 4; - const __global float *srca_tail1 = srca_read1 + i * 4; + const __global Dtype *srcb_tail = srcb_read + i * 4; + const __global Dtype *srca_tail0 = srca_read0 + i * 4; + const __global Dtype *srca_tail1 = srca_read1 + i * 4; #pragma unroll for(short i = 0; i < tail_items; ++i) { - const float at0 = srca_tail0[i]; - const float at1 = srca_tail1[i]; + const Dtype at0 = srca_tail0[i]; + const Dtype at1 = srca_tail1[i]; #pragma unroll for(int j = 0; j < rows; ++j) { work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K]; @@ -1479,49 +1486,49 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( } __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( - __global const float * A, + __global const Dtype * A, int offA, - __global const float * B, + __global const Dtype * B, int offB, - __global float * C, + __global Dtype * C, int offC, int M, int N, int K, - float alpha_f, - float beta_f) + KERNEL_ARG_DTYPE alpha_f, + KERNEL_ARG_DTYPE beta_f) { - float alpha = (float)alpha_f; - float beta = (float)beta_f; + Dtype alpha = (Dtype)alpha_f; + Dtype beta = (Dtype)beta_f; int x_gid = get_group_id(0); int lid = get_local_id(0); - const __global float *srca_read0 = A + offA; - const __global float *srca_read1 = srca_read0 + K; + const __global Dtype *srca_read0 = A + offA; + const __global Dtype *srca_read1 = srca_read0 + K; - const __global float *srcb_read = B + x_gid * 4 * K + offB; + const __global Dtype *srcb_read = B + x_gid * 4 * K + offB; - __global float4 *dstc0 = (__global float4*)(C + offC); - __global float4 *dstc1 = (__global float4*)((__global float*)(dstc0) + N); + __global Dtype4 *dstc0 = (__global Dtype4*)(C + offC); + __global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N); - __local float4 work0[SLM_SIZE]; - __local float4 work1[SLM_SIZE]; - __local float* work_each0 = (__local float*)work0; - __local float* work_each1 = (__local float*)work1; + __local Dtype4 work0[SLM_SIZE]; + __local Dtype4 work1[SLM_SIZE]; + __local Dtype* work_each0 = (__local Dtype*)work0; + __local Dtype* work_each1 = (__local Dtype*)work1; if(x_gid == N / 4) { TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) \ - (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global float*)dstc0, (__global float*)dstc1); + (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1); } else { - float4 dot0[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; - float4 dot1[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; + Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; int i = lid; while( i < K / 4) { - const float4 b0 = vload4(i, srca_read0); - const float4 b1 = vload4(i, srca_read1); + const Dtype4 b0 = vload4(i, srca_read0); + const Dtype4 b1 = vload4(i, srca_read1); #pragma unroll for(int j = 0; j < 4; ++j) { - float4 a = vload4(i, srcb_read + j * K); + Dtype4 a = vload4(i, srcb_read + j * K); dot0[j] += b0 * a; dot1[j] += b1 * a; } @@ -1537,14 +1544,14 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( if(i == K / 4) { short tail_items = K % 4; if(tail_items != 0) { - const __global float *srcb_tail = srcb_read + i * 4; + const __global Dtype *srcb_tail = srcb_read + i * 4; - const __global float *srca_tail0 = srca_read0 + i * 4; - const __global float *srca_tail1 = srca_read1 + i * 4; + const __global Dtype *srca_tail0 = srca_read0 + i * 4; + const __global Dtype *srca_tail1 = srca_read1 + i * 4; #pragma unroll for(short i = 0; i < tail_items; ++i) { - const float at0 = srca_tail0[i]; - const float at1 = srca_tail1[i]; + const Dtype at0 = srca_tail0[i]; + const Dtype at1 = srca_tail1[i]; #pragma unroll for(int j = 0; j < 4; ++j) { work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K]; @@ -1572,44 +1579,44 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( #define SLM_SIZE 32 void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)( - const __global float* srca_read0, - const __global float* srca_read1, - const __global float* srca_read2, - const __global float* srca_read3, - const __global float* srcb_read, - __local float4* work0, - __local float4* work1, - __local float4* work2, - __local float4* work3, + const __global Dtype* srca_read0, + const __global Dtype* srca_read1, + const __global Dtype* srca_read2, + const __global Dtype* srca_read3, + const __global Dtype* srcb_read, + __local Dtype4* work0, + __local Dtype4* work1, + __local Dtype4* work2, + __local Dtype4* work3, int N, int K, int x_gid, int lid, - float alpha, - float beta, - __global float* dstc0, - __global float* dstc1, - __global float* dstc2, - __global float* dstc3) + Dtype alpha, + Dtype beta, + __global Dtype* dstc0, + __global Dtype* dstc1, + __global Dtype* dstc2, + __global Dtype* dstc3) { - __local float* work_each0 = (__local float*)(work0 + lid); - __local float* work_each1 = (__local float*)(work1 + lid); - __local float* work_each2 = (__local float*)(work2 + lid); - __local float* work_each3 = (__local float*)(work3 + lid); + __local Dtype* work_each0 = (__local Dtype*)(work0 + lid); + __local Dtype* work_each1 = (__local Dtype*)(work1 + lid); + __local Dtype* work_each2 = (__local Dtype*)(work2 + lid); + __local Dtype* work_each3 = (__local Dtype*)(work3 + lid); int rows = N - x_gid * 4; - float4 dot0[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; - float4 dot1[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; - float4 dot2[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; - float4 dot3[3] = {(float4)(0.), (float4)(0.), (float4)(0.)}; + Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; int i = lid; while( i < K / 4) { - const float4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; - const float4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; - const float4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]}; - const float4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]}; + const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]}; + const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; + const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]}; + const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]}; #pragma unrol for(int j = 0; j < rows; ++j) { dot0[j] += a0 * vload4(i, srcb_read + j * K); @@ -1632,18 +1639,18 @@ void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)( short tail_items = K % 4; if(tail_items != 0) { - const __global float *srcb_tail = srcb_read + i * 4; + const __global Dtype *srcb_tail = srcb_read + i * 4; - const __global float *srca_tail0 = srca_read0 + i * 4; - const __global float *srca_tail1 = srca_read1 + i * 4; - const __global float *srca_tail2 = srca_read2 + i * 4; - const __global float *srca_tail3 = srca_read3 + i * 4; + const __global Dtype *srca_tail0 = srca_read0 + i * 4; + const __global Dtype *srca_tail1 = srca_read1 + i * 4; + const __global Dtype *srca_tail2 = srca_read2 + i * 4; + const __global Dtype *srca_tail3 = srca_read3 + i * 4; #pragma unroll for(short i = 0; i < tail_items; ++i) { - const float at0 = srca_tail0[i]; - const float at1 = srca_tail1[i]; - const float at2 = srca_tail2[i]; - const float at3 = srca_tail3[i]; + const Dtype at0 = srca_tail0[i]; + const Dtype at1 = srca_tail1[i]; + const Dtype at2 = srca_tail2[i]; + const Dtype at3 = srca_tail3[i]; #pragma unroll for(int j = 0; j < rows; ++j) { work_each0[j] += at0 * srcb_tail[i + j * K]; @@ -1677,65 +1684,65 @@ void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)( } __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( - __global const float * A, + __global const Dtype * A, int offA, - __global const float * B, + __global const Dtype * B, int offB, - __global float * C, + __global Dtype * C, int offC, int M, int N, int K, - float alpha_f, - float beta_f) + KERNEL_ARG_DTYPE alpha_f, + KERNEL_ARG_DTYPE beta_f) { - float alpha = (float)alpha_f; - float beta = (float)beta_f; + Dtype alpha = (Dtype)alpha_f; + Dtype beta = (Dtype)beta_f; int x_gid = get_group_id(0); int lid = get_local_id(0); int lsize = get_local_size(0); - const __global float *srca_read0 = A + offA; - const __global float *srca_read1 = srca_read0 + K; - const __global float *srca_read2 = srca_read1 + K; - const __global float *srca_read3 = srca_read2 + K; + const __global Dtype *srca_read0 = A + offA; + const __global Dtype *srca_read1 = srca_read0 + K; + const __global Dtype *srca_read2 = srca_read1 + K; + const __global Dtype *srca_read3 = srca_read2 + K; - const __global float *srcb_read = B + x_gid * 4 * K + offB; + const __global Dtype *srcb_read = B + x_gid * 4 * K + offB; - __global float4 *dstc0 = (__global float4*)(C + offC); - __global float4 *dstc1 = (__global float4*)((__global float*)(dstc0) + N); - __global float4 *dstc2 = (__global float4*)((__global float*)(dstc1) + N); - __global float4 *dstc3 = (__global float4*)((__global float*)(dstc2) + N); + __global Dtype4 *dstc0 = (__global Dtype4*)(C + offC); + __global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N); + __global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N); + __global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N); - __local float4 work0[SLM_SIZE]; - __local float4 work1[SLM_SIZE]; - __local float4 work2[SLM_SIZE]; - __local float4 work3[SLM_SIZE]; - __local float* work_each0 = (__local float*)(work0 + lid); - __local float* work_each1 = (__local float*)(work1 + lid); - __local float* work_each2 = (__local float*)(work2 + lid); - __local float* work_each3 = (__local float*)(work3 + lid); + __local Dtype4 work0[SLM_SIZE]; + __local Dtype4 work1[SLM_SIZE]; + __local Dtype4 work2[SLM_SIZE]; + __local Dtype4 work3[SLM_SIZE]; + __local Dtype* work_each0 = (__local Dtype*)(work0 + lid); + __local Dtype* work_each1 = (__local Dtype*)(work1 + lid); + __local Dtype* work_each2 = (__local Dtype*)(work2 + lid); + __local Dtype* work_each3 = (__local Dtype*)(work3 + lid); if(x_gid == N / 4) { TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) \ (srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, \ work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, \ - (__global float*)dstc0, (__global float*)dstc1, (__global float*)dstc2, (__global float*)dstc3); + (__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3); } else { - float4 dot0[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; - float4 dot1[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; - float4 dot2[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; - float4 dot3[4] = {(float4)(0.), (float4)(0.), (float4)(0.), (float4)(0.)}; + Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; int kid = lid; while( kid < K / 4) { - const float4 b0 = vload4(kid, srca_read0); - const float4 b1 = vload4(kid, srca_read1); - const float4 b2 = vload4(kid, srca_read2); - const float4 b3 = vload4(kid, srca_read3); + const Dtype4 b0 = vload4(kid, srca_read0); + const Dtype4 b1 = vload4(kid, srca_read1); + const Dtype4 b2 = vload4(kid, srca_read2); + const Dtype4 b3 = vload4(kid, srca_read3); #pragma unroll for(int j = 0; j < 4; ++j) { - float4 a = vload4(kid, srcb_read + j * K); + Dtype4 a = vload4(kid, srcb_read + j * K); dot0[j] += b0 * a; dot1[j] += b1 * a; dot2[j] += b2 * a; @@ -1755,18 +1762,18 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( short tail_items = K % 4; if(tail_items != 0) { int offset = kid << 2; - const __global float *srcb_tail = srcb_read + offset; + const __global Dtype *srcb_tail = srcb_read + offset; - const __global float *srca_tail0 = srca_read0 + offset; - const __global float *srca_tail1 = srca_read1 + offset; - const __global float *srca_tail2 = srca_read2 + offset; - const __global float *srca_tail3 = srca_read3 + offset; + const __global Dtype *srca_tail0 = srca_read0 + offset; + const __global Dtype *srca_tail1 = srca_read1 + offset; + const __global Dtype *srca_tail2 = srca_read2 + offset; + const __global Dtype *srca_tail3 = srca_read3 + offset; #pragma unroll for(short i = 0; i < tail_items; ++i) { - const float at0 = srca_tail0[i]; - const float at1 = srca_tail1[i]; - const float at2 = srca_tail2[i]; - const float at3 = srca_tail3[i]; + const Dtype at0 = srca_tail0[i]; + const Dtype at1 = srca_tail1[i]; + const Dtype at2 = srca_tail2[i]; + const Dtype at3 = srca_tail3[i]; #pragma unroll for(int j = 0; j < 4; ++j) { work_each0[j] += at0 * srcb_tail[i + j * K]; @@ -1800,73 +1807,73 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( #define SLM_SIZE 16 __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( - __global const float * A, + __global const Dtype * A, int offA, - __global const float * B, + __global const Dtype * B, int offB, - __global float * C, + __global Dtype * C, int offC, int M, int N, int K, - float alpha_f, - float beta_f) + KERNEL_ARG_DTYPE alpha_f, + KERNEL_ARG_DTYPE beta_f) { - float alpha = (float)alpha_f; - float beta = (float)beta_f; + Dtype alpha = (Dtype)alpha_f; + Dtype beta = (Dtype)beta_f; int x_gid = get_group_id(0); int lid = get_local_id(0); int lsize = get_local_size(0); - const __global float *srca_read0 = A + offA; - const __global float *srca_read1 = srca_read0 + K; - const __global float *srca_read2 = srca_read1 + K; - const __global float *srca_read3 = srca_read2 + K; - const __global float *srca_read4 = srca_read3 + K; - const __global float *srca_read5 = srca_read4 + K; - const __global float *srca_read6 = srca_read5 + K; - const __global float *srca_read7 = srca_read6 + K; - - const __global float *srcb_read = B + x_gid * K + offB; - - __global float *dstc0 = C + offC; - __global float *dstc1 = dstc0 + N; - __global float *dstc2 = dstc1 + N; - __global float *dstc3 = dstc2 + N; - __global float *dstc4 = dstc3 + N; - __global float *dstc5 = dstc4 + N; - __global float *dstc6 = dstc5 + N; - __global float *dstc7 = dstc6 + N; - - __local float work0[SLM_SIZE]; - __local float work1[SLM_SIZE]; - __local float work2[SLM_SIZE]; - __local float work3[SLM_SIZE]; - __local float work4[SLM_SIZE]; - __local float work5[SLM_SIZE]; - __local float work6[SLM_SIZE]; - __local float work7[SLM_SIZE]; - - float4 dot0 = (float4)(0.); - float4 dot1 = (float4)(0.); - float4 dot2 = (float4)(0.); - float4 dot3 = (float4)(0.); - float4 dot4 = (float4)(0.); - float4 dot5 = (float4)(0.); - float4 dot6 = (float4)(0.); - float4 dot7 = (float4)(0.); + const __global Dtype *srca_read0 = A + offA; + const __global Dtype *srca_read1 = srca_read0 + K; + const __global Dtype *srca_read2 = srca_read1 + K; + const __global Dtype *srca_read3 = srca_read2 + K; + const __global Dtype *srca_read4 = srca_read3 + K; + const __global Dtype *srca_read5 = srca_read4 + K; + const __global Dtype *srca_read6 = srca_read5 + K; + const __global Dtype *srca_read7 = srca_read6 + K; + + const __global Dtype *srcb_read = B + x_gid * K + offB; + + __global Dtype *dstc0 = C + offC; + __global Dtype *dstc1 = dstc0 + N; + __global Dtype *dstc2 = dstc1 + N; + __global Dtype *dstc3 = dstc2 + N; + __global Dtype *dstc4 = dstc3 + N; + __global Dtype *dstc5 = dstc4 + N; + __global Dtype *dstc6 = dstc5 + N; + __global Dtype *dstc7 = dstc6 + N; + + __local Dtype work0[SLM_SIZE]; + __local Dtype work1[SLM_SIZE]; + __local Dtype work2[SLM_SIZE]; + __local Dtype work3[SLM_SIZE]; + __local Dtype work4[SLM_SIZE]; + __local Dtype work5[SLM_SIZE]; + __local Dtype work6[SLM_SIZE]; + __local Dtype work7[SLM_SIZE]; + + Dtype4 dot0 = (Dtype4)(0.); + Dtype4 dot1 = (Dtype4)(0.); + Dtype4 dot2 = (Dtype4)(0.); + Dtype4 dot3 = (Dtype4)(0.); + Dtype4 dot4 = (Dtype4)(0.); + Dtype4 dot5 = (Dtype4)(0.); + Dtype4 dot6 = (Dtype4)(0.); + Dtype4 dot7 = (Dtype4)(0.); int kid = lid; while( kid < K / 4) { - const float4 a0 = vload4(kid, srca_read0); - const float4 a1 = vload4(kid, srca_read1); - const float4 a2 = vload4(kid, srca_read2); - const float4 a3 = vload4(kid, srca_read3); - const float4 a4 = vload4(kid, srca_read4); - const float4 a5 = vload4(kid, srca_read5); - const float4 a6 = vload4(kid, srca_read6); - const float4 a7 = vload4(kid, srca_read7); - float4 b = vload4(kid, srcb_read); + const Dtype4 a0 = vload4(kid, srca_read0); + const Dtype4 a1 = vload4(kid, srca_read1); + const Dtype4 a2 = vload4(kid, srca_read2); + const Dtype4 a3 = vload4(kid, srca_read3); + const Dtype4 a4 = vload4(kid, srca_read4); + const Dtype4 a5 = vload4(kid, srca_read5); + const Dtype4 a6 = vload4(kid, srca_read6); + const Dtype4 a7 = vload4(kid, srca_read7); + Dtype4 b = vload4(kid, srcb_read); dot0 += a0 * b; dot1 += a1 * b; dot2 += a2 * b; @@ -1891,16 +1898,16 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( short tail_items = K % 4; if(tail_items != 0) { int offset = kid << 2; - const __global float *srcb_tail = srcb_read + offset; - - const __global float *srca_tail0 = srca_read0 + offset; - const __global float *srca_tail1 = srca_read1 + offset; - const __global float *srca_tail2 = srca_read2 + offset; - const __global float *srca_tail3 = srca_read3 + offset; - const __global float *srca_tail4 = srca_read4 + offset; - const __global float *srca_tail5 = srca_read5 + offset; - const __global float *srca_tail6 = srca_read6 + offset; - const __global float *srca_tail7 = srca_read7 + offset; + const __global Dtype *srcb_tail = srcb_read + offset; + + const __global Dtype *srca_tail0 = srca_read0 + offset; + const __global Dtype *srca_tail1 = srca_read1 + offset; + const __global Dtype *srca_tail2 = srca_read2 + offset; + const __global Dtype *srca_tail3 = srca_read3 + offset; + const __global Dtype *srca_tail4 = srca_read4 + offset; + const __global Dtype *srca_tail5 = srca_read5 + offset; + const __global Dtype *srca_tail6 = srca_read6 + offset; + const __global Dtype *srca_tail7 = srca_read7 + offset; #pragma unroll for(short item = 0; item < tail_items; ++item) { work0[lid] += srca_tail0[item] * srcb_tail[item]; @@ -1956,19 +1963,19 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1))) __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( - const __global float *src0, int off0, - const __global float *src1, int off1, - __global float *dst, int offd, + const __global Dtype *src0, int off0, + const __global Dtype *src1, int off1, + __global Dtype *dst, int offd, int M, int N, int K, - float alpha_in, - float beta_in, + KERNEL_ARG_DTYPE alpha_in, + KERNEL_ARG_DTYPE beta_in, int start_index) { - const float alpha = (float)alpha_in; - const float beta = (float)beta_in; + const Dtype alpha = (Dtype)alpha_in; + const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); const int local_x = get_local_id(0); @@ -1976,72 +1983,72 @@ __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( const int global_x = get_global_id(0); const int global_y = get_global_id(1); - float4 brow; + Dtype4 brow; - __global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + __global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; - const __global float *src0_read = src0 + (local_x * (TILE_K / SIMD_SIZE_GEMM) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; + const __global Dtype *src0_read = src0 + (local_x * (TILE_K / SIMD_SIZE_GEMM) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; - const __global float *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1; + const __global Dtype *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1; - float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); - float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N); - float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); - float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); - float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); - float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); - float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); - float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); + Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); + Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N); + Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); + Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); + Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); + Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); + Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); + Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); int end_index = min(start_index + 256, K); while( start_index + TILE_K <= end_index ) { - float8 arow0 = alpha * vload8(0, src0_read); - float8 arow1 = alpha * vload8(0, src0_read + M); + Dtype8 arow0 = alpha * vload8(0, src0_read); + Dtype8 arow1 = alpha * vload8(0, src0_read + M); #define MM_DOT_PRODUCT( _arow ) \ brow = vload4(0, src1_read0); src1_read0 += N; \ - dot00 = mad( (float4)(_arow.s0), brow, dot00 ); \ - dot01 = mad( (float4)(_arow.s1), brow, dot01 ); \ - dot02 = mad( (float4)(_arow.s2), brow, dot02 ); \ - dot03 = mad( (float4)(_arow.s3), brow, dot03 ); \ - dot04 = mad( (float4)(_arow.s4), brow, dot04 ); \ - dot05 = mad( (float4)(_arow.s5), brow, dot05 ); \ - dot06 = mad( (float4)(_arow.s6), brow, dot06 ); \ - dot07 = mad( (float4)(_arow.s7), brow, dot07 ); \ - - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) ); + dot00 = mad( (Dtype4)(_arow.s0), brow, dot00 ); \ + dot01 = mad( (Dtype4)(_arow.s1), brow, dot01 ); \ + dot02 = mad( (Dtype4)(_arow.s2), brow, dot02 ); \ + dot03 = mad( (Dtype4)(_arow.s3), brow, dot03 ); \ + dot04 = mad( (Dtype4)(_arow.s4), brow, dot04 ); \ + dot05 = mad( (Dtype4)(_arow.s5), brow, dot05 ); \ + dot06 = mad( (Dtype4)(_arow.s6), brow, dot06 ); \ + dot07 = mad( (Dtype4)(_arow.s7), brow, dot07 ); \ + + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) ); #if TYPE == TYPE_HALF - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) ); #endif #undef MM_DOT_PRODUCT @@ -2050,53 +2057,53 @@ __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( } if(start_index < end_index) { - float8 arow0 = ((start_index + local_x * 2) < K) ? alpha * vload8(0, src0_read) : (float8)0.0f; - float8 arow1 = ((start_index + local_x * 2 + 1) < K) ? alpha * vload8(0, src0_read + M) : (float8)0.0f; + Dtype8 arow0 = ((start_index + local_x * 2) < K) ? alpha * vload8(0, src0_read) : (Dtype8)0.0f; + Dtype8 arow1 = ((start_index + local_x * 2 + 1) < K) ? alpha * vload8(0, src0_read + M) : (Dtype8)0.0f; #define MM_DOT_PRODUCT( _arow ) \ - brow = (start_index < K) ? vload4(0, src1_read0) : (float4)0.0f; src1_read0 += N; start_index++; \ - dot00 = mad( (float4)(_arow.s0), brow, dot00 ); \ - dot01 = mad( (float4)(_arow.s1), brow, dot01 ); \ - dot02 = mad( (float4)(_arow.s2), brow, dot02 ); \ - dot03 = mad( (float4)(_arow.s3), brow, dot03 ); \ - dot04 = mad( (float4)(_arow.s4), brow, dot04 ); \ - dot05 = mad( (float4)(_arow.s5), brow, dot05 ); \ - dot06 = mad( (float4)(_arow.s6), brow, dot06 ); \ - dot07 = mad( (float4)(_arow.s7), brow, dot07 ); \ - - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) ); + brow = (start_index < K) ? vload4(0, src1_read0) : (Dtype4)0.0f; src1_read0 += N; start_index++; \ + dot00 = mad( (Dtype4)(_arow.s0), brow, dot00 ); \ + dot01 = mad( (Dtype4)(_arow.s1), brow, dot01 ); \ + dot02 = mad( (Dtype4)(_arow.s2), brow, dot02 ); \ + dot03 = mad( (Dtype4)(_arow.s3), brow, dot03 ); \ + dot04 = mad( (Dtype4)(_arow.s4), brow, dot04 ); \ + dot05 = mad( (Dtype4)(_arow.s5), brow, dot05 ); \ + dot06 = mad( (Dtype4)(_arow.s6), brow, dot06 ); \ + dot07 = mad( (Dtype4)(_arow.s7), brow, dot07 ); \ + + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )) ); #if TYPE == TYPE_HALF - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) ); - MM_DOT_PRODUCT( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 8 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 8 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 9 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 9 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 10 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 10 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 11 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 11 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 12 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 12 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 13 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 13 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 14 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 14 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 15 )) ); + MM_DOT_PRODUCT( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 15 )) ); #endif #undef MM_DOT_PRODUCT } @@ -2202,19 +2209,19 @@ __kernel void TEMPLATE(gemm_buffer_TN, Dtype)( __attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) __attribute__((intel_reqd_sub_group_size(8))) __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( - const __global float *src0, int off0, - const __global float *src1, int off1, - __global float *dst, int offd, + const __global Dtype *src0, int off0, + const __global Dtype *src1, int off1, + __global Dtype *dst, int offd, int M, int N, int K, - float alpha_in, - float beta_in, + KERNEL_ARG_DTYPE alpha_in, + KERNEL_ARG_DTYPE beta_in, int start_index) { - const float alpha = (float)alpha_in; - const float beta = (float)beta_in; + const Dtype alpha = (Dtype)alpha_in; + const Dtype beta = (Dtype)beta_in; const int group_x = get_group_id(0); const int group_y = get_group_id(1); const int local_x = get_local_id(0); @@ -2222,30 +2229,30 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( const int global_x = get_global_id(0); const int global_y = get_global_id(1); - float8 dot0 = 0.f; - float8 dot1 = 0.f; - float8 dot2 = 0.f; - float8 dot3 = 0.f; + Dtype8 dot0 = 0.f; + Dtype8 dot1 = 0.f; + Dtype8 dot2 = 0.f; + Dtype8 dot3 = 0.f; - float16 brow0; - float16 brow1; - float16 brow2; - float16 brow3; + Dtype16 brow0; + Dtype16 brow1; + Dtype16 brow2; + Dtype16 brow3; - __global float *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + __global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; - const __global float *src0_read = src0 + (local_x * (TILE_K / 8) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; + const __global Dtype *src0_read = src0 + (local_x * (TILE_K / 8) + start_index) * M + group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M + off0; - const __global float *src1_read0 = src1 + (local_x * VEC_SIZE + (group_x * TILE_N)) * K + start_index + off1; + const __global Dtype *src1_read0 = src1 + (local_x * VEC_SIZE + (group_x * TILE_N)) * K + start_index + off1; - float4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); - float4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N); - float4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); - float4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); - float4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); - float4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); - float4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); - float4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); + Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); + Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + N) : beta * vload4(0, dst_write0 + N); + Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); + Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); + Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); + Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); + Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); + Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); int end_index = min(start_index + 256, K); while( start_index + TILE_K <= end_index ) { @@ -2254,26 +2261,26 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( brow2 = vload16(0, src1_read0 + 2 * K); brow3 = vload16(0, src1_read0 + 3 * K); - float8 arow0 = alpha * vload8(0, src0_read); - float8 arow1 = alpha * vload8(0, src0_read + M); + Dtype8 arow0 = alpha * vload8(0, src0_read); + Dtype8 arow1 = alpha * vload8(0, src0_read + M); #define MM_DOT_PRODUCT( _brow, _dot) \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (float8)_brow.s0, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (float8)_brow.s1, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (float8)_brow.s2, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (float8)_brow.s3, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (float8)_brow.s4, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (float8)_brow.s5, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (float8)_brow.s6, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (float8)_brow.s7, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (float8)_brow.s8, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (float8)_brow.s9, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (float8)_brow.sa, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (float8)_brow.sb, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (float8)_brow.sc, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (float8)_brow.sd, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (float8)_brow.se, _dot ); \ - _dot = mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (float8)_brow.sf, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (Dtype8)_brow.s0, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (Dtype8)_brow.s1, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (Dtype8)_brow.s2, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (Dtype8)_brow.s3, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (Dtype8)_brow.s4, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (Dtype8)_brow.s5, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (Dtype8)_brow.s6, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (Dtype8)_brow.s7, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (Dtype8)_brow.s8, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (Dtype8)_brow.s9, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (Dtype8)_brow.sa, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (Dtype8)_brow.sb, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (Dtype8)_brow.sc, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (Dtype8)_brow.sd, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (Dtype8)_brow.se, _dot ); \ + _dot = mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (Dtype8)_brow.sf, _dot ); \ MM_DOT_PRODUCT( brow0, dot0 ); MM_DOT_PRODUCT( brow1, dot1 ); @@ -2292,26 +2299,26 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( brow2 = vload16(0, src1_read0); src1_read0 += K; brow3 = vload16(0, src1_read0); - float8 arow0 = alpha * vload8(0, src0_read); - float8 arow1 = alpha * vload8(0, src0_read + M); + Dtype8 arow0 = alpha * vload8(0, src0_read); + Dtype8 arow1 = alpha * vload8(0, src0_read + M); #define MM_DOT_PRODUCT( _brow, _dot) \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (float8)_brow.s0, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (float8)_brow.s1, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (float8)_brow.s2, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (float8)_brow.s3, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (float8)_brow.s4, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (float8)_brow.s5, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (float8)_brow.s6, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (float8)_brow.s7, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (float8)_brow.s8, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (float8)_brow.s9, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (float8)_brow.sa, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (float8)_brow.sb, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (float8)_brow.sc, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (float8)_brow.sd, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (float8)_brow.se, _dot ) : _dot; \ - _dot = (w++ < K) ? mad( as_float8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (float8)_brow.sf, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 0 )), (Dtype8)_brow.s0, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 0 )), (Dtype8)_brow.s1, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 1 )), (Dtype8)_brow.s2, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 1 )), (Dtype8)_brow.s3, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 2 )), (Dtype8)_brow.s4, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 2 )), (Dtype8)_brow.s5, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 3 )), (Dtype8)_brow.s6, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 3 )), (Dtype8)_brow.s7, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 4 )), (Dtype8)_brow.s8, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 4 )), (Dtype8)_brow.s9, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 5 )), (Dtype8)_brow.sa, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 5 )), (Dtype8)_brow.sb, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 6 )), (Dtype8)_brow.sc, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 6 )), (Dtype8)_brow.sd, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow0), 7 )), (Dtype8)_brow.se, _dot ) : _dot; \ + _dot = (w++ < K) ? mad( as_Dtype8(intel_sub_group_shuffle( SHUFFLE_TYPE8(arow1), 7 )), (Dtype8)_brow.sf, _dot ) : _dot; \ int w = start_index; MM_DOT_PRODUCT( brow0, dot0 ); @@ -2324,14 +2331,14 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( #undef MM_DOT_PRODUCT } - dot00 += (float4)(dot0.s0, dot1.s0, dot2.s0, dot3.s0); - dot01 += (float4)(dot0.s1, dot1.s1, dot2.s1, dot3.s1); - dot02 += (float4)(dot0.s2, dot1.s2, dot2.s2, dot3.s2); - dot03 += (float4)(dot0.s3, dot1.s3, dot2.s3, dot3.s3); - dot04 += (float4)(dot0.s4, dot1.s4, dot2.s4, dot3.s4); - dot05 += (float4)(dot0.s5, dot1.s5, dot2.s5, dot3.s5); - dot06 += (float4)(dot0.s6, dot1.s6, dot2.s6, dot3.s6); - dot07 += (float4)(dot0.s7, dot1.s7, dot2.s7, dot3.s7); + dot00 += (Dtype4)(dot0.s0, dot1.s0, dot2.s0, dot3.s0); + dot01 += (Dtype4)(dot0.s1, dot1.s1, dot2.s1, dot3.s1); + dot02 += (Dtype4)(dot0.s2, dot1.s2, dot2.s2, dot3.s2); + dot03 += (Dtype4)(dot0.s3, dot1.s3, dot2.s3, dot3.s3); + dot04 += (Dtype4)(dot0.s4, dot1.s4, dot2.s4, dot3.s4); + dot05 += (Dtype4)(dot0.s5, dot1.s5, dot2.s5, dot3.s5); + dot06 += (Dtype4)(dot0.s6, dot1.s6, dot2.s6, dot3.s6); + dot07 += (Dtype4)(dot0.s7, dot1.s7, dot2.s7, dot3.s7); if(global_x * 4 < N && global_y * 8 < M) { if(mad24(global_x, 4, 3) < N) { @@ -2424,5 +2431,8 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( #undef TILE_M #undef TILE_K #undef TILE_N +#undef SIMD_SIZE_GEMM +#undef SHUFFLE_TYPE2 +#undef SHUFFLE_TYPE8 #endif diff --git a/src/caffe/greentea/cl_kernels/lrn.cl b/src/caffe/greentea/cl_kernels/lrn.cl index f4f38fcb5f2..93cd1c25e90 100644 --- a/src/caffe/greentea/cl_kernels/lrn.cl +++ b/src/caffe/greentea/cl_kernels/lrn.cl @@ -5,18 +5,18 @@ __kernel void TEMPLATE(lrn_compute_output,Dtype)(const int_tp nthreads, __global const Dtype* in, __global const Dtype* scale, - const Dtype negative_beta, + const KERNEL_ARG_DTYPE negative_beta, __global Dtype* out) { for (int_tp index = get_global_id(0); index < nthreads; index += get_global_size(0)) { - out[index] = in[index] * pow(scale[index], negative_beta); + out[index] = in[index] * pow(scale[index], (Dtype)negative_beta); } } __kernel void TEMPLATE(lrn_fill_scale,Dtype)(const int_tp nthreads, __global const Dtype* in, const int_tp num, const int_tp channels, const int_tp height, const int_tp width, const int_tp size, - const Dtype alpha_over_size, const Dtype k, + const KERNEL_ARG_DTYPE alpha_over_size, const KERNEL_ARG_DTYPE k, __global Dtype* const scale) { for (int_tp index = get_global_id(0); index < nthreads; index += get_global_size(0)) { @@ -67,8 +67,8 @@ __kernel void TEMPLATE(lrn_compute_diff,Dtype)(const int_tp nthreads, __global const Dtype* top_diff, const int_tp num, const int_tp channels, const int_tp height, const int_tp width, const int_tp size, - const Dtype negative_beta, - const Dtype cache_ratio, + const KERNEL_ARG_DTYPE negative_beta, + const KERNEL_ARG_DTYPE cache_ratio, __global Dtype* bottom_diff) { for (int_tp index = get_global_id(0); index < nthreads; index += get_global_size(0)) { @@ -102,7 +102,7 @@ __kernel void TEMPLATE(lrn_compute_diff,Dtype)(const int_tp nthreads, * top_off[(head - size) * step] / scale_off[(head - size) * step]; } bottom_diff_off[(head - post_pad) * step] = top_diff_off[(head - post_pad) - * step] * pow(scale_off[(head - post_pad) * step], negative_beta) + * step] * pow(scale_off[(head - post_pad) * step], (Dtype)negative_beta) - cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio; ++head; } @@ -113,7 +113,7 @@ __kernel void TEMPLATE(lrn_compute_diff,Dtype)(const int_tp nthreads, * top_off[(head - size) * step] / scale_off[(head - size) * step]; } bottom_diff_off[(head - post_pad) * step] = top_diff_off[(head - post_pad) - * step] * pow(scale_off[(head - post_pad) * step], negative_beta) + * step] * pow(scale_off[(head - post_pad) * step], (Dtype)negative_beta) - cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio; ++head; } @@ -136,9 +136,9 @@ __kernel void TEMPLATE(lrn_fuse_pool_max,Dtype)( const int_tp height, const int_tp width, const int_tp tiled_height, int_tp tiled_width, const int_tp size, - const Dtype alpha_over_size, const Dtype k, + const KERNEL_ARG_DTYPE alpha_over_size, const KERNEL_ARG_DTYPE k, __global Dtype* const out, - const Dtype negative_beta, + const KERNEL_ARG_DTYPE negative_beta, const int_tp pool_h, const int_tp pool_w, const int_tp pool_stride_h, int_tp pool_stride_w, const int_tp pooled_height, const int_tp pooled_width, const int_tp tile_pooled_block_h, const int_tp tile_pooled_block_w) { @@ -168,7 +168,7 @@ __kernel void TEMPLATE(lrn_fuse_pool_max,Dtype)( while ( head < channels + post_pad ) { int ph = 0; int cur_out_h = 0; - Dtype output_val = -FLT_MAX; + Dtype output_val = -DTYPE_MAX; // fill the scale at [n, :, h, w] // accumulate values for( int lrn_out_h = 0; lrn_out_h < TILE_H && (lrn_out_h + h) < height; lrn_out_h++) { @@ -184,11 +184,11 @@ __kernel void TEMPLATE(lrn_fuse_pool_max,Dtype)( // compute output. if (head >= post_pad) { scale_val = k + prev_val * alpha_over_size; - Dtype tmp = -FLT_MAX; + Dtype tmp = -DTYPE_MAX; //if (w + get_local_id(1) < width) - tmp = in_off[(head - post_pad) * step + width * lrn_out_h] * native_powr(scale_val, negative_beta); + tmp = in_off[(head - post_pad) * step + width * lrn_out_h] * native_powr(scale_val, (Dtype)negative_beta); - Dtype h_max_val = -FLT_MAX; + Dtype h_max_val = -DTYPE_MAX; int index = (get_local_id(1) * pool_stride_w) % SIMD_WIDTH; for(int i = 0; i < pool_w; i++) { Dtype val = intel_sub_group_shuffle(tmp, index); @@ -233,9 +233,9 @@ __kernel void TEMPLATE(lrn_fuse_pool_max,Dtype)( __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int_tp nthreads, __global const Dtype* in, const int_tp num, const int_tp channels, const int_tp height, const int_tp width, const int_tp size, - const Dtype alpha_over_size, const Dtype k, + const KERNEL_ARG_DTYPE alpha_over_size, const KERNEL_ARG_DTYPE k, __global Dtype* const out, - const Dtype negative_beta) { + const KERNEL_ARG_DTYPE negative_beta) { for (int_tp index = get_global_id(0); index < nthreads; index += get_global_size(0)) { // find out the local offset @@ -265,7 +265,7 @@ __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int_tp nthreads, __global * in_off[(head - size) * step]; } scale_val = k + accum_scale * alpha_over_size; - out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta); + out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta); ++head; } // subtract only @@ -275,7 +275,7 @@ __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int_tp nthreads, __global * in_off[(head - size) * step]; } scale_val = k + accum_scale * alpha_over_size; - out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta); + out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta); ++head; } } @@ -284,10 +284,10 @@ __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int_tp nthreads, __global __kernel void TEMPLATE(lrn_full,Dtype)(const int_tp nthreads, __global const Dtype* in, const int_tp num, const int_tp channels, const int_tp height, const int_tp width, const int_tp size, - const Dtype alpha_over_size, const Dtype k, + const KERNEL_ARG_DTYPE alpha_over_size, const KERNEL_ARG_DTYPE k, __global Dtype* const scale, __global Dtype* const out, - const Dtype negative_beta) { + const KERNEL_ARG_DTYPE negative_beta) { for (int_tp index = get_global_id(0); index < nthreads; index += get_global_size(0)) { // find out the local offset @@ -319,7 +319,7 @@ __kernel void TEMPLATE(lrn_full,Dtype)(const int_tp nthreads, __global const Dty } scale_val = k + accum_scale * alpha_over_size; scale_off[(head - post_pad) * step] = scale_val; - out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta); + out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta); ++head; } // subtract only @@ -330,7 +330,7 @@ __kernel void TEMPLATE(lrn_full,Dtype)(const int_tp nthreads, __global const Dty } scale_val = k + accum_scale * alpha_over_size; scale_off[(head - post_pad) * step] = scale_val; - out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta); + out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta); ++head; } } diff --git a/src/caffe/greentea/cl_kernels/math.cl b/src/caffe/greentea/cl_kernels/math.cl index fadaa289a8d..8e3ff794c8e 100644 --- a/src/caffe/greentea/cl_kernels/math.cl +++ b/src/caffe/greentea/cl_kernels/math.cl @@ -22,7 +22,7 @@ __kernel void TEMPLATE(div,Dtype)(const int_tp n, __global const Dtype* a, } } -__kernel void TEMPLATE(add_scalar,Dtype)(const int_tp N, const Dtype alpha, +__kernel void TEMPLATE(add_scalar,Dtype)(const int_tp N, const KERNEL_ARG_DTYPE alpha, __global Dtype* Y, const int_tp offY) { for (int_tp index = get_global_id(0); index < N; index += get_global_size(0)) { @@ -81,7 +81,7 @@ __kernel void TEMPLATE(sqrt,Dtype)(const int_tp n, __global const Dtype* a, } __kernel void TEMPLATE(powx,Dtype)(const int_tp n, __global const Dtype* a, - const int_tp offa, Dtype alpha, + const int_tp offa, KERNEL_ARG_DTYPE alpha, __global Dtype* y, const int_tp offy) { for (int_tp index = get_global_id(0); index < n; index += get_global_size(0)) { diff --git a/src/caffe/greentea/cl_kernels/matvec_mul.cl b/src/caffe/greentea/cl_kernels/matvec_mul.cl index dee7779ce9c..97c11c46a08 100644 --- a/src/caffe/greentea/cl_kernels/matvec_mul.cl +++ b/src/caffe/greentea/cl_kernels/matvec_mul.cl @@ -1,152 +1,177 @@ -#ifndef __OPENCL_VERSION__ -#include "header.cl" -#endif - -__kernel void TEMPLATE(matvec_mul4,Dtype)( - __global const float * A, - int offA, - unsigned int A_col_size, - unsigned int trail_item, - __global const float * v, - int offv, - float alpha, - float beta, - __global float4 * result, - int offr, - __local float4 * work) +void TEMPLATE(matvec_mul_trail_rows,Dtype)(unsigned int M, + unsigned int N, + int row_gid, + int lid, + const __global Dtype* src0_read, + int lda, + const __global Dtype* src1_read, + int incv, + __local Dtype4* work, + Dtype alpha, + Dtype beta, + __global Dtype* result, + int incr) { - unsigned int row_gid = get_group_id(0); - unsigned int lid = get_local_id(0); - const __global float *src0_read = A + row_gid * 4 * A_col_size + offA; - const __global float *src1_read = v + offv; - result = (__global float4*)((__global float*)result + offr); - float4 dot0 = (float4)(0.f); - float4 dot1 = (float4)(0.f); - float4 dot2 = (float4)(0.f); - float4 dot3 = (float4)(0.f); - - unsigned int i = lid; - while( i < A_col_size / 4) { - const float4 a0 = vload4(i, src0_read); - const float4 a1 = vload4(i, src0_read + A_col_size); - const float4 a2 = vload4(i, src0_read + 2 * A_col_size); - const float4 a3 = vload4(i, src0_read + 3 * A_col_size); - - const float4 b0 = vload4(i, src1_read); - - dot0 += a0 * b0; - dot1 += a1 * b0; - dot2 += a2 * b0; - dot3 += a3 * b0; + __local Dtype* work_each = (__local Dtype*)work; + + int rows = M - row_gid * 4; + + Dtype4 dot[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)}; + + int i = lid; + while( i < N / 4) { + const Dtype4 b0 = {src1_read[i*4*incv], src1_read[(i*4+1)*incv], src1_read[(i*4+2)*incv], src1_read[(i*4+3)*incv]}; +#pragma unroll + for(int j = 0; j < rows; ++j) { + dot[j] += b0 * vload4(i, src0_read + j * lda); + } i += get_local_size(0); } +#pragma unroll + for(int j = 0; j < rows; ++j) { + work_each[lid * 4 + j] = dot[j].x + dot[j].y + dot[j].z + dot[j].w; + } - work[lid].s0 = dot0.x + dot0.y + dot0.z + dot0.w; - work[lid].s1 = dot1.x + dot1.y + dot1.z + dot1.w; - work[lid].s2 = dot2.x + dot2.y + dot2.z + dot2.w; - work[lid].s3 = dot3.x + dot3.y + dot3.z + dot3.w; - - if(i == A_col_size / 4) - { - if(trail_item != 0) - { - const __global float *src0_trail = src0_read + i * 4; - const __global float *src1_trail = src1_read + i * 4; - for(unsigned int i = 0; i < trail_item; ++i) { - const float at0 = src0_trail[i]; - const float at1 = src0_trail[i + A_col_size]; - const float at2 = src0_trail[i + 2 * A_col_size]; - const float at3 = src0_trail[i + 3 * A_col_size]; - - const float bt = src1_trail[i]; - - work[lid].s0 += at0 * bt; - work[lid].s1 += at1 * bt; - work[lid].s2 += at2 * bt; - work[lid].s3 += at3 * bt; + if(i == N / 4) { + short trail_item = N % 4; + + if(trail_item != 0) { + const __global Dtype *src0_trail = src0_read + i * 4; + const __global Dtype *src1_trail = src1_read + i * 4 * incv; +#pragma unroll + for(short i = 0; i < trail_item; ++i) { + const Dtype bt = src1_trail[i*incv]; +#pragma unroll + for(int j = 0; j < rows; ++j) { + work_each[lid * 4 + j] += bt * src0_trail[i + j * lda]; + } } } - } - for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) { - barrier(CLK_LOCAL_MEM_FENCE); - if(lid < stride) - work[lid] += work[lid+stride]; + for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) + work[lid] += work[lid+stride]; } if(lid == 0) { - if(beta == (Dtype)0) - result[row_gid] = alpha * work[0]; - else - result[row_gid] = alpha * work[0] + beta * result[row_gid]; +#pragma unroll + for(int j = 0; j < rows; ++j) { + result[(row_gid * 4 + j) * incr] = alpha * work_each[j] + beta * result[(row_gid * 4 + j) * incr]; + } } } -/* This kernel used for the trailing rows when row_of_A %4 !=0 */ -__kernel void TEMPLATE(matvec_mul1,Dtype)( - __global const float * A, +__kernel void TEMPLATE(matvec_mul,Dtype)( + unsigned int M, + unsigned int N, + __global const Dtype * A, int offA, - unsigned int A_col_size, - unsigned int row_offset, - unsigned int trail_item, - __global const float * v, + int lda, + __global const Dtype * v, int offv, - float alpha, - float beta, - __global float * result, + int incv, + KERNEL_ARG_DTYPE alpha, + KERNEL_ARG_DTYPE beta, + __global Dtype * result, int offr, - __local float * work) + int incr) { - unsigned int row_gid = get_group_id(0); - unsigned int lid = get_local_id(0); - - const __global float *src0_read = A + (row_offset + row_gid) * A_col_size + offA; - const __global float *src1_read = v + + offv; + int row_gid = get_group_id(0); + int lid = get_local_id(0); + const __global Dtype *src0_read = A + row_gid * 4 * lda + offA; + const __global Dtype *src1_read = v + offv; result = result + offr; - float4 dot0 = (float4)(0.f); - unsigned int i = lid; - while( i < A_col_size / 4) - { - const float4 a0 = vload4(i, src0_read); - const float4 b0 = vload4(i, src1_read); + src1_read += incv > 0 ? 0 : (1 - N) * incv; + result += incr > 0 ? 0 : (1 - M) * incr; + __local Dtype4 work[128]; + __local Dtype* work_each = (__local Dtype*)work; - dot0 += a0 * b0; - i += get_local_size(0); - } - - work[lid] = dot0.x + dot0.y + dot0.z + dot0.w; - - if(i == A_col_size / 4) + if(row_gid == M / 4) + TEMPLATE(matvec_mul_trail_rows,Dtype)(M, N, row_gid, lid, src0_read, lda, src1_read, incv, work, alpha, beta, result, incr); + else { - if(trail_item != 0) - { - const __global float *src0_trail = src0_read + i * 4; - const __global float *src1_trail = src1_read + i * 4; - for(unsigned int i = 0; i < trail_item; ++i) { - const float at0 = src0_trail[i]; - const float bt = src1_trail[i]; - - work[lid] += at0 * bt; + Dtype4 dot[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.f), (Dtype4)(0.f)}; + int i = lid; + while( i < N / 4) { + const Dtype4 b0 = {src1_read[i*4*incv], src1_read[(i*4+1)*incv], src1_read[(i*4+2)*incv], src1_read[(i*4+3)*incv]}; +#pragma unroll + for(int j = 0; j < 4; ++j) { + dot[j] += b0 * vload4(i, src0_read + j * lda); } + i += get_local_size(0); + } +#pragma unroll + for(int j = 0; j < 4; ++j) { + work_each[lid * 4 + j] = dot[j].x + dot[j].y + dot[j].z + dot[j].w; } - } - for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) { + if(i == N / 4) { + short trail_item = N % 4; + if(trail_item != 0) { + const __global Dtype *src0_trail = src0_read + i * 4; + const __global Dtype *src1_trail = src1_read + i * 4 * incv; +#pragma unroll + for(short i = 0; i < trail_item; ++i) { + const Dtype bt = src1_trail[i * incv]; +#pragma unroll + for(int j = 0; j < 4; ++j) { + work_each[lid * 4 + j] += bt * src0_trail[i + j * lda]; + } + } + } + } + + for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) { barrier(CLK_LOCAL_MEM_FENCE); if(lid < stride) work[lid] += work[lid+stride]; - } + } - if(lid == 0) { - if(beta == (Dtype)0) { - result[row_gid+row_offset] = alpha * work[0]; - } else { - result[row_gid+row_offset] *= beta; - result[row_gid+row_offset] += alpha * work[0]; + if(lid == 0) { + // vstore4(alpha * work[0] + beta * vload4(row_gid, result), row_gid, result); + result[row_gid*4*incr] = alpha * work[0].s0 + beta * result[row_gid*4*incr]; + result[(row_gid*4+1)*incr] = alpha * work[0].s1 + beta * result[(row_gid*4+1)*incr]; + result[(row_gid*4+2)*incr] = alpha * work[0].s2 + beta * result[(row_gid*4+2)*incr]; + result[(row_gid*4+3)*incr] = alpha * work[0].s3 + beta * result[(row_gid*4+3)*incr]; } } } + +__kernel void TEMPLATE(trans_matvec_mul,Dtype)( + unsigned int M, + unsigned int N, + __global const Dtype * A, + int offA, + int lda, + __global const Dtype * v, + int offv, + int incv, + KERNEL_ARG_DTYPE alpha, + KERNEL_ARG_DTYPE beta, + __global Dtype * result, + int offr, + int incr) +{ + int col_gid = get_global_id(0); + A += offA + col_gid; + v += offv; + result += offr; + + v += incv > 0 ? 0 : (1 - M) * incv; + result += incr > 0 ? 0 : (1 - N) * incr; + + Dtype dot_prod = 0; + int row_id = 0; +#pragma unroll + for(int row = 0; row < M; ++row) { + dot_prod += A[row_id] * v[row * incv]; + row_id += lda; + } + result[col_gid * incr] = beta * result[col_gid * incr]; + result[col_gid * incr] += alpha * dot_prod; +} diff --git a/src/caffe/greentea/cl_kernels/pooling.cl b/src/caffe/greentea/cl_kernels/pooling.cl index 37400e6f84a..86281770ca6 100644 --- a/src/caffe/greentea/cl_kernels/pooling.cl +++ b/src/caffe/greentea/cl_kernels/pooling.cl @@ -22,7 +22,7 @@ void TEMPLATE(max_pool_forward_impl, Dtype)( const int_tp wend = min(wstart + kernel_w, width); hstart = max(hstart, (int_tp)0); wstart = max(wstart, (int_tp)0); - Dtype maxval = -FLT_MAX; + Dtype maxval = -DTYPE_MAX; int_tp maxidx = -1; __global const Dtype* bottom_slice = bottom_data + (n * channels + c) * height * width; @@ -137,7 +137,7 @@ __kernel void TEMPLATE(sto_pool_forward_train,Dtype)( cumsum += bottom_slice[h * width + w]; } } - const float thres = rand_idx[index] * cumsum; + const Dtype thres = rand_idx[index] * cumsum; // Second pass: get value, and set index. cumsum = 0; for (int_tp h = hstart; h < hend; ++h) { @@ -171,7 +171,7 @@ __kernel void TEMPLATE(sto_pool_forward_test,Dtype)( const int_tp wstart = pw * stride_w; const int_tp wend = min(wstart + kernel_w, width); // We set cumsum to be 0 to avoid divide-by-zero problems - Dtype cumsum = FLT_MIN; + Dtype cumsum = DTYPE_MIN; Dtype cumvalues = 0.; __global const Dtype* bottom_slice = bottom_data + (n * channels + c) * height * width; diff --git a/src/caffe/greentea/cl_kernels/pooling_nd.cl b/src/caffe/greentea/cl_kernels/pooling_nd.cl index 119f6a09787..33bab6fe884 100644 --- a/src/caffe/greentea/cl_kernels/pooling_nd.cl +++ b/src/caffe/greentea/cl_kernels/pooling_nd.cl @@ -41,7 +41,7 @@ __kernel void TEMPLATE(max_pool_forward_nd, Dtype)(const int_tp n, d_iter[i] = d_start[i]; if (d_start[i] >= d_end[i]) { - top_data[index] = -FLT_MAX; + top_data[index] = -DTYPE_MAX; if (use_mask) { mask[index] = -1; } else { @@ -59,7 +59,7 @@ __kernel void TEMPLATE(max_pool_forward_nd, Dtype)(const int_tp n, num /= channels; offset *= (num * channels + chan); - Dtype maxval = -FLT_MAX; + Dtype maxval = -DTYPE_MAX; int_tp maxidx = -1; int_tp final_offset = 0; diff --git a/src/caffe/greentea/cl_kernels/pooling_sk.cl b/src/caffe/greentea/cl_kernels/pooling_sk.cl index 73d18b900b7..1ad8c18862a 100644 --- a/src/caffe/greentea/cl_kernels/pooling_sk.cl +++ b/src/caffe/greentea/cl_kernels/pooling_sk.cl @@ -40,7 +40,7 @@ __global Dtype* bottom_data, while (wstart < 0) { wstart += dilation_w; } - Dtype maxval = -FLT_MAX; + Dtype maxval = -DTYPE_MAX; int_tp maxidx = -1; __global Dtype* bottom_data_ptr = bottom_data + (n * channels + c) * height * width; @@ -252,7 +252,7 @@ __kernel void TEMPLATE(sto_pool_forward_train_sk,Dtype)( cumsum += bottom_data_ptr[h * width + w]; } } - float thres = rand_idx[index] * cumsum; + Dtype thres = rand_idx[index] * cumsum; // Second pass: get value, and set index. cumsum = 0; for (int_tp h = hstart; h < hend; h += dilation_h) { @@ -289,7 +289,7 @@ __kernel void TEMPLATE(sto_pool_forward_test_sk,Dtype)( int_tp wstart = pw * stride_w; int_tp wend = min(wstart + ext_kernel_w, width); // We set cumsum to be 0 to avoid divide-by-zero problems - Dtype cumsum = FLT_MIN; + Dtype cumsum = DTYPE_MIN; Dtype cumvalues = 0.; __global const Dtype* bottom_data_ptr = bottom_data; bottom_data_ptr += (n * channels + c) * height * width; diff --git a/src/caffe/greentea/cl_kernels/softmax_loss.cl b/src/caffe/greentea/cl_kernels/softmax_loss.cl index 019f784d7b8..610cada4c8b 100644 --- a/src/caffe/greentea/cl_kernels/softmax_loss.cl +++ b/src/caffe/greentea/cl_kernels/softmax_loss.cl @@ -21,7 +21,7 @@ __kernel void TEMPLATE(softmax_forward_slm,Dtype)(const int_tp num, const int_tp int_tp n = get_global_id(1); for (int_tp index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index += get_global_size(0), ++s) { - float maxval = -FLT_MAX; + Dtype maxval = -DTYPE_MAX; for (int_tp c = get_global_id(0); c < channels; c += get_global_size(0)) { Dtype tmp = data[(n * channels + c) * spatial_dim + s]; maxval = max((Dtype)tmp, (Dtype)maxval); @@ -87,7 +87,7 @@ __kernel void TEMPLATE(softmax_forward,Dtype)(const int_tp num, const int_tp cha __global Dtype *group_tmp = scale + spatial_dim * num + n * get_max_sub_group_size() * spatial_dim; for (int_tp index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index += get_global_size(0), ++s) { - float maxval = -FLT_MAX; + Dtype maxval = -DTYPE_MAX; for (int_tp c = get_global_id(0); c < channels; c += get_global_size(0)) { Dtype tmp = data[(n * channels + c) * spatial_dim + s]; maxval = max((Dtype)tmp, (Dtype)maxval); @@ -231,7 +231,7 @@ __kernel void TEMPLATE(softmax_loss_forward,Dtype)( } else { loss[index] = -log((Dtype)( max((Dtype) (prob_data[n * dim + label_value * spatial_dim + s]), - (Dtype) FLT_MIN))); + (Dtype) DTYPE_MIN))); counts[index] = 1; } } diff --git a/src/caffe/greentea/cl_kernels/solvers.cl b/src/caffe/greentea/cl_kernels/solvers.cl index 7d792cd9d5a..79afdb54c1b 100644 --- a/src/caffe/greentea/cl_kernels/solvers.cl +++ b/src/caffe/greentea/cl_kernels/solvers.cl @@ -5,9 +5,9 @@ __kernel void TEMPLATE(ada_delta_update,Dtype)(int_tp N, __global Dtype* g, __global Dtype* h, __global Dtype* h2, - Dtype momentum, - Dtype delta, - Dtype local_rate) { + KERNEL_ARG_DTYPE momentum, + KERNEL_ARG_DTYPE delta, + KERNEL_ARG_DTYPE local_rate) { for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) { Dtype gi = g[i]; Dtype hi = h[i] = momentum * h[i] + ((Dtype)1.0 - momentum) * gi * gi; @@ -19,8 +19,8 @@ __kernel void TEMPLATE(ada_delta_update,Dtype)(int_tp N, __global Dtype* g, __kernel void TEMPLATE(ada_grad_update,Dtype)(int_tp N, __global Dtype* g, __global Dtype* h, - Dtype delta, - Dtype local_rate) { + KERNEL_ARG_DTYPE delta, + KERNEL_ARG_DTYPE local_rate) { for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) { Dtype gi = g[i]; Dtype hi = h[i] = h[i] + gi * gi; @@ -31,10 +31,10 @@ __kernel void TEMPLATE(ada_grad_update,Dtype)(int_tp N, __global Dtype* g, __kernel void TEMPLATE(adam_update,Dtype)(int_tp N, __global Dtype* g, __global Dtype* m, __global Dtype* v, - Dtype beta1, - Dtype beta2, - Dtype eps_hat, - Dtype corrected_local_rate) { + KERNEL_ARG_DTYPE beta1, + KERNEL_ARG_DTYPE beta2, + KERNEL_ARG_DTYPE eps_hat, + KERNEL_ARG_DTYPE corrected_local_rate) { for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) { Dtype gi = g[i]; Dtype mi = m[i] = m[i] * beta1 + gi * (1 - beta1); @@ -46,8 +46,8 @@ __kernel void TEMPLATE(adam_update,Dtype)(int_tp N, __global Dtype* g, __kernel void TEMPLATE(nesterov_update,Dtype)(int_tp N, __global Dtype* g, __global Dtype* h, - Dtype momentum, - Dtype local_rate) { + KERNEL_ARG_DTYPE momentum, + KERNEL_ARG_DTYPE local_rate) { for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) { Dtype hi = h[i]; Dtype hi_new = h[i] = momentum * hi + local_rate * g[i]; @@ -57,9 +57,9 @@ __kernel void TEMPLATE(nesterov_update,Dtype)(int_tp N, __global Dtype* g, __kernel void TEMPLATE(rms_prop_update,Dtype)(int_tp N, __global Dtype* g, __global Dtype* h, - Dtype rms_decay, - Dtype delta, - Dtype local_rate) { + KERNEL_ARG_DTYPE rms_decay, + KERNEL_ARG_DTYPE delta, + KERNEL_ARG_DTYPE local_rate) { for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) { Dtype gi = g[i]; Dtype hi = h[i] = rms_decay * h[i] + (1 - rms_decay) * gi * gi; @@ -69,8 +69,8 @@ __kernel void TEMPLATE(rms_prop_update,Dtype)(int_tp N, __global Dtype* g, __kernel void TEMPLATE(sgd_update,Dtype)(int_tp N, __global Dtype* g, __global Dtype* h, - Dtype momentum, - Dtype local_rate) { + KERNEL_ARG_DTYPE momentum, + KERNEL_ARG_DTYPE local_rate) { for (int_tp i = get_global_id(0); i < N; i += get_global_size(0)) { g[i] = h[i] = momentum * h[i] + local_rate * g[i]; } diff --git a/src/caffe/greentea/greentea_im2col.cpp b/src/caffe/greentea/greentea_im2col.cpp index 11a0e59ee3f..2eb0c425f53 100644 --- a/src/caffe/greentea/greentea_im2col.cpp +++ b/src/caffe/greentea/greentea_im2col.cpp @@ -37,6 +37,25 @@ void greentea_im2col_gpu(viennacl::ocl::program *prog, } // Explicit instantiation +#ifdef HAS_HALF_SUPPORT +template void greentea_im2col_gpu(viennacl::ocl::program *prog, + viennacl::ocl::context *ctx, + const cl_mem data_im, + const int_tp data_offset, + const int_tp channels, + const int_tp height, + const int_tp width, + const int_tp kernel_h, + const int_tp kernel_w, + const int_tp pad_h, const int_tp pad_w, + const int_tp stride_h, + const int_tp stride_w, + const int_tp dilation_h, + const int_tp dilation_w, + cl_mem data_col, + const int_tp data_col_off); +#endif + template void greentea_im2col_gpu(viennacl::ocl::program *prog, viennacl::ocl::context *ctx, const cl_mem data_im, @@ -97,6 +116,25 @@ void greentea_col2im_gpu(viennacl::ocl::program *prog, ctx->get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_col2im_gpu(viennacl::ocl::program *prog, + viennacl::ocl::context *ctx, + const cl_mem data_col, + const int_tp data_col_off, + const int_tp channels, + const int_tp height, + const int_tp width, + const int_tp patch_h, + const int_tp patch_w, + const int_tp pad_h, const int_tp pad_w, + const int_tp stride_h, + const int_tp stride_w, + const int_tp dilation_h, + const int_tp dilation_w, + cl_mem data_im, + const int_tp data_offset); +#endif + template void greentea_col2im_gpu(viennacl::ocl::program *prog, viennacl::ocl::context *ctx, const cl_mem data_col, @@ -156,6 +194,21 @@ void greentea_im2col_nd_gpu(viennacl::ocl::program *prog, } // Explicit instantiation +#ifdef HAS_HALF_SUPPORT +template void greentea_im2col_nd_gpu(viennacl::ocl::program *prog, + viennacl::ocl::context *ctx, + cl_mem data_im, + const int_tp data_off, + const int_tp num_spatial_axes, + const int_tp channel_axis, + const int_tp num_kernels, + cl_mem im_shape, cl_mem col_shape, + cl_mem kernel_shape, cl_mem pad, + cl_mem stride, cl_mem dilation, + cl_mem data_col, + const int_tp data_col_off); +#endif + template void greentea_im2col_nd_gpu(viennacl::ocl::program *prog, viennacl::ocl::context *ctx, cl_mem data_im, @@ -207,6 +260,20 @@ void greentea_col2im_nd_gpu(viennacl::ocl::program *prog, } // Explicit instantiation +#ifdef HAS_HALF_SUPPORT +template void greentea_col2im_nd_gpu(viennacl::ocl::program *prog, + viennacl::ocl::context *ctx, + cl_mem data_col, + const int_tp data_col_off, + const int_tp num_spatial_axes, + const int_tp channel_axis, + const int_tp im_size, + cl_mem im_shape, cl_mem col_shape, + cl_mem kernel_shape, cl_mem pad, + cl_mem stride, cl_mem dilation, + cl_mem data_im, int_tp data_off); +#endif + template void greentea_col2im_nd_gpu(viennacl::ocl::program *prog, viennacl::ocl::context *ctx, cl_mem data_col, diff --git a/src/caffe/greentea/greentea_math_functions.cpp b/src/caffe/greentea/greentea_math_functions.cpp index 1db968d730b..c2aca7db0bf 100644 --- a/src/caffe/greentea/greentea_math_functions.cpp +++ b/src/caffe/greentea/greentea_math_functions.cpp @@ -175,6 +175,19 @@ template void greentea_copy(const int_tp N, const cl_mem X, const int_tp offX, cl_mem Y, const int_tp offY, viennacl::ocl::context *ctx); +#ifdef HAS_HALF_SUPPORT +template void greentea_copy(const int_tp N, const cl_mem X, + const int_tp offX, half* Y, + viennacl::ocl::context *ctx); +template void greentea_copy(const int_tp N, const half* X, cl_mem Y, + const int_tp offY, + viennacl::ocl::context *ctx); +template void greentea_copy(const int_tp N, const cl_mem X, + const int_tp offX, cl_mem Y, + const int_tp offY, + viennacl::ocl::context *ctx); +#endif + template void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, @@ -209,9 +222,7 @@ void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, int_tp lda = (TransA == CblasNoTrans) ? K : M; int_tp ldb = (TransB == CblasNoTrans) ? N : K; int_tp ldc = N; - #if defined(USE_CLBLAS) - clblasOrder clOrder = clblasRowMajor; clblasTranspose clTransA = (TransA == CblasNoTrans) ? clblasNoTrans : clblasTrans; @@ -224,13 +235,20 @@ void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, clblasSgemm(clOrder, clTransA, clTransB, M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, C, offC, ldc, 1, &queue, 0, NULL, NULL)); - } else { + } else if (std::is_same::value) { GREENTEA_CL_BLAS_CHECK( clblasDgemm(clOrder, clTransA, clTransB, M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, C, offC, ldc, 1, &queue, 0, NULL, NULL)); } - +#ifdef HAS_HALF_SUPPORT + else if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK( + clblasHgemm(clOrder, clTransA, clTransB, + M, N, K, alpha, A, offA, lda, B, offB, ldb, beta, + C, offC, ldc, 1, &queue, 0, NULL, NULL)); + } +#endif #elif defined(USE_CLBLAST) cl_command_queue queue = ctx.get_queue().handle().get(); @@ -321,7 +339,17 @@ void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, #endif // clBLAS, CLBlast, or default (ViennaCL) } } - +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_gemm(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int_tp M, const int_tp N, + const int_tp K, const half alpha, + const cl_mem A, const int_tp offA, + const cl_mem B, const int_tp offB, + const half beta, cl_mem C, + const int_tp offC); +#endif template void greentea_gpu_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, @@ -373,65 +401,40 @@ void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, clEnqueueUnmapMemObject(ctx.get_queue().handle().get(), y, yptr, 0, NULL, NULL); } else { - if (std::is_same::value && TransA == CblasNoTrans) { + if (!std::is_same::value && TransA == CblasNoTrans) { viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) ->program(); - viennacl::ocl::kernel &k = - program.get_kernel(CL_KERNEL_SELECT("matvec_mul4")); + bool isTransA = (TransA == CblasTrans); + viennacl::ocl::kernel &k = (isTransA ? + program.get_kernel(CL_KERNEL_SELECT("trans_matvec_mul")) : + program.get_kernel(CL_KERNEL_SELECT("matvec_mul"))); uint row_size = M; uint col_size = N; size_t localsize = 128; - size_t globalsize = row_size / 4 * localsize; + size_t globalsize = (isTransA ? col_size : (row_size + 3) / 4 * localsize); uint argId = 0; + k.arg(argId++, row_size); + k.arg(argId++, col_size); k.arg(argId++, WrapHandle(A, &ctx)); k.arg(argId++, offA); - k.arg(argId++, cl_uint(col_size)); - k.arg(argId++, cl_uint(col_size%4)); + k.arg(argId++, col_size); k.arg(argId++, WrapHandle(x, &ctx)); k.arg(argId++, offx); - k.arg(argId++, alpha); - k.arg(argId++, beta); + k.arg(argId++, 1); + k.arg(argId++, fixup_arg_type(alpha)); + k.arg(argId++, fixup_arg_type(beta)); k.arg(argId++, WrapHandle(y, &ctx)); k.arg(argId++, offy); - k.arg(argId++, viennacl::ocl::local_mem(sizeof(cl_float4) * localsize)); + k.arg(argId++, 1); clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), k.handle().get(), 1, NULL, &globalsize, - &localsize, 0, NULL, + (isTransA ? NULL : &localsize), 0, NULL, NULL); - if ((row_size % 4) != 0) { - viennacl::ocl::kernel &k_1 = - program.get_kernel(CL_KERNEL_SELECT("matvec_mul1")); - size_t localsize = 128; - size_t globalsize = row_size % 4 * localsize; - uint row_offset = row_size - (row_size % 4); - - uint argId = 0; - k_1.arg(argId++, WrapHandle(A, &ctx)); - k_1.arg(argId++, offA); - k_1.arg(argId++, cl_uint(col_size)); - k_1.arg(argId++, cl_uint(row_offset)); - k_1.arg(argId++, cl_uint(col_size%4)); - k_1.arg(argId++, WrapHandle(x, &ctx)); - k_1.arg(argId++, offx); - k_1.arg(argId++, alpha); - k_1.arg(argId++, beta); - k_1.arg(argId++, WrapHandle(y, &ctx)); - k_1.arg(argId++, offy); - k_1.arg(argId++, - viennacl::ocl::local_mem(sizeof(cl_float) * localsize)); - - clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), - k_1.handle().get(), 1, - NULL, - &globalsize, - &localsize, 0, NULL, - NULL); - } } else { #if defined(USE_CLBLAS) @@ -445,13 +448,20 @@ void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, clblasSgemv(clblasRowMajor, clTransA, M, N, alpha, A, offA, N, x, offx, 1, beta, y, offy, 1, 1, &queue, 0, NULL, NULL)); - } else { + } else if (std::is_same::value) { GREENTEA_CL_BLAS_CHECK( clblasDgemv(clblasRowMajor, clTransA, M, N, alpha, A, offA, N, x, offx, 1, beta, y, offy, 1, 1, &queue, 0, NULL, NULL)); } - +#ifdef HAS_HALF_SUPPORT + else if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK( + clblasHgemv(clblasRowMajor, + clTransA, M, N, alpha, A, offA, N, x, offx, 1, + beta, y, offy, 1, 1, &queue, 0, NULL, NULL)); + } +#endif #elif defined(USE_CLBLAST) cl_command_queue queue = ctx.get_queue().handle().get(); @@ -523,6 +533,15 @@ void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, } } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_gemv(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransA, + const int_tp M, const int_tp N, + const half alpha, const cl_mem A, + const int_tp offA, const cl_mem x, + const int_tp offx, const half beta, + cl_mem y, const int_tp offy); +#endif template void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const int_tp M, const int_tp N, @@ -567,12 +586,18 @@ void greentea_gpu_axpy(const int_tp ctx_id, const int_tp N, const Dtype alpha, GREENTEA_CL_BLAS_CHECK( clblasSaxpy(N, alpha, X, offX, 1, Y, offY, 1, 1, &queue, 0, NULL, NULL)); - } else { + } else if (std::is_same::value){ GREENTEA_CL_BLAS_CHECK( clblasDaxpy(N, alpha, X, offX, 1, Y, offY, 1, 1, &queue, 0, NULL, NULL)); } - +#ifdef HAS_HALF_SUPPORT + else if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK( + clblasHaxpy(N, alpha, X, offX, + 1, Y, offY, 1, 1, &queue, 0, NULL, NULL)); + } +#endif #elif defined(USE_CLBLAST) cl_command_queue queue = ctx.get_queue().handle().get(); @@ -617,6 +642,12 @@ void greentea_gpu_axpy(const int_tp ctx_id, const int_tp N, const Dtype alpha, } } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_axpy(const int_tp ctx_id, const int_tp N, + const half alpha, const cl_mem X, + const int_tp offX, cl_mem Y, + const int_tp offY); +#endif template void greentea_gpu_axpy(const int_tp ctx_id, const int_tp N, const float alpha, const cl_mem X, const int_tp offX, cl_mem Y, @@ -641,6 +672,12 @@ void greentea_gpu_mul(const int_tp ctx_id, const int_tp N, const cl_mem a, ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_mul(const int_tp ctx_id, const int_tp N, + const cl_mem a, const int_tp offa, + const cl_mem b, const int_tp offb, + cl_mem y, const int_tp offy); +#endif template void greentea_gpu_mul(const int_tp ctx_id, const int_tp N, const cl_mem a, const int_tp offa, const cl_mem b, const int_tp offb, @@ -665,6 +702,12 @@ void greentea_gpu_div(const int_tp ctx_id, const int_tp N, const cl_mem a, ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_div(const int_tp ctx_id, const int_tp N, + const cl_mem a, const int_tp offa, + const cl_mem b, const int_tp offb, + cl_mem y, const int_tp offy); +#endif template void greentea_gpu_div(const int_tp ctx_id, const int_tp N, const cl_mem a, const int_tp offa, const cl_mem b, const int_tp offb, @@ -696,11 +739,16 @@ void greentea_gpu_scal(const int_tp ctx_id, const int_tp N, const Dtype alpha, if (std::is_same::value) { GREENTEA_CL_BLAS_CHECK(clblasSscal(N, alpha, x, offx, 1, 1, &queue, 0, NULL, NULL)); - } else { + } else if (std::is_same::value) { GREENTEA_CL_BLAS_CHECK(clblasDscal(N, alpha, x, offx, 1, 1, &queue, 0, NULL, NULL)); } - +#ifdef HAS_HALF_SUPPORT + else if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK(clblasHscal(N, alpha, x, offx, + 1, 1, &queue, 0, NULL, NULL)); + } +#endif #elif defined(USE_CLBLAST) cl_command_queue queue = ctx.get_queue().handle().get(); @@ -738,7 +786,11 @@ void greentea_gpu_scal(const int_tp ctx_id, const int_tp N, const Dtype alpha, #endif // clBLAS, CLBlast, or default (ViennaCL) } } - +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_scal(const int_tp ctx_id, const int_tp N, + const half alpha, cl_mem x, + const int_tp offx); +#endif template void greentea_gpu_scal(const int_tp ctx_id, const int_tp N, const float alpha, cl_mem x, const int_tp offx); @@ -754,6 +806,13 @@ void greentea_gpu_axpby(const int_tp ctx_id, const int_tp N, const Dtype alpha, greentea_gpu_axpy(ctx_id, N, alpha, X, offX, Y, offY); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_axpby(const int_tp ctx_id, const int_tp N, + const half alpha, const cl_mem X, + const int_tp offX, const half beta, + cl_mem Y, const int_tp offY); +#endif + template void greentea_gpu_axpby(const int_tp ctx_id, const int_tp N, const float alpha, const cl_mem X, const int_tp offX, const float beta, @@ -800,12 +859,18 @@ void greentea_gpu_dot(const int_tp ctx_id, const int_tp n, const cl_mem X, GREENTEA_CL_BLAS_CHECK( clblasSdot(n, gpuout, 0, X, offX, 1, Y, offY, 1, scratch, 1, &queue, 0, NULL, NULL)); - } else { + } else if (std::is_same::value){ GREENTEA_CL_BLAS_CHECK( clblasDdot(n, gpuout, 0, X, offX, 1, Y, offY, 1, scratch, 1, &queue, 0, NULL, NULL)); } - +#ifdef HAS_HALF_SUPPORT + else if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK( + clblasHdot(n, gpuout, 0, X, offX, 1, Y, + offY, 1, scratch, 1, &queue, 0, NULL, NULL)); + } +#endif greentea_gpu_memcpy(sizeof(Dtype), gpuout, 0, out, &ctx); clReleaseMemObject(gpuout); @@ -865,6 +930,12 @@ void greentea_gpu_dot(const int_tp ctx_id, const int_tp n, const cl_mem X, } } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_dot(const int_tp ctx_id, const int_tp n, + const cl_mem X, const int_tp offX, + const cl_mem Y, const int_tp offY, + half* out); +#endif template void greentea_gpu_dot(const int_tp ctx_id, const int_tp n, const cl_mem X, const int_tp offX, const cl_mem Y, const int_tp offY, @@ -903,11 +974,18 @@ void greentea_gpu_asum(const int_tp ctx_id, const int_tp n, const cl_mem X, GREENTEA_CL_BLAS_CHECK( clblasSasum(n, gpuout, 0, X, offX, 1, scratch, 1, &queue, 0, NULL, NULL)); - } else { + } else if (std::is_same::value) { GREENTEA_CL_BLAS_CHECK( clblasDasum(n, gpuout, 0, X, offX, 1, scratch, 1, &queue, 0, NULL, NULL)); } +#ifdef HAS_HALF_SUPPORT + else if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK( + clblasHasum(n, gpuout, 0, X, offX, 1, + scratch, 1, &queue, 0, NULL, NULL)); + } +#endif greentea_gpu_memcpy(sizeof(Dtype), gpuout, 0, Y, &ctx); @@ -963,6 +1041,11 @@ void greentea_gpu_asum(const int_tp ctx_id, const int_tp n, const cl_mem X, } } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_asum(const int_tp ctx_id, const int_tp n, + const cl_mem X, const int_tp offX, + half* Y); +#endif template void greentea_gpu_asum(const int_tp ctx_id, const int_tp n, const cl_mem X, const int_tp offX, float* Y); @@ -1003,12 +1086,20 @@ void greentea_gpu_scale(const int_tp ctx_id, const int_tp n, const Dtype alpha, clblasScopy(n, X, offX, 1, Y, offY, 1, 1, &queue, 0, NULL, NULL)); GREENTEA_CL_BLAS_CHECK( clblasSscal(n, alpha, Y, offY, 1, 1, &queue, 0, NULL, NULL)); - } else { + } else if (std::is_same::value) { GREENTEA_CL_BLAS_CHECK( clblasDcopy(n, X, offX, 1, Y, offY, 1, 1, &queue, 0, NULL, NULL)); GREENTEA_CL_BLAS_CHECK( clblasDscal(n, alpha, Y, offY, 1, 1, &queue, 0, NULL, NULL)); } +#ifdef HAS_HALF_SUPPORT + else if (std::is_same::value) { + GREENTEA_CL_BLAS_CHECK( + clblasHcopy(n, X, offX, 1, Y, offY, 1, 1, &queue, 0, NULL, NULL)); + GREENTEA_CL_BLAS_CHECK( + clblasHscal(n, alpha, Y, offY, 1, 1, &queue, 0, NULL, NULL)); + } +#endif #elif defined(USE_CLBLAST) @@ -1065,11 +1156,16 @@ void greentea_gpu_scale(const int_tp ctx_id, const int_tp n, const Dtype alpha, } } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_scale(const int_tp ctx_id, const int_tp n, + const half alpha, const cl_mem X, + const int_tp offX, cl_mem Y, + const int_tp offY); +#endif template void greentea_gpu_scale(const int_tp ctx_id, const int_tp n, const float alpha, const cl_mem X, const int_tp offX, cl_mem Y, const int_tp offY); - template void greentea_gpu_scale(const int_tp ctx_id, const int_tp n, const double alpha, const cl_mem X, const int_tp offX, cl_mem Y, @@ -1089,13 +1185,19 @@ void greentea_gpu_set(const int_tp ctx_id, const int_tp N, const Dtype alpha, // OpenCL Version < 1.2 fallback viennacl::ocl::kernel &oclk_fill = program.get_kernel( CL_KERNEL_SELECT("fill")); - viennacl::ocl::enqueue(oclk_fill(N, alpha, WrapHandle(Y, &ctx), offY), + viennacl::ocl::enqueue(oclk_fill(N, fixup_arg_type(alpha), + WrapHandle(Y, &ctx), offY), ctx.get_queue()); } template void greentea_gpu_set(const int_tp ctx_id, const int_tp N, const int_tp alpha, cl_mem Y, const int_tp offY); +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_set(const int_tp ctx_id, const int_tp N, + const half alpha, cl_mem Y, + const int_tp offY); +#endif template void greentea_gpu_set(const int_tp ctx_id, const int_tp N, const float alpha, cl_mem Y, const int_tp offY); @@ -1112,10 +1214,15 @@ void greentea_gpu_add_scalar(const int_tp ctx_id, const int_tp N, viennacl::ocl::kernel &oclk_add_scalar = program.get_kernel( CL_KERNEL_SELECT("add_scalar")); - viennacl::ocl::enqueue(oclk_add_scalar(N, alpha, WrapHandle(Y, &ctx), offY), + viennacl::ocl::enqueue(oclk_add_scalar(N, fixup_arg_type(alpha), WrapHandle(Y, &ctx), offY), ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_add_scalar(const int_tp ctx_id, + const int_tp N, const half alpha, + cl_mem Y, const int_tp offY); +#endif template void greentea_gpu_add_scalar(const int_tp ctx_id, const int_tp N, const float alpha, cl_mem Y, const int_tp offY); @@ -1139,6 +1246,12 @@ void greentea_gpu_add(const int_tp ctx_id, const int_tp n, const cl_mem a, ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_add(const int_tp ctx_id, const int_tp n, + const cl_mem a, const int_tp offa, + const cl_mem b, const int_tp offb, + cl_mem y, const int_tp offy); +#endif template void greentea_gpu_add(const int_tp ctx_id, const int_tp n, const cl_mem a, const int_tp offa, const cl_mem b, const int_tp offb, @@ -1163,6 +1276,12 @@ void greentea_gpu_sub(const int_tp ctx_id, const int_tp n, const cl_mem a, ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_sub(const int_tp ctx_id, const int_tp n, + const cl_mem a, const int_tp offa, + const cl_mem b, const int_tp offb, + cl_mem y, const int_tp offy); +#endif template void greentea_gpu_sub(const int_tp ctx_id, const int_tp n, const cl_mem a, const int_tp offa, const cl_mem b, const int_tp offb, @@ -1185,6 +1304,11 @@ void greentea_gpu_abs(const int_tp ctx_id, const int_tp N, const cl_mem a, ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_abs(const int_tp ctx_id, const int_tp N, + const cl_mem a, const int_tp offa, + cl_mem y, const int_tp offy); +#endif template void greentea_gpu_abs(const int_tp ctx_id, const int_tp N, const cl_mem a, const int_tp offa, cl_mem y, const int_tp offy); @@ -1205,6 +1329,11 @@ void greentea_gpu_exp(const int_tp ctx_id, const int_tp N, const cl_mem a, ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_exp(const int_tp ctx_id, const int_tp N, + const cl_mem a, const int_tp offa, + cl_mem y, const int_tp offy); +#endif template void greentea_gpu_exp(const int_tp ctx_id, const int_tp N, const cl_mem a, const int_tp offa, cl_mem y, const int_tp offy); @@ -1227,6 +1356,9 @@ void greentea_gpu_sqrt(const int_tp ctx_id, const int_tp n, ctx.get_queue()); } +template void greentea_gpu_sqrt(const int_tp ctx_id, const int_tp n, + const cl_mem a, const int_tp offa, + cl_mem y, const int_tp offy); template void greentea_gpu_sqrt(const int_tp ctx_id, const int_tp n, const cl_mem a, const int_tp offa, cl_mem y, const int_tp offy); @@ -1245,10 +1377,17 @@ void greentea_gpu_powx(const int_tp ctx_id, const int_tp N, const cl_mem a, viennacl::ocl::kernel &oclk_powx = program.get_kernel( CL_KERNEL_SELECT("powx")); viennacl::ocl::enqueue( - oclk_powx(N, WrapHandle(a, &ctx), offa, alpha, WrapHandle(y, &ctx), offy), + oclk_powx(N, WrapHandle(a, &ctx), offa, fixup_arg_type(alpha), + WrapHandle(y, &ctx), offy), ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_powx(const int_tp ctx_id, const int_tp N, + const cl_mem a, const int_tp offa, + const half alpha, cl_mem y, + const int_tp offy); +#endif template void greentea_gpu_powx(const int_tp ctx_id, const int_tp N, const cl_mem a, const int_tp offa, const float alpha, cl_mem y, @@ -1271,6 +1410,11 @@ void greentea_gpu_log(const int_tp ctx_id, const int_tp N, const cl_mem a, ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_log(const int_tp ctx_id, const int_tp N, + const cl_mem a, const int_tp offa, + cl_mem y, const int_tp offy); +#endif template void greentea_gpu_log(const int_tp ctx_id, const int_tp N, const cl_mem a, const int_tp offa, cl_mem y, const int_tp offy); @@ -1293,6 +1437,11 @@ int_tp offx, ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_sign(const int_tp ctx_id, const int_tp n, + const cl_mem x, int_tp offx, cl_mem y, + const int_tp offy); +#endif template void greentea_gpu_sign(const int_tp ctx_id, const int_tp n, const cl_mem x, int_tp offx, cl_mem y, const int_tp offy); @@ -1315,6 +1464,11 @@ int_tp offx, ctx.get_queue()); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_sgnbit(const int_tp ctx_id, const int_tp n, + const cl_mem x, int_tp offx, cl_mem y, + const int_tp offy); +#endif template void greentea_gpu_sgnbit(const int_tp ctx_id, const int_tp n, const cl_mem x, int_tp offx, cl_mem y, const int_tp offy); @@ -1340,6 +1494,12 @@ void greentea_gpu_rng_uniform(const int_tp ctx_id, const int_tp n, greentea_gpu_memcpy(sizeof(Dtype) * n, &random[0], r, offr, &ctx); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_rng_uniform(const int_tp ctx_id, + const int_tp n, const half a, + const half b, cl_mem r, + const int_tp offr); +#endif template void greentea_gpu_rng_uniform(const int_tp ctx_id, const int_tp n, const float a, const float b, cl_mem r, @@ -1359,6 +1519,13 @@ void greentea_gpu_rng_gaussian(const int_tp ctx_id, const int_tp n, greentea_gpu_memcpy(sizeof(Dtype) * n, &random[0], r, offr, &ctx); } +#ifdef HAS_HALF_SUPPORT +template void greentea_gpu_rng_gaussian(const int_tp ctx_id, + const int_tp n, const half mu, + const half sigma, cl_mem r, + const int_tp offr); +#endif + template void greentea_gpu_rng_gaussian(const int_tp ctx_id, const int_tp n, const float mu, const float sigma, cl_mem r, diff --git a/src/caffe/greentea/libdnn.cpp b/src/caffe/greentea/libdnn.cpp index 1affcde7966..e7b90f0ef22 100644 --- a/src/caffe/greentea/libdnn.cpp +++ b/src/caffe/greentea/libdnn.cpp @@ -32,6 +32,13 @@ std::string LibDNN::generate_header() { ss << "#endif" << std::endl; } + if (std::is_same::value) { + ss << "#if defined(cl_khr_fp16)" << std::endl; + ss << "#pragma OPENCL EXTENSION cl_khr_fp16 : enable" << std::endl; + ss << "#define HALF_SUPPORT_AVAILABLE" << std::endl; + ss << "#endif" << std::endl; + } + // Test/enable 32 bit atomics ss << "#if defined(cl_khr_int32_base_atomics)" << std::endl; ss << "#pragma OPENCL EXTENSION cl_khr_int32_base_atomics : enable" @@ -62,7 +69,7 @@ std::string LibDNN::generate_header() { for (int_tp i = 2; i <= 16; i *= 2) { ss << "#define Dtype" << i << " double" << i << std::endl; } - } else { + } else if (std::is_same::value){ ss << "#define Dtype float" << std::endl; ss << "#define Dtype1 float" << std::endl; // float2, float4, float8, float16 @@ -70,6 +77,22 @@ std::string LibDNN::generate_header() { ss << "#define Dtype" << i << " float" << i << std::endl; } } +#ifdef HAS_HALF_SUPPORT + else if (std::is_same::value) { + ss << "#define Dtype half" << std::endl; + ss << "#define Dtype1 half" << std::endl; + // half2, half4, half8, half16 + for (int_tp i = 2; i <= 16; i *= 2) { + ss << "#define Dtype" << i << " half" << i << std::endl; + } + } +#endif + + if (std::is_same::value) { + ss << "#define KERNEL_ARG_DTYPE float" << std::endl; + } else { + ss << "#define KERNEL_ARG_DTYPE Dtype" << std::endl; + } std::vector elems4({ "x", "y", "z", "w" }); @@ -162,6 +185,7 @@ std::string LibDNN::generate_header() { } else { ss << "#ifdef ATOMICS_32_AVAILABLE" << std::endl; } + // FIXME, half version has bug. for (int i = 0; i < atomic_funcs.size(); ++i) { ss << "inline void atomic" << atomic_funcs[i]; ss << "(volatile __global Dtype* source, const Dtype operand) {" @@ -172,13 +196,22 @@ std::string LibDNN::generate_header() { } else { ss << "unsigned int intVal;" << std::endl; } - ss << "Dtype floatVal;" << std::endl; + if (std::is_same::value) { + ss << "Dtype floatVal[2];" << std::endl; + } else { + ss << "Dtype floatVal[1];" << std::endl; + } ss << "} next, expected, current;" << std::endl; - ss << "current.floatVal = *source;" << std::endl; + ss << "current.floatVal[0] = *source;" << std::endl; + if (std::is_same::value) + ss << "current.floatVal[1] = *(source + 1);" << std::endl; ss << "do {" << std::endl; - ss << "expected.floatVal = current.floatVal;" << std::endl; - ss << "next.floatVal = expected.floatVal " << atomic_ops[i] << " operand;" + ss << "expected.intVal = current.intVal;" << std::endl; + ss << "next.floatVal[0] = expected.floatVal[0] " << atomic_ops[i] << " operand;" << std::endl; + if (std::is_same::value) { + ss << "next.floatVal[1] = expected.floatVal[1]; " << std::endl; + } ss << "current.intVal = "; if (std::is_same::value) { ss << "atom_cmpxchg((volatile __global unsigned long *)"; @@ -197,7 +230,7 @@ std::string LibDNN::generate_header() { } // Memory set - ss << "__kernel void fill_memory(const int_tp n, const Dtype alpha," + ss << "__kernel void fill_memory(const int_tp n, const KERNEL_ARG_DTYPE alpha," << "__global Dtype* x, const int_tp offx) {" << std::endl; ss << "for (int_tp index = get_global_id(0); index < n; " << "index += get_global_size(0)) {" << std::endl; @@ -361,8 +394,8 @@ void LibDNN::SetMemory(Dtype* memory, int_tp count, int_tp offset, kernel.global_work_size(2, 1); viennacl::ocl::enqueue( - kernel(count, value, WrapHandle((cl_mem) memory, &ctx), offset), - ctx.get_queue()); + kernel(count, fixup_arg_type(value), WrapHandle((cl_mem) memory, &ctx), + offset), ctx.get_queue()); #endif // USE_GREENTEA } else { #ifdef USE_CUDA diff --git a/src/caffe/greentea/libdnn_conv_spatial.cpp b/src/caffe/greentea/libdnn_conv_spatial.cpp index f6d8b66b558..d168fc7169a 100644 --- a/src/caffe/greentea/libdnn_conv_spatial.cpp +++ b/src/caffe/greentea/libdnn_conv_spatial.cpp @@ -258,23 +258,44 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, // SIMD16/8 mode will be used, // the compiler could choose to use two SIMD8 threads, // and if that happens the code will break. + ss << "#if defined(convolve_simd) || defined(Conv_Interleaved)" << std::endl; + ss << "#if TYPE == TYPE_HALF" << std::endl; + ss << "#define INT_TYPE ushort" << std::endl; + ss << "#define INT_TYPE2 ushort2" << std::endl; + ss << "#define INT_TYPE4 ushort4" << std::endl; + ss << "#define INT_TYPE8 ushort8" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read_us2" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read_us4" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read_us8" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read_us" << std::endl; + ss << "#else" << std::endl; + ss << "#define INT_TYPE uint" << std::endl; + ss << "#define INT_TYPE2 uint2" << std::endl; + ss << "#define INT_TYPE4 uint4" << std::endl; + ss << "#define INT_TYPE8 uint8" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read2" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read4" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read8" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read" << std::endl; + ss << "#endif" << std::endl; + ss << "#endif" << std::endl; ss << "#define activation_function(x) (x)" << std::endl; ss << "__attribute__((reqd_work_group_size(1, 1, SIMD_SIZE)))" << std::endl; ss << "kernel void" << std::endl; ss << "convolve_simd(" << std::endl; - ss << "__global float* inputs_base," << std::endl; - ss << "filter_qualifier float* weights_base," << std::endl; - ss << "__global float* biases_base," << std::endl; - ss << "__global float* outputs_base," << std::endl; + ss << "__global Dtype* inputs_base," << std::endl; + ss << "filter_qualifier Dtype* weights_base," << std::endl; + ss << "__global Dtype* biases_base," << std::endl; + ss << "__global Dtype* outputs_base," << std::endl; ss << "const ushort input_width," << std::endl; ss << "const ushort input_height," << std::endl; ss << "const ushort output_width," << std::endl; ss << "const ushort output_height)" << std::endl; ss << "{" << std::endl; - ss << "__global float* outputs = outputs_base;" << std::endl; - ss << "__global float* inputs = inputs_base;" << std::endl; - ss << "filter_qualifier float* weights = weights_base;" << std::endl; - ss << "__global float* biases = biases_base;" << std::endl; + ss << "__global Dtype* outputs = outputs_base;" << std::endl; + ss << "__global Dtype* inputs = inputs_base;" << std::endl; + ss << "filter_qualifier Dtype* weights = weights_base;" << std::endl; + ss << "__global Dtype* biases = biases_base;" << std::endl; // oc = Output Column ss << "uint_tp oc = get_global_id(0) * OUT_BLOCK_WIDTH;" << std::endl; // or = Output Row @@ -283,7 +304,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "uint_tp fm = get_global_id(2);" << std::endl; ss << "uint_tp fmg = get_group_id(2);" << std::endl; ss << "uint_tp lid = get_local_id(2);" << std::endl; - ss << "float out[OUT_BLOCK_SIZE];" << std::endl; + ss << "Dtype out[OUT_BLOCK_SIZE];" << std::endl; ss << "int_tp in_addr;" << std::endl; // find weights adress of given neuron (lid is index) ss << "uint_tp weight_addr = (fmg % (ALIGNED_NUM_FILTERS/SIMD_SIZE)) * " @@ -313,8 +334,8 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, << "+ curr_x - INPUT_PAD_W;" << std::endl; ss << "union {" << std::endl; - ss << "float4 in_vec[INVEC_SIZE];" << std::endl; - ss << "float in_array[INVEC_SIZE * 4];" << std::endl; + ss << "Dtype4 in_vec[INVEC_SIZE];" << std::endl; + ss << "Dtype in_array[INVEC_SIZE * 4];" << std::endl; ss << "} in_buf;" << std::endl; ss << "for(int_tp kd = 0; kd < INPUT_DEPTH; kd++)" << std::endl; ss << "{" << std::endl; @@ -340,7 +361,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "in_buf.in_vec[reg].s3 = *(inputs + in_offset + 3);" << std::endl; ss << "} else {" << std::endl; // read SIMD_SIZE elements - ss << "in_buf.in_vec[reg] = *(global float4*)(inputs + in_offset);" + ss << "in_buf.in_vec[reg] = vload4(0, inputs + in_offset);" << std::endl; ss << "if (curr_x + 1 >= input_width + INPUT_PAD_W)" << std::endl; ss << "in_buf.in_vec[reg].s1 = 0;" << std::endl; @@ -355,7 +376,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "curr_y += TILE_Y_STRIDE;" << std::endl; ss << "#else" << std::endl; // read SIMD_SIZE elements - ss << "in_buf.in_vec[reg] = *(global float4*)(inputs + in_offset);" + ss << "in_buf.in_vec[reg] = *(global Dtype4*)(inputs + in_offset);" << std::endl; ss << "#endif" << std::endl; ss << "in_offset += input_width * TILE_Y_STRIDE;" << std::endl; @@ -370,21 +391,21 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "#define WEIGHT_PREF 1" << std::endl; ss << "#endif" << std::endl; ss << "union {" << std::endl; - ss << "float w[WEIGHT_PREF];" << std::endl; + ss << "Dtype w[WEIGHT_PREF];" << std::endl; ss << "#if KERNEL_WIDTH * KERNEL_HEIGHT != 1" << std::endl; - ss << "uint8 ui8;" << std::endl; + ss << "INT_TYPE8 ui8;" << std::endl; ss << "#endif" << std::endl; ss << "} weight_buf;" << std::endl; ss << "int_tp w_idx=0;" << std::endl; ss << "uint_tp orig_weight_addr = weight_addr;" << std::endl; ss << "#if KERNEL_WIDTH * KERNEL_HEIGHT != 1" << std::endl; ss << "weight_buf.ui8 = " - << "intel_sub_group_block_read8((__global uint *)&weights[weight_addr]);" + << "SUB_GROUP_BLOCK_READ8((__global INT_TYPE *)&weights[weight_addr]);" << std::endl; ss << "weight_addr += SIMD_SIZE * WEIGHT_PREF;" << std::endl; ss << "#else" << std::endl; - ss << "weight_buf.w[0] = as_float(" - << "intel_sub_group_block_read((__global uint *)&weights[weight_addr]));" + ss << "weight_buf.w[0] = as_Dtype(" + << "SUB_GROUP_BLOCK_READ((__global INT_TYPE *)&weights[weight_addr]));" << std::endl; ss << "weight_addr += SIMD_SIZE * 1;" << std::endl; ss << "#endif" << std::endl; @@ -402,7 +423,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "{" << std::endl; ss << "for(int_tp br=0; br < OUT_BLOCK_HEIGHT; br++) {" << std::endl; ss << "for(int_tp bc=0; bc < OUT_BLOCK_WIDTH; bc++) {" << std::endl; - ss << "float input = BLOCK_IN((br * STRIDE_Y + kr * DILATION_Y) * " + ss << "Dtype input = BLOCK_IN((br * STRIDE_Y + kr * DILATION_Y) * " << "TILE_X + bc * STRIDE_X + kc * DILATION_X);" << std::endl; ss << "out[br * OUT_BLOCK_WIDTH + bc] = " << "mad(weight_buf.w[w_idx % WEIGHT_PREF], " @@ -418,7 +439,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "#endif" << std::endl; ss << ") {" << std::endl; ss << "weight_buf.ui8 = " - << "intel_sub_group_block_read8((__global uint *)&weights[weight_addr]);" + << "SUB_GROUP_BLOCK_READ8((__global INT_TYPE *)&weights[weight_addr]);" << std::endl; // weights must be stored in just the right SIMD swizzled format // for this to work, see host code for details. @@ -434,15 +455,15 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "weight_buf.w[0] = weights[weight_addr];" << std::endl; ss << "#elif KERNEL_WIDTH * KERNEL_HEIGHT % 8 == 2" << std::endl; ss << "weight_buf.ui8.s01 = " - << "intel_sub_group_block_read2((__global uint *)&weights[weight_addr]);" + << "SUB_GROUP_BLOCK_READ2((__global INT_TYPE *)&weights[weight_addr]);" << std::endl; ss << "#elif KERNEL_WIDTH * KERNEL_HEIGHT % 8 <= 4" << std::endl; ss << "weight_buf.ui8.s0123 = " - << "intel_sub_group_block_read4((__global uint *)&weights[weight_addr]);" + << "SUB_GROUP_BLOCK_READ4((__global INT_TYPE *)&weights[weight_addr]);" << std::endl; ss << "#else" << std::endl; ss << "weight_buf.ui8 = " - << "intel_sub_group_block_read8((__global uint *)&weights[weight_addr]);" + << "SUB_GROUP_BLOCK_READ8((__global INT_TYPE *)&weights[weight_addr]);" << std::endl; ss << "#endif" << std::endl; ss << "#endif" << std::endl; @@ -468,7 +489,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, << "output_width * output_height;" << std::endl; ss << "out_addr += or * output_width + oc;" << std::endl; - ss << "float bias = biases[(fm % ALIGNED_NUM_FILTERS)];" << std::endl; + ss << "Dtype bias = biases[(fm % ALIGNED_NUM_FILTERS)];" << std::endl; ss << "#ifndef WRITE_PADDED_VALUES" << std::endl; ss << "if(get_global_id(0) != (get_global_size(0)-1) &&" << std::endl; ss << "get_global_id(1) != (get_global_size(1)-1) )" << std::endl; @@ -563,9 +584,9 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, #define TYPEDEF_FLOAT_N(ele_num) \ do { \ - ss << "typedef struct float" << ele_num << " { "; \ - for (int_tp i = 0; i < ele_num; i++) { ss << "float s" << i << "; ";} \ - ss << "} float" << ele_num << ";" << std::endl; \ + ss << "typedef struct Dtype" << ele_num << " { "; \ + for (int_tp i = 0; i < ele_num; i++) { ss << "Dtype s" << i << "; ";} \ + ss << "} Dtype" << ele_num << ";" << std::endl; \ } while (0) TYPEDEF_FLOAT_N(1); @@ -580,7 +601,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, TYPEDEF_FLOAT_N(14); TYPEDEF_FLOAT_N(15); // never used but makes compiler happy. - ss << "typedef struct float0 { float s0; } float0;" << std::endl; + ss << "typedef struct Dtype0 { Dtype s0; } Dtype0;" << std::endl; LibDNN::add_def(ss, "OUT_PITCH_X", "output_width"); LibDNN::add_def(ss, "OUT_PITCH_Y", "(output_width * output_height)"); @@ -624,10 +645,10 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, else if (!IsBeignet(&ctx)) ss << "__attribute__((intel_reqd_sub_group_size(16)))" << std::endl; ss << "__kernel void Conv_Interleaved(" << std::endl; - ss << "const __global float *src0," << std::endl; - ss << "const __global float *src1," << std::endl; - ss << "const __global float *biases," << std::endl; - ss << "__global float *dst," << std::endl; + ss << "const __global Dtype *src0," << std::endl; + ss << "const __global Dtype *src1," << std::endl; + ss << "const __global Dtype *biases," << std::endl; + ss << "__global Dtype *dst," << std::endl; ss << "const ushort input_width," << std::endl; ss << "const ushort input_height," << std::endl; ss << "const ushort output_width," << std::endl; @@ -641,7 +662,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "int_tp interleaved_y;" << std::endl; ss << "int_tp kernel_y;" << std::endl; ss << "int_tp kernel_idx;" << std::endl; - ss << "typedef CAT( float, KERNEL_WIDTH ) float_t;" << std::endl; + ss << "typedef CAT( Dtype, KERNEL_WIDTH ) Dtype_t;" << std::endl; // True for all threads if filter_width is multiple of TILE_N // else, true for all but right-most column of threads. ss << "if( TILE_N_LAST == 0 || global_x < WIDTH1 / TILE_N ) " << std::endl; @@ -650,19 +671,19 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, // LWG size is 1x8 or 1x16. // Thus each thread calculates (8 or 16) *M rows x N cols of ctile. if (simd_size == 16) { - ss << "float16 blockC00 = 0.f;" << std::endl; - ss << "float16 blockC10 = 0.f;" << std::endl; + ss << "Dtype16 blockC00 = 0.f;" << std::endl; + ss << "Dtype16 blockC10 = 0.f;" << std::endl; } else { - ss << "float8 blockC00 = 0.f;" << std::endl; - ss << "float8 blockC10 = 0.f;" << std::endl; - ss << "float8 blockC20 = 0.f;" << std::endl; - ss << "float8 blockC30 = 0.f;" << std::endl; + ss << "Dtype8 blockC00 = 0.f;" << std::endl; + ss << "Dtype8 blockC10 = 0.f;" << std::endl; + ss << "Dtype8 blockC20 = 0.f;" << std::endl; + ss << "Dtype8 blockC30 = 0.f;" << std::endl; } if (blockM == 2 && simd_size == 8) { - ss << "float8 blockC01 = 0.f;" << std::endl; - ss << "float8 blockC11 = 0.f;" << std::endl; - ss << "float8 blockC21 = 0.f;" << std::endl; - ss << "float8 blockC31 = 0.f;" << std::endl; + ss << "Dtype8 blockC01 = 0.f;" << std::endl; + ss << "Dtype8 blockC11 = 0.f;" << std::endl; + ss << "Dtype8 blockC21 = 0.f;" << std::endl; + ss << "Dtype8 blockC31 = 0.f;" << std::endl; } // Src0 (patch input) is directly used as atile. // Each work item points to the start of a different patch. @@ -686,7 +707,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "int_tp saved_y1 = curr_y1;" << std::endl; } } - ss << "const __global float *src0_read = src0" << std::endl; + ss << "const __global Dtype *src0_read = src0" << std::endl; // batch offset ss << "+ ALIGNED_INPUT_SIZE * global_z" << std::endl; // y offset @@ -694,7 +715,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, // x offset ss << "+ (curr_x - INPUT_PAD_W);" << std::endl; if (blockM == 2) { - ss << "const __global float *src0_read1 = src0" << std::endl; + ss << "const __global Dtype *src0_read1 = src0" << std::endl; // batch offset ss << "+ ALIGNED_INPUT_SIZE * global_z" << std::endl; // y offset @@ -705,7 +726,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, // Src1 (filter) is directly used as btile. // It starts at the top of src1 and walks down. // btile is K rows x N columns. - ss << "const __global float *src1_read = src1 + ( global_x * TILE_N * 2);" + ss << "const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2);" << std::endl; // Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1. // Inner loop loads and FMADs one row (KERNEL_WIDTH) of each input patch @@ -730,7 +751,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, * Load atile and btile. * * Kernel data is partially interleaved. - * Every 2 rows are interleaved at float8 granularity. + * Every 2 rows are interleaved at Dtype8 granularity. * The exception is that if KERNEL_WIDTH is odd the last row is not * interleaved. * The non interleaved row is padded with zero to ensure same size @@ -748,17 +769,17 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, << std::endl; if (this->pad_[0] == 0 && this->pad_[1] == 0 && this->dilation_[1] == 1 && this->dilation_[0] == 1) { - ss << "float_t blockA00 = ( (const __global float_t*)src0_read )[0];" + ss << "Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[0];" << std::endl; - ss << "float* pblockA00 = (float*)(&blockA00);" << std::endl; + ss << "Dtype* pblockA00 = (Dtype*)(&blockA00);" << std::endl; if (blockM == 2) { - ss << "float_t blockA01 = ( (const __global float_t*)src0_read1 )[0];" + ss << "Dtype_t blockA01 = ( (const __global Dtype_t*)src0_read1 )[0];" << std::endl; - ss << "float* pblockA01 = (float*)(&blockA01);" << std::endl; + ss << "Dtype* pblockA01 = (Dtype*)(&blockA01);" << std::endl; } } else { - ss << "float_t blockA00;" << std::endl; - ss << "float* pblockA00 = (float*)(&blockA00);" << std::endl; + ss << "Dtype_t blockA00;" << std::endl; + ss << "Dtype* pblockA00 = (Dtype*)(&blockA00);" << std::endl; ss << "int_tp pos = 0;" << std::endl; ss << "LOOP(KERNEL_WIDTH, pos," << std::endl; ss << "{" << std::endl; @@ -773,8 +794,8 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "})" << std::endl; ss << "curr_y += DILATION_Y;" << std::endl; if (blockM == 2) { - ss << "float_t blockA01;" << std::endl; - ss << "float* pblockA01 = (float*)(&blockA01);" << std::endl; + ss << "Dtype_t blockA01;" << std::endl; + ss << "Dtype* pblockA01 = (Dtype*)(&blockA01);" << std::endl; ss << "pos = 0;" << std::endl; ss << "LOOP(KERNEL_WIDTH, pos," << std::endl; ss << "{" << std::endl; @@ -795,20 +816,20 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "src0_read1 += (ROW_PITCH * DILATION_Y);" << std::endl; } ss << "uint blockB00[KERNEL_WIDTH * (TILE_N_PER_LANE)];" << std::endl; - ss << "float8* p8BlockB00 = (float8*)blockB00;" << std::endl; - ss << "float4* p4BlockB00 = (float4*)blockB00;" << std::endl; - ss << "float2* p2BlockB00 = (float2*)blockB00;" << std::endl; - ss << "float* pBlockB00 = (float* )blockB00;" << std::endl; + ss << "Dtype8* p8BlockB00 = (Dtype8*)blockB00;" << std::endl; + ss << "Dtype4* p4BlockB00 = (Dtype4*)blockB00;" << std::endl; + ss << "Dtype2* p2BlockB00 = (Dtype2*)blockB00;" << std::endl; + ss << "Dtype* pBlockB00 = (Dtype* )blockB00;" << std::endl; ss << "interleaved_y = 0;" << std::endl; ss << "LOOP(KERNEL_WIDTH_DIV2, interleaved_y, " << std::endl; ss << "{ " << std::endl; if (simd_size == 8) { - ss << "p8BlockB00[interleaved_y] = as_float8(" - << "intel_sub_group_block_read8( (const __global uint*)src1_read ) ); " + ss << "p8BlockB00[interleaved_y] = as_Dtype8(" + << "SUB_GROUP_BLOCK_READ8( (const __global INT_TYPE *)src1_read ) ); " << std::endl; } else { - ss << "p4BlockB00[interleaved_y] = as_float4(" - << "intel_sub_group_block_read4( (const __global uint*)src1_read ) ); " + ss << "p4BlockB00[interleaved_y] = as_Dtype4(" + << "SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE *)src1_read ) ); " << std::endl; } ss << "src1_read += WIDTH1 * 2;" << std::endl; @@ -816,12 +837,12 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "if ( kernel_width_is_odd )" << std::endl; ss << "{" << std::endl; if (simd_size == 8) { - ss << "p4BlockB00[KERNEL_WIDTH - 1] = as_float4(" - << "intel_sub_group_block_read4( (const __global uint*)src1_read ) ); " + ss << "p4BlockB00[KERNEL_WIDTH - 1] = as_Dtype4(" + << "SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE *)src1_read ) ); " << std::endl; } else { - ss << "p2BlockB00[KERNEL_WIDTH - 1] = as_float2(" - << "intel_sub_group_block_read2( (const __global uint*)src1_read ) ); " + ss << "p2BlockB00[KERNEL_WIDTH - 1] = as_Dtype2(" + << "SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE *)src1_read ) ); " << std::endl; } ss << "src1_read += WIDTH1 * 2;" << std::endl; @@ -934,7 +955,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, // Dst resembles a cube of width x height x (output channel * batches). // Each tile writes: (SIMD * TILE_M) x 1 x TILE_N. // Partial writes most likely generated if padding used. - ss << "__global float *out = dst " << std::endl; + ss << "__global Dtype *out = dst " << std::endl; // batch offset ss << "+ global_z * OUT_PITCH_Z" << std::endl; // channel offset @@ -946,7 +967,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "+ ( ( global_y * TILE_M ) % output_width ) + OUT_PADDING_LEFT;" << std::endl; if (blockM == 2) { - ss << "__global float *out1 = dst " << std::endl; + ss << "__global Dtype *out1 = dst " << std::endl; ss << "+ global_z * OUT_PITCH_Z" << std::endl; ss << "+ ( group_x * TILE_N ) * OUT_PITCH_Y" << std::endl; ss << "+ ((global_y * TILE_M + 1) / output_width + OUT_PADDING_HEIGHT)*" @@ -954,22 +975,22 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "+ ( ( global_y * TILE_M + 1 ) % output_width ) + OUT_PADDING_LEFT;" << std::endl; } - ss << "float bias[TILE_N_PER_LANE];" << std::endl; - ss << "typedef CAT( float, TILE_N_PER_LANE) float_flex;" << std::endl; - ss << "float_flex *bias_vec;" << std::endl; - ss << "bias_vec = (float_flex*)bias;" << std::endl; + ss << "Dtype bias[TILE_N_PER_LANE];" << std::endl; + ss << "typedef CAT( Dtype, TILE_N_PER_LANE) Dtype_flex;" << std::endl; + ss << "Dtype_flex *bias_vec;" << std::endl; + ss << "bias_vec = (Dtype_flex*)bias;" << std::endl; if (simd_size == 16) { ss << "*bias_vec = " - << "as_float2(intel_sub_group_block_read2(" - << "(__global uint *)biases + group_x * TILE_N));" + << "as_Dtype2(SUB_GROUP_BLOCK_READ42(" + << "(__global INT_TYPE *)biases + group_x * TILE_N));" << std::endl; // Work around a potential compiler bug ss << "if (group_x > 0xFFFFFFFEul)" << std::endl; ss << "out[0] = bias[0] + bias[1];" << std::endl; } else { ss << "*bias_vec = " - << "as_float4(intel_sub_group_block_read4(" - << "(__global uint *)biases + group_x * TILE_N));" + << "as_Dtype4(SUB_GROUP_BLOCK_READ4(" + << "(__global INT_TYPE *)biases + group_x * TILE_N));" << std::endl; } ss << "if (global_y * TILE_M < output_width * output_height )" << std::endl; @@ -1020,7 +1041,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, // Result ctile (*dst) is M rows x N columns // LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile. ss << "int_tp i = 0;" << std::endl; - ss << "float8 blockC[TILE_N_LAST_DIV8];" << std::endl; + ss << "Dtype8 blockC[TILE_N_LAST_DIV8];" << std::endl; ss << "LOOP(TILE_N_LAST_DIV8, i," << std::endl; ss << "{" << std::endl; ss << "blockC[i] = 0.f;" << std::endl; @@ -1033,13 +1054,13 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, || this->dilation_[1] != 1 || this->dilation_[0] != 1) { ss << "int_tp saved_y = curr_y;" << std::endl; } - ss << "const __global float *src0_read = src0" << std::endl; + ss << "const __global Dtype *src0_read = src0" << std::endl; ss << "+ ALIGNED_INPUT_SIZE * global_z" << std::endl; ss << "+ (curr_y - INPUT_PAD_H) * ROW_PITCH" << std::endl; ss << "+ (curr_x - INPUT_PAD_W);" << std::endl; if (blockM == 2) { ss << "i = 0;" << std::endl; - ss << "float8 blockC1[TILE_N_LAST_DIV8];" << std::endl; + ss << "Dtype8 blockC1[TILE_N_LAST_DIV8];" << std::endl; ss << "LOOP(TILE_N_LAST_DIV8, i," << std::endl; ss << "{" << std::endl; ss << "blockC1[i] = 0.f;" << std::endl; @@ -1054,12 +1075,12 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, || this->dilation_[1] != 1 || this->dilation_[0] != 1) { ss << "int_tp saved_y1 = curr_y1;" << std::endl; } - ss << "const __global float *src0_read1 = src0" << std::endl; + ss << "const __global Dtype *src0_read1 = src0" << std::endl; ss << "+ ALIGNED_INPUT_SIZE * global_z" << std::endl; ss << "+ (curr_y1 - INPUT_PAD_H) * ROW_PITCH" << std::endl; ss << "+ (curr_x1 - INPUT_PAD_W);" << std::endl; } - ss << "const __global float *src1_read = src1 + ( global_x * TILE_N * 2);" + ss << "const __global Dtype *src1_read = src1 + ( global_x * TILE_N * 2);" << std::endl; ss << "int_tp patch_depth = 0;" << std::endl; ss << "do" << std::endl; @@ -1078,17 +1099,17 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, << std::endl; if (this->pad_[0] == 0 && this->pad_[1] == 0 && this->dilation_[1] == 1 && this->dilation_[0] == 1) { - ss << "float_t blockA00 = ( (const __global float_t*)src0_read )[0];" + ss << "Dtype_t blockA00 = ( (const __global Dtype_t*)src0_read )[0];" << std::endl; - ss << "float* pblockA00 = (float*)(&blockA00);" << std::endl; + ss << "Dtype* pblockA00 = (Dtype*)(&blockA00);" << std::endl; if (blockM == 2) { - ss << "float_t blockA01 = ( (const __global float_t*)src0_read1 )[0];" + ss << "Dtype_t blockA01 = ( (const __global Dtype_t*)src0_read1 )[0];" << std::endl; - ss << "float* pblockA01 = (float*)(&blockA01);" << std::endl; + ss << "Dtype* pblockA01 = (Dtype*)(&blockA01);" << std::endl; } } else { - ss << "float_t blockA00;" << std::endl; - ss << "float* pblockA00 = (float*)(&blockA00);" << std::endl; + ss << "Dtype_t blockA00;" << std::endl; + ss << "Dtype* pblockA00 = (Dtype*)(&blockA00);" << std::endl; ss << "int_tp pos = 0;" << std::endl; ss << "LOOP(KERNEL_WIDTH, pos," << std::endl; ss << "{" << std::endl; @@ -1103,8 +1124,8 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "})" << std::endl; ss << "curr_y += DILATION_Y;" << std::endl; if (blockM == 2) { - ss << "float_t blockA01;" << std::endl; - ss << "float* pblockA01 = (float*)(&blockA01);" << std::endl; + ss << "Dtype_t blockA01;" << std::endl; + ss << "Dtype* pblockA01 = (Dtype*)(&blockA01);" << std::endl; ss << "pos = 0;" << std::endl; ss << "LOOP(KERNEL_WIDTH, pos," << std::endl; ss << "{" << std::endl; @@ -1124,57 +1145,57 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, if (blockM == 2) { ss << "src0_read1 += (ROW_PITCH * DILATION_Y);" << std::endl; } - ss << "float blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8];" << std::endl; + ss << "Dtype blockB[KERNEL_WIDTH * TILE_N_LAST_DIV8];" << std::endl; ss << "interleaved_y = 0;" << std::endl; ss << "LOOP(KERNEL_WIDTH_DIV2, interleaved_y, " << std::endl; ss << "{ " << std::endl; ss << "#if TILE_N_LAST_DIV8 == 1" << std::endl; - ss << "float2* p2BlockB = (float2* )blockB;" << std::endl; - ss << "p2BlockB[interleaved_y] = as_float2(" - << "intel_sub_group_block_read2( (const __global uint*)src1_read ) );" + ss << "Dtype2* p2BlockB = (Dtype2* )blockB;" << std::endl; + ss << "p2BlockB[interleaved_y] = as_Dtype2(" + << "SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE *)src1_read ) );" << std::endl; ss << "#elif TILE_N_LAST_DIV8 == 2" << std::endl; - ss << "float4* p4BlockB = (float4* )blockB;" << std::endl; - ss << "p4BlockB[interleaved_y] = as_float4(" - << "intel_sub_group_block_read4( (const __global uint*)src1_read ) );" + ss << "Dtype4* p4BlockB = (Dtype4* )blockB;" << std::endl; + ss << "p4BlockB[interleaved_y] = as_Dtype4(" + << "SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE *)src1_read ) );" << std::endl; ss << "#elif TILE_N_LAST_DIV8 == 3" << std::endl; ss << "//TODO: broken. No block_read6" << std::endl; - ss << "float6* p6BlockB = (float6* )blockB;" << std::endl; - ss << "(*((float8*)(&p6BlockB[interleaved_y]))).s0123 = as_float4(" - << "intel_sub_group_block_read4( (const __global uint*)src1_read ) );" + ss << "Dtype6* p6BlockB = (Dtype6* )blockB;" << std::endl; + ss << "(*((Dtype8*)(&p6BlockB[interleaved_y]))).s0123 = as_Dtype4(" + << "SUB_GROUP_BLOCK_READ4( (const __global INT_TYPE *)src1_read ) );" << std::endl; - ss << "(*((float8*)(&p6BlockB[interleaved_y]))).s45 = as_float2(" - << "intel_sub_group_block_read2(" - << "(const __global uint*)(src1_read + 4 * 8)));" << std::endl; + ss << "(*((Dtype8*)(&p6BlockB[interleaved_y]))).s45 = as_Dtype2(" + << "SUB_GROUP_BLOCK_READ2(" + << "(const __global INT_TYPE *)(src1_read + 4 * 8)));" << std::endl; ss << "#endif" << std::endl; ss << "src1_read += WIDTH1 * 2;" << std::endl; ss << "} )" << std::endl; ss << "if ( kernel_width_is_odd )" << std::endl; ss << "{" << std::endl; ss << "#if TILE_N_LAST_DIV8 == 1" << std::endl; - ss << "float* pBlockB = (float* )blockB;" << std::endl; - ss << "pBlockB[KERNEL_WIDTH - 1] = as_float(" - << "intel_sub_group_block_read( (const __global uint*)src1_read ) );" + ss << "Dtype* pBlockB = (Dtype* )blockB;" << std::endl; + ss << "pBlockB[KERNEL_WIDTH - 1] = as_Dtype(" + << "SUB_GROUP_BLOCK_READ( (const __global INT_TYPE *)src1_read ) );" << std::endl; ss << "#elif TILE_N_LAST_DIV8 == 2" << std::endl; - ss << "float2* p2BlockB = (float2* )blockB;" << std::endl; - ss << "p2BlockB[KERNEL_WIDTH - 1] = as_float2(" - << "intel_sub_group_block_read2( (const __global uint*)src1_read ) );" + ss << "Dtype2* p2BlockB = (Dtype2* )blockB;" << std::endl; + ss << "p2BlockB[KERNEL_WIDTH - 1] = as_Dtype2(" + << "SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE *)src1_read ) );" << std::endl; ss << "#elif TILE_N_LAST_DIV8 == 3" << std::endl; - ss << "float3* p3BlockB = (float3* )blockB;" << std::endl; - ss << "p3BlockB[KERNEL_WIDTH - 1].s01 = as_float2(" - << "intel_sub_group_block_read2( (const __global uint*)src1_read ) );" + ss << "Dtype3* p3BlockB = (Dtype3* )blockB;" << std::endl; + ss << "p3BlockB[KERNEL_WIDTH - 1].s01 = as_Dtype2(" + << "SUB_GROUP_BLOCK_READ2( (const __global INT_TYPE *)src1_read ) );" << std::endl; - ss << "p3BlockB[KERNEL_WIDTH - 1].s2 = as_float(" - << "intel_sub_group_block_read( (const __global uint*)" + ss << "p3BlockB[KERNEL_WIDTH - 1].s2 = as_Dtype(" + << "SUB_GROUP_BLOCK_READ( (const __global INT_TYPE *)" << "(src1_read + 2 * 8)));" << std::endl; ss << "#endif" << std::endl; ss << "src1_read += WIDTH1 * 2;" << std::endl; ss << "}" << std::endl; ss << "// Perform MADs" << std::endl; - ss << "float* pBlockB = (float*)blockB;" << std::endl; + ss << "Dtype* pBlockB = (Dtype*)blockB;" << std::endl; ss << "kernel_idx = 0;" << std::endl; ss << "interleaved_y = 0;" << std::endl; ss << "LOOP(KERNEL_WIDTH_DIV2, interleaved_y, " << std::endl; @@ -1251,7 +1272,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, << std::endl; ss << "} " << std::endl; ss << "while ( ++patch_depth < INPUT_DEPTH );" << std::endl; - ss << "__global float *out = dst " << std::endl; + ss << "__global Dtype *out = dst " << std::endl; ss << "+ global_z * OUT_PITCH_Z" << std::endl; ss << "+ (group_x * TILE_N) * OUT_PITCH_Y" << std::endl; ss << "+ ((global_y * TILE_M) / output_width + " @@ -1259,7 +1280,7 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "+ ((global_y * TILE_M) % output_width ) + OUT_PADDING_LEFT;" << std::endl; if (blockM == 2) { - ss << "__global float *out1 = dst " << std::endl; + ss << "__global Dtype *out1 = dst " << std::endl; ss << "+ global_z * OUT_PITCH_Z" << std::endl; ss << "+ ( group_x * TILE_N ) * OUT_PITCH_Y" << std::endl; ss << "+ ((global_y * TILE_M + 1) / output_width + OUT_PADDING_HEIGHT ) *" @@ -1267,11 +1288,11 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, ss << "+ ((global_y * TILE_M + 1) % output_width ) + OUT_PADDING_LEFT;" << std::endl; } - ss << "float bias[4];" << std::endl; - ss << "float4 *bias_vec;" << std::endl; - ss << "bias_vec = (float4*)bias;" << std::endl; - ss << "*bias_vec = as_float4(intel_sub_group_block_read4(" - << "(__global uint *)biases + group_x * TILE_N));" << std::endl; + ss << "Dtype bias[4];" << std::endl; + ss << "Dtype4 *bias_vec;" << std::endl; + ss << "bias_vec = (Dtype4*)bias;" << std::endl; + ss << "*bias_vec = as_Dtype4(SUB_GROUP_BLOCK_READ4(" + << "(__global INT_TYPE *)biases + group_x * TILE_N));" << std::endl; ss << "if (global_y * TILE_M < output_width * output_height )" << std::endl; ss << "{" << std::endl; ss << "for (int_tp i = 0; i < 8; i++)" << std::endl; @@ -1567,7 +1588,7 @@ void LibDNNConvSpatial::calculate_verify_data(const Dtype* bottom, clEnqueueCopyBuffer(ctx.get_queue().handle().get(), (cl_mem)top_data_, (cl_mem)verify_data, 0, 0, - sizeof(float) * num_ * this->top_dim_, 0, NULL, NULL); + sizeof(Dtype) * num_ * this->top_dim_, 0, NULL, NULL); ctx.delete_program(kernelQueue[kernel_index_]->kernelName); kernelQueue.pop_back(); return; @@ -1605,7 +1626,12 @@ void LibDNNConvSpatial::ForwardBenchmark(const Dtype* bottom, template void LibDNNConvSpatial::generate_key() { + CHECK((!std::is_same::value)); std::stringstream keyBuilder; + if (std::is_same::value) + keyBuilder << "float_"; + else + keyBuilder << "half_"; // FIXME: to support fuse? keyBuilder << this->kernel_shape_[1] << "_" << this->kernel_shape_[0] << "_" @@ -1637,6 +1663,7 @@ template std::string LibDNNConvSpatial::generate_specific_key( int_tp type, int_tp blockWidth, int_tp blockHeight, int_tp blockDepth) { std::stringstream keyBuilder; + CHECK((!std::is_same::value)); keyBuilder << short_key_ << "_" << type << "_" << blockWidth @@ -1653,7 +1680,7 @@ void interleaveMatrix( CHECK_EQ(interleavedRows % 2, 0) << "interleaveMatrix only supports even values for interleavedRows."; - size_t memSize = r * c * sizeof(float); + size_t memSize = r * c * sizeof(Dtype); size_t dstSize = memSize * (interleavedRows + nonInterleavedRows * 2) / (interleavedRows + nonInterleavedRows); @@ -1822,11 +1849,12 @@ void LibDNNConvSpatial::swizzleWeights( } } -template<> -void LibDNNConvSpatial::calculate_global_size(int_tp batch, +template +void LibDNNConvSpatial::calculate_global_size(int_tp batch, int_tp* wio, // work item output size size_t* lSize, // local size size_t* gSize) { // global size + CHECK((!std::is_same::value)); gSize[0] = ceil( (fmax(static_cast(output_w_) / wio[0], 1.0)) / lSize[0]) * lSize[0]; @@ -1839,9 +1867,9 @@ void LibDNNConvSpatial::calculate_global_size(int_tp batch, / lSize[2]) * lSize[2]; } -template<> -bool LibDNNConvSpatial::create_basic_kernel( - const float *bottom, const float *top, +template +bool LibDNNConvSpatial::create_basic_kernel( + const Dtype *bottom, const Dtype *top, int_tp blockWidth, int_tp blockHeight, int_tp blockDepth) { int_tp workItemOutput[3]; @@ -1877,6 +1905,7 @@ void LibDNNConvSpatial::setBufferKernelArg( size_t size, bool readOnly, bool preserved) { + CHECK((!std::is_same::value)); if (offset == 0) { kernel->arg(argIdx, WrapHandle((cl_mem) buffer, ctx)); return; @@ -1920,13 +1949,14 @@ void LibDNNConvSpatial::cleanTmpSubBuffers( tmpSubBuffers.clear(); } -template<> -cl_int LibDNNConvSpatial::convolve( - const float *bottom, const float *top, +template +cl_int LibDNNConvSpatial::convolve( + const Dtype *bottom, const Dtype *top, int_tp index, int_tp numImages, kernelConfig* config) { + CHECK((!std::is_same::value)); viennacl::ocl::context &ctx = - viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); + viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); viennacl::ocl::program & program = ctx.get_program(config->kernelName); viennacl::ocl::kernel &kernel = program.get_kernel(config->kernelName); cl_int err = CL_SUCCESS; @@ -2046,7 +2076,7 @@ cl_int LibDNNConvSpatial::convolve( kernel.arg(argIdx++, (uint16_t)output_w_); kernel.arg(argIdx++, (uint16_t)output_h_); viennacl::ocl::context &ctx = - viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); + viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); err = clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), kernel.handle().get(), 3, NULL, @@ -2116,9 +2146,9 @@ cl_int LibDNNConvSpatial::convolve( return err; } -template<> -float LibDNNConvSpatial::timed_convolve( - const float *bottom, const float *top, +template +float LibDNNConvSpatial::timed_convolve( + const Dtype *bottom, const Dtype *top, int_tp index, int_tp numImages, kernelConfig* config) { // warm up. @@ -2158,11 +2188,11 @@ float LibDNNConvSpatial::timed_convolve( return elapsedTime; } -template<> -bool LibDNNConvSpatial::verify_result( - const float *bottom, const float *top, +template +bool LibDNNConvSpatial::verify_result( + const Dtype *bottom, const Dtype *top, int_tp index, - int_tp numImages, const float *verify_blob, kernelConfig* config) { + int_tp numImages, const Dtype *verify_blob, kernelConfig* config) { uint_tp verificationFail = 0; @@ -2171,26 +2201,29 @@ bool LibDNNConvSpatial::verify_result( else if (config->tested) return false; - greentea_memset(LibDNN::dev_ptr_->id(), - sizeof(float) * numImages * this->top_dim_, + if (std::is_same::value) + return true; + + greentea_memset(LibDNN::dev_ptr_->id(), + sizeof(Dtype) * numImages * this->top_dim_, 0, (cl_mem)top, 0); config->executionTime = timed_convolve(bottom, top, index, numImages, config); - const float *verify_data; - float *data; - float *tmp_verify_data; + const Dtype *verify_data; + Dtype *data; + Dtype *tmp_verify_data; viennacl::ocl::context &ctx = - viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); - data = reinterpret_cast(clEnqueueMapBuffer( + viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); + data = reinterpret_cast(clEnqueueMapBuffer( ctx.get_queue().handle().get(), (cl_mem)top, true, CL_MAP_READ, - 0, sizeof(float) * numImages * this->top_dim_, 0, NULL, NULL, NULL)); - tmp_verify_data = reinterpret_cast(clEnqueueMapBuffer( + 0, sizeof(Dtype) * numImages * this->top_dim_, 0, NULL, NULL, NULL)); + tmp_verify_data = reinterpret_cast(clEnqueueMapBuffer( ctx.get_queue().handle().get(), (cl_mem)verify_blob, true, CL_MAP_READ, - 0, sizeof(float) * numImages * this->top_dim_, + 0, sizeof(Dtype) * numImages * this->top_dim_, 0, NULL, NULL, NULL)); verify_data = tmp_verify_data; @@ -2210,7 +2243,8 @@ bool LibDNNConvSpatial::verify_result( && fabs(data[offset] - verify_data[offset]) < 1.e-4)) { dbgPrint(printf("test verification failed @ image %d group %d" "out_ch %d h %d w %d got %G expected %G\n", - n, g, out_ch, h, w, data[offset], verify_data[offset])); + n, g, out_ch, h, w, + (double)data[offset], (double)verify_data[offset])); verificationFail = 1; goto out; } @@ -2236,9 +2270,9 @@ viennacl::ocl::program LibDNNConvSpatial::compile_fw_kernel() { return ctx.add_program(LibDNN::kernel_.c_str(), kernel_name_); } -template<> -bool LibDNNConvSpatial::create_gemm_like_conv_kernel( - const float *bottom, const float *top, +template +bool LibDNNConvSpatial::create_gemm_like_conv_kernel( + const Dtype *bottom, const Dtype *top, int_tp blockM, int_tp blockK, int_tp blockN) { @@ -2253,10 +2287,10 @@ bool LibDNNConvSpatial::create_gemm_like_conv_kernel( int_tp globalWorkSizeDY = blockM; size_t sgemm_m = alignedExpandHeight; size_t sgemm_n = alignedFilterWidth; - size_t gx = static_cast(ceil(static_cast(sgemm_n) - / static_cast(globalWorkSizeDX))); - size_t gy = static_cast(ceil(static_cast(sgemm_m) - / static_cast(globalWorkSizeDY))); + size_t gx = static_cast(ceil(static_cast(sgemm_n) + / static_cast(globalWorkSizeDX))); + size_t gy = static_cast(ceil(static_cast(sgemm_m) + / static_cast(globalWorkSizeDY))); gy = ALIGN(gy, blockK); size_t gz = num_batches; size_t global_size[3] = { gx, gy, gz }; @@ -2278,7 +2312,7 @@ bool LibDNNConvSpatial::create_gemm_like_conv_kernel( NULL); viennacl::ocl::context &ctx = - viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); + viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); if (workgroupSize_used != simd_size) { ctx.delete_program(kernel_name_); return false; @@ -2295,9 +2329,9 @@ bool LibDNNConvSpatial::create_gemm_like_conv_kernel( } } -template<> -bool LibDNNConvSpatial::setup_IDLF( - const float *bottom, const float *top, +template +bool LibDNNConvSpatial::setup_IDLF( + const Dtype *bottom, const Dtype *top, int_tp blockWidth, int_tp blockHeight, int_tp simd_size) { int_tp workItemOutput[3] = { blockWidth, blockHeight, simd_size }; @@ -2333,7 +2367,7 @@ bool LibDNNConvSpatial::setup_IDLF( NULL); viennacl::ocl::context &ctx = - viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); + viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); if (workgroupSize_used != simd_size) { ctx.delete_program(kernel_name_); return false; @@ -2350,9 +2384,9 @@ bool LibDNNConvSpatial::setup_IDLF( } } -template<> -bool LibDNNConvSpatial::tune_local_size( - const float *bottom, const float *top, +template +bool LibDNNConvSpatial::tune_local_size( + const Dtype *bottom, const Dtype *top, kernelConfig* config) { if (config->use_null_local || !config->autoTune) return true; @@ -2440,9 +2474,9 @@ bool LibDNNConvSpatial::tune_local_size( return true; } -template<> -void LibDNNConvSpatial::create_convolution_kernel( - const float *bottom, const float *top, +template +void LibDNNConvSpatial::create_convolution_kernel( + const Dtype *bottom, const Dtype *top, int_tp kernelType, int_tp blockWidth, int_tp blockHeight, int_tp blockDepth) { @@ -2457,19 +2491,19 @@ void LibDNNConvSpatial::create_convolution_kernel( assert(0); } -template<> -void LibDNNConvSpatial::setup_convolution( - const float *bottom, const float *top, - const float *verify_blob) { +template +void LibDNNConvSpatial::setup_convolution( + const Dtype *bottom, const Dtype *top, + const Dtype *verify_blob) { // Initializes unique kernel ID kernel_uid_ = 0; - if (LibDNN::dev_ptr_->CheckCapability("cl_intel_subgroups")) { + if (LibDNN::dev_ptr_->CheckCapability("cl_intel_subgroups")) { /* IDLF kernels are using Intel specific extension which make them intel only. */ // Generates static key_ viennacl::ocl::context &ctx = - viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); + viennacl::ocl::get_context(LibDNN::dev_ptr_->id()); int_tp max_compute_units = ctx.current_device().max_compute_units(); int_tp kernelCnt = 0; if (this->group_ == 1 @@ -2477,7 +2511,9 @@ void LibDNNConvSpatial::setup_convolution( && (this->M_FW_ % 32 != 24))) { create_convolution_kernel(bottom, top, 5, 1, 8, 32); create_convolution_kernel(bottom, top, 5, 2, 8, 32); - if (this->kernel_shape_[1] < 4 && this->M_FW_ % 32 == 0) + if ((this->kernel_shape_[1] < 4 || + (std::is_same::value)) + && this->M_FW_ % 32 == 0) create_convolution_kernel(bottom, top, 5, 1, 16, 32); } @@ -2511,7 +2547,7 @@ void LibDNNConvSpatial::setup_convolution( if (simd_size == 8 && this->M_FW_ >= 16 && ((num_ * this->M_FW_ * output_w_ * output_h_ / - static_cast(width * height)) + static_cast(width * height)) >= max_compute_units * 7 * 16)) continue; int_tp tile_x = (this->kernel_shape_[1] * this->dilation_[1] @@ -2755,94 +2791,6 @@ void LibDNNConvSpatial::SetUp( } } -template<> -void LibDNNConvSpatial::create_convolution_kernel( - const double *bottom, const double *top, - int_tp kernelType, - int_tp blockWidth, int_tp blockHeight, - int_tp blockDepth) { - NOT_IMPLEMENTED; - return; -} - -template<> -bool LibDNNConvSpatial::setup_IDLF( - const double *bottom, const double *top, - int_tp blockWidth, - int_tp blockHeight, int_tp blockDepth) { - NOT_IMPLEMENTED; - return false; -} - -template<> -bool LibDNNConvSpatial::create_gemm_like_conv_kernel( - const double *bottom, const double *top, - int_tp blockWidth, - int_tp blockHeight, int_tp blockDepth) { - NOT_IMPLEMENTED; - return false; -} - - -template<> -bool LibDNNConvSpatial::verify_result( - const double *bottom, const double *top, - int_tp index, - int_tp numImages, const double *verify_blob, kernelConfig* config) { - NOT_IMPLEMENTED; - return false; -} - -template<> -bool LibDNNConvSpatial::create_basic_kernel( - const double *bottom, const double *top, - int_tp blockWidth, - int_tp blockHeight, int_tp blockDepth) { - NOT_IMPLEMENTED; - return false; -} - -template<> -bool LibDNNConvSpatial::tune_local_size( - const double *bottom, const double *top, - kernelConfig* config) { - NOT_IMPLEMENTED; - return false; -} - -template<> -cl_int LibDNNConvSpatial::convolve( - const double *bottom, const double *top, - int_tp index, - int_tp numImages, kernelConfig* config) { - NOT_IMPLEMENTED; - return false; -} - -template<> -float LibDNNConvSpatial::timed_convolve( - const double *bottom, const double *top, - int_tp index, - int_tp numImages, kernelConfig* config) { - NOT_IMPLEMENTED; - return 0.f; -} - -template<> -void LibDNNConvSpatial::setup_convolution( - const double *bottom, const double *top, - const double *verify_blob) { - NOT_IMPLEMENTED; -} - -template<> -void LibDNNConvSpatial::calculate_global_size( - int_tp batch, - int_tp* workItemOutput, - size_t* localSizes, size_t* globalSizes) { - NOT_IMPLEMENTED; -} - INSTANTIATE_CLASS(LibDNNConvSpatial); } // namespace caffe diff --git a/src/caffe/greentea/libdnn_pool.cpp b/src/caffe/greentea/libdnn_pool.cpp index 65522a4c194..2074c69a9a8 100755 --- a/src/caffe/greentea/libdnn_pool.cpp +++ b/src/caffe/greentea/libdnn_pool.cpp @@ -201,6 +201,16 @@ template std::string LibDNNPool::generate_fw_kernels(std::string name, bool test_mode) { std::stringstream ss; +#ifdef HAS_HALF_SUPPORT + if (std::is_same::value) { + ss << "#define DTYPE_MAX HALF_MAX" << std::endl; + ss << "#define DTYPE_MIN HALF_MIN" << std::endl; + } else +#endif + { + ss << "#define DTYPE_MAX FLT_MAX" << std::endl; + ss << "#define DTYPE_MIN FLT_MIN" << std::endl; + } ss << "__kernel void " + name + "("; ss << "__global const Dtype* __restrict bottom_data, "; @@ -248,7 +258,7 @@ std::string LibDNNPool::generate_fw_kernels(std::string name, ss << "__global int_tp* mask_ptr = mask + get_global_id(1) * v_imso;" << std::endl; } - ss << "Dtype val = -FLT_MAX;" << std::endl; + ss << "Dtype val = -DTYPE_MAX;" << std::endl; ss << "int_tp maxidx = -1;" << std::endl; } @@ -258,7 +268,7 @@ std::string LibDNNPool::generate_fw_kernels(std::string name, if (pool_method_ == LIBDNN_POOLING_METHOD_STO) { if (test_mode) { - ss << "Dtype cumsum = FLT_MIN;" << std::endl; + ss << "Dtype cumsum = DTYPE_MIN;" << std::endl; ss << "Dtype cumvalues = 0;" << std::endl; } else { ss << "__global Dtype* rand_ptr = rand_idx + get_global_id(1) * v_imso;" @@ -387,7 +397,7 @@ std::string LibDNNPool::generate_fw_kernels(std::string name, ss << "if (cumsum > thres) {" << std::endl; ss << "stoidx = in_idx + " << kernel_offset << ";" << std::endl; ss << "val = in_ptr[" << kernel_offset << "];" << std::endl; - ss << "thres = FLT_MAX;" << std::endl; + ss << "thres = DTYPE_MAX;" << std::endl; ss << "}" << std::endl; } } diff --git a/src/caffe/layers/batch_norm_layer.cu b/src/caffe/layers/batch_norm_layer.cu index 3fbf1f0bbff..73747e0427d 100644 --- a/src/caffe/layers/batch_norm_layer.cu +++ b/src/caffe/layers/batch_norm_layer.cu @@ -10,8 +10,8 @@ namespace caffe { oclk_bn_use_global_stats.arg(argIdx++, num); \ oclk_bn_use_global_stats.arg(argIdx++, channels_); \ oclk_bn_use_global_stats.arg(argIdx++, spatial_dim); \ - oclk_bn_use_global_stats.arg(argIdx++, scale_factor); \ - oclk_bn_use_global_stats.arg(argIdx++, eps_); \ + oclk_bn_use_global_stats.arg(argIdx++, fixup_arg_type(scale_factor)); \ + oclk_bn_use_global_stats.arg(argIdx++, fixup_arg_type(eps_)); \ oclk_bn_use_global_stats.arg(argIdx++, WrapHandle((cl_mem) this->blobs_[0]->gpu_data(), &ctx)); \ oclk_bn_use_global_stats.arg(argIdx++, WrapHandle((cl_mem) this->blobs_[1]->gpu_data(), &ctx)); diff --git a/src/caffe/layers/contrastive_loss_layer.cpp b/src/caffe/layers/contrastive_loss_layer.cpp index 83e96c0f643..90b961359c4 100644 --- a/src/caffe/layers/contrastive_loss_layer.cpp +++ b/src/caffe/layers/contrastive_loss_layer.cpp @@ -49,7 +49,7 @@ void ContrastiveLossLayer::Forward_cpu( loss += dist_sq_.cpu_data()[i]; } else { // dissimilar pairs if (legacy_version) { - loss += std::max(margin - dist_sq_.cpu_data()[i], Dtype(0.0)); + loss += fmax(margin - dist_sq_.cpu_data()[i], Dtype(0.0)); } else { Dtype dist = std::max(margin - sqrt(dist_sq_.cpu_data()[i]), Dtype(0.0)); diff --git a/src/caffe/layers/contrastive_loss_layer.cu b/src/caffe/layers/contrastive_loss_layer.cu index 80c120d75a5..54c50093504 100644 --- a/src/caffe/layers/contrastive_loss_layer.cu +++ b/src/caffe/layers/contrastive_loss_layer.cu @@ -59,9 +59,9 @@ void ContrastiveLossLayer::Forward_gpu( loss += dist_sq_.cpu_data()[i]; } else { // dissimilar pairs if (legacy_version) { - loss += std::max(margin - dist_sq_.cpu_data()[i], Dtype(0.0)); + loss += fmax(margin - dist_sq_.cpu_data()[i], Dtype(0.0)); } else { - Dtype dist = std::max(margin - (Dtype) sqrt(dist_sq_.cpu_data()[i]), + Dtype dist = fmax(margin - (Dtype) sqrt(dist_sq_.cpu_data()[i]), Dtype(0.0)); loss += dist * dist; } @@ -141,7 +141,8 @@ void ContrastiveLossLayer::Backward_gpu( CL_KERNEL_SELECT("cll_backward")); viennacl::ocl::enqueue( oclk_cll( - count, channels, margin, alpha, + count, channels, fixup_arg_type(margin), + fixup_arg_type(alpha), WrapHandle((cl_mem) (bottom[2]->gpu_data()), &ctx), WrapHandle((cl_mem) (diff_.gpu_data()), &ctx), WrapHandle((cl_mem) (dist_sq_.gpu_data()), &ctx), diff --git a/src/caffe/layers/conv_layer_spatial.cpp b/src/caffe/layers/conv_layer_spatial.cpp index f3bbb4a5799..ad8c7c48a25 100644 --- a/src/caffe/layers/conv_layer_spatial.cpp +++ b/src/caffe/layers/conv_layer_spatial.cpp @@ -145,7 +145,7 @@ void ConvolutionLayerSpatial::Reshape(const vector*>& bottom, caffe_set(N_, Dtype(1), bias_multiplier_.mutable_cpu_data()); } - if (std::is_same::value) { + if (!std::is_same::value) { this->num_ = bottom[0]->count(0, this->channel_axis_); SetUp(bottom, top, Caffe::GetDefaultDevice()->backend()); } @@ -207,7 +207,7 @@ void ConvolutionLayerSpatial::Backward_cpu( #ifndef CPU_ONLY #ifdef USE_GREENTEA -// #define dbg + #define dbg #ifdef dbg #define dbgPrint(x) (x) #else @@ -217,28 +217,33 @@ void ConvolutionLayerSpatial::Backward_cpu( // For large enough input size, we do not need to tune kernels for different // size. The reason is with large input size, there will be enough work items // to feed al the EUs. -// FIXME for the gemm like convolution, switch back to eaxct image size. #define TUNING_SIZE(x) ((x) > 256 ? 256 : (ALIGN(x, 16))) -template<> -void ConvolutionLayerSpatial::generate_key() { + +template +void ConvolutionLayerSpatial::generate_key() { + CHECK((!std::is_same::value)); std::stringstream keyBuilder; + if (std::is_same::value) + keyBuilder << "float_"; + else + keyBuilder << "half_"; keyBuilder << this->layer_param_.convolution_param().fuse_type() << "_" << kernel_w_ << "_" << kernel_h_ << "_" - << channels_ << "_" - << group_ << "_" + << this->channels_ << "_" + << this->group_ << "_" << stride_h_ << "_" << stride_w_ << "_" << dilation_h_ << "_" << dilation_w_ << "_" - << bias_term_ << "_" + << this->bias_term_ << "_" << TUNING_SIZE(width_) << "_" << TUNING_SIZE(height_) << "_" << pad_w_ << "_" << pad_h_ << "_" - << num_ << "_" + << this->num_ << "_" << M_; viennacl::ocl::context &ctx = viennacl::ocl::get_context @@ -251,9 +256,10 @@ void ConvolutionLayerSpatial::generate_key() { short_key_ = keyBuilder.str(); } -template<> -std::string ConvolutionLayerSpatial::generate_specific_key( +template +std::string ConvolutionLayerSpatial::generate_specific_key( int_tp type, int_tp blockWidth, int_tp blockHeight, int_tp blockDepth) { + CHECK((!std::is_same::value)); std::stringstream keyBuilder; keyBuilder << short_key_ << "_" << type @@ -271,7 +277,7 @@ void interleaveMatrix( CHECK_EQ(interleavedRows % 2, 0) << "interleaveMatrix only supports even values for interleavedRows."; - size_t memSize = r * c * sizeof(float); + size_t memSize = r * c * sizeof(Dtype); size_t dstSize = memSize * (interleavedRows + nonInterleavedRows * 2) / (interleavedRows + nonInterleavedRows); @@ -397,11 +403,12 @@ void ConvolutionLayerSpatial::swizzleWeights( } } -template<> -void ConvolutionLayerSpatial::calculate_global_size(int_tp batch, +template +void ConvolutionLayerSpatial::calculate_global_size(int_tp batch, int_tp* wio, // work item output size size_t* lSize, // local size size_t* gSize) { // global size + CHECK((!std::is_same::value)); gSize[0] = ceil( (fmax(static_cast(output_w_) / wio[0], 1.0)) / lSize[0]) * lSize[0]; @@ -413,11 +420,13 @@ void ConvolutionLayerSpatial::calculate_global_size(int_tp batch, / lSize[2]) * lSize[2]; } -template<> -bool ConvolutionLayerSpatial::create_basic_kernel( - const vector*>& bottom, const vector*>& top, +template +bool ConvolutionLayerSpatial::create_basic_kernel( + const vector*>& bottom, + const vector*>& top, int_tp blockWidth, int_tp blockHeight, int_tp blockDepth) { + CHECK((!std::is_same::value)); // Standard spatial setup is done here std::stringstream keyBuilder; std::stringstream multFunctionBuilder; @@ -439,11 +448,12 @@ bool ConvolutionLayerSpatial::create_basic_kernel( optionsString << "-cl-fast-relaxed-math " << " -D KERNELSIZE=" << kernel_w_ * kernel_h_ << " -D KERNEL_W=" << kernel_w_ << " -D KERNEL_H=" << kernel_h_ << " -D CHANNELS=" - << channels_ / group_ << " -D STRIDE_H=" << stride_h_ + << this->channels_ / this->group_ + << " -D STRIDE_H=" << stride_h_ << " -DDILATION_X=" << dilation_w_ << " -DDILATION_Y=" << dilation_h_ << " -D STRIDE_W=" << stride_w_ << " -D APPLY_BIAS=" - << bias_term_ << " -D OUTPUT_Z=" << M_ + << this->bias_term_ << " -D OUTPUT_Z=" << M_ << " -D XPAR=" << workItemOutput[0] << " -D YPAR=" << workItemOutput[1] << " -D ZPAR=" << workItemOutput[2] << " -D " << kernelDef.c_str() << " -D CFMultiNoPadding=" @@ -462,7 +472,7 @@ bool ConvolutionLayerSpatial::create_basic_kernel( optionsString << " -D__BEIGNET__"; string options = optionsString.str(); try { - submit_conv_spatial_program(&ctx, kernel_name_, options); + submit_conv_spatial_program(&ctx, kernel_name_, options); } catch (std::exception& e) { dbgPrint(std::cout << "Basic kernel generation failed" << std::endl); return false; @@ -531,28 +541,30 @@ void ConvolutionLayerSpatial::cleanTmpSubBuffers( tmpSubBuffers.clear(); } -template<> -cl_int ConvolutionLayerSpatial::convolve( - const vector*>& bottom, const vector*>& top, +template +cl_int ConvolutionLayerSpatial::convolve( + const vector*>& bottom, const vector*>& top, int_tp index, int_tp numImages, kernelConfig* config) { + CHECK((!std::is_same::value)); viennacl::ocl::context &ctx = viennacl::ocl::get_context(this->device_->id()); viennacl::ocl::program & program = ctx.get_program(config->kernelName); viennacl::ocl::kernel &kernel = program.get_kernel(config->kernelName); cl_int err = CL_SUCCESS; if (config->kernelType == 2) { swizzleWeights(bottom, top, config->workItem_output[2], false); - size_t total_bottom_size = bottom_dim_ * numImages; - size_t total_kernel_size = kernel_h_ * kernel_w_ * channels_ * M_; - size_t total_bias_size = M_ * group_; - size_t total_top_size = top_dim_ * numImages; - for (int_tp g = 0; g < group_; ++g) { + size_t total_bottom_size = this->bottom_dim_ * numImages; + size_t total_kernel_size = kernel_h_ * kernel_w_ * this->channels_ * M_; + size_t total_bias_size = M_ * this->group_; + size_t total_top_size = this->top_dim_ * numImages; + for (int_tp g = 0; g < this->group_; ++g) { bias_offset_ = M_ * g; - int_tp image_offset = width_ * height_ * (channels_ / group_) * g; + int_tp image_offset = width_ * height_ * + (this->channels_ / this->group_) * g; int_tp output_image_offset = output_w_ * output_h_ * M_ * g; int_tp kernel_offset = kernel_h_ * kernel_w_ - * (channels_ / group_) * M_ * g; + * (this->channels_ / this->group_) * M_ * g; cl_uint argIdx = 0; if (IsFusedWithEltwiseReLU()) kernel.arg(argIdx++, WrapHandle((cl_mem) bottom[1]->gpu_data(), &ctx)); @@ -604,20 +616,21 @@ cl_int ConvolutionLayerSpatial::convolve( break; } - if (group_ > 1) { + if (this->group_ > 1) { cleanTmpSubBuffers(bottom, top); } if (err != CL_SUCCESS) return err; } else if (config->kernelType == 5) { swizzleWeights(bottom, top, config->workItem_output[1], true); - size_t total_bottom_size = bottom_dim_ * numImages; - size_t total_kernel_size = kernel_h_ * kernel_w_ * channels_ * M_; - size_t total_bias_size = M_ * group_; - size_t total_top_size = top_dim_ * numImages; - for (int_tp g = 0; g < group_; ++g) { + size_t total_bottom_size = this->bottom_dim_ * numImages; + size_t total_kernel_size = kernel_h_ * kernel_w_ * this->channels_ * M_; + size_t total_bias_size = M_ * this->group_; + size_t total_top_size = this->top_dim_ * numImages; + for (int_tp g = 0; g < this->group_; ++g) { bias_offset_ = M_ * g; - int_tp image_offset = width_ * height_ * (channels_ / group_) * g; + int_tp image_offset = width_ * height_ * + (this->channels_ / this->group_) * g; int_tp output_image_offset = output_w_ * output_h_ * M_ * g; cl_uint argIdx = 0; @@ -625,7 +638,7 @@ cl_int ConvolutionLayerSpatial::convolve( kernel.arg(argIdx++, WrapHandle((cl_mem) bottom[1]->gpu_data(), &ctx)); int_tp kernel_offset = kernel_h_ * kernel_w_ - * (channels_ / group_) * M_ * g; + * (this->channels_ / this->group_) * M_ * g; try { setBufferKernelArg(bottom, top, &kernel, argIdx++, &ctx, (cl_mem) bottom_data, @@ -660,7 +673,8 @@ cl_int ConvolutionLayerSpatial::convolve( viennacl::ocl::get_context(this->device_->id()); int out_pitch_y = output_w_ * output_h_; int out_pitch_z = out_pitch_y * M_; - int aligned_input_size = height_ * width_ * channels_ / group_; + int aligned_input_size = height_ * width_ * + this->channels_ / this->group_; int slice_pitch = width_ * height_; kernel.arg(argIdx++, (uint32_t)out_pitch_y); kernel.arg(argIdx++, (uint32_t)out_pitch_z); @@ -695,17 +709,17 @@ cl_int ConvolutionLayerSpatial::convolve( break; } - if (group_ > 1) { + if (this->group_ > 1) { cleanTmpSubBuffers(bottom, top); } if (err != CL_SUCCESS) return err; } else { for (int_tp n = 0; n < numImages; ++n) { - for (int_tp g = 0; g < group_; ++g) { + for (int_tp g = 0; g < this->group_; ++g) { bias_offset_ = M_ * g; int_tp image_offset = n * this->bottom_dim_ - + width_ * height_ * (channels_ / group_) * g; + + width_ * height_ * (this->channels_ / this->group_) * g; int_tp output_image_offset = n * this->top_dim_ + output_w_ * output_h_ * M_ * g; @@ -714,7 +728,8 @@ cl_int ConvolutionLayerSpatial::convolve( kernel.arg(argIdx++, WrapHandle((cl_mem) bottom[1]->gpu_data(), &ctx)); - int_tp kernel_offset = kernel_h_ * kernel_w_ * (channels_ / group_) * M_ + int_tp kernel_offset = kernel_h_ * kernel_w_ * + (this->channels_ / this->group_) * M_ * g; kernel.arg(argIdx++, WrapHandle((cl_mem) bottom_data, &ctx)); @@ -760,15 +775,16 @@ cl_int ConvolutionLayerSpatial::convolve( return err; } -template<> -float ConvolutionLayerSpatial::timed_convolve( - const vector*>& bottom, const vector*>& top, +template +float ConvolutionLayerSpatial::timed_convolve( + const vector*>& bottom, const vector*>& top, int_tp index, int_tp numImages, kernelConfig* config) { // warm up. + CHECK((!std::is_same::value)); bool saved_tuned = tuned_; tuned_ = false; - convolve(bottom, top, index, num_, config); + convolve(bottom, top, index, this->num_, config); Timer timer; timer.initted(); timer.Start(); @@ -778,7 +794,7 @@ float ConvolutionLayerSpatial::timed_convolve( tuned_ = true; int loop_cnt = 4; for (int i = 0; i < loop_cnt; i++) { - err = convolve(bottom, top, index, num_, config); + err = convolve(bottom, top, index, this->num_, config); if (err != CL_SUCCESS) break; } @@ -799,8 +815,8 @@ float ConvolutionLayerSpatial::timed_convolve( double out_z = M_; double k_w = kernel_w_; double k_h = kernel_h_; - double k_z = channels_; - double totalFlops = ((k_w*k_h*k_z -1)*2)*(out_w*out_h*out_z)*num_; + double k_z = this->channels_; + double totalFlops = ((k_w*k_h*k_z -1)*2)*(out_w*out_h*out_z) * this->num_; std::cout << "\tEstimated Gflops:" << ((totalFlops/1000)/1000)/1000 << std::endl; std::cout << "\tEstimated GFLOPS/S: " << @@ -814,11 +830,14 @@ float ConvolutionLayerSpatial::timed_convolve( return elapsedTime; } -template<> -bool ConvolutionLayerSpatial::verify_result( - const vector*>& bottom, const vector*>& top, +template +bool ConvolutionLayerSpatial::verify_result( + const vector*>& bottom, + const vector*>& top, int_tp index, - int_tp numImages, const Blob &verify_blob, kernelConfig* config) { + int_tp numImages, + const Blob &verify_blob, + kernelConfig* config) { uint_tp verificationFail = 0; @@ -827,20 +846,25 @@ bool ConvolutionLayerSpatial::verify_result( else if (config->tested) return false; - greentea_memset(this->device_->id(), top[index]->count() * sizeof(float), + greentea_memset(this->device_->id(), + top[index]->count() * sizeof(Dtype), 0xff, - (cl_mem)top[index]->mutable_gpu_data(), 0); + (cl_mem)top[index]->mutable_gpu_data(), + 0); config->executionTime = timed_convolve(bottom, top, index, numImages, config); // Currently we can't do verification when conv is fused because the results // won't match the results of forward_gpu_gemm. Need more work to fix it. if (IsFused()) return true; - const float *verify_data = verify_blob.cpu_data(); - const float *data = top[index]->cpu_data(); + const Dtype *verify_data = verify_blob.cpu_data(); + const Dtype *data = top[index]->cpu_data(); + Dtype err_factor = 1; + if (std::is_same::value) + err_factor = 8; for (int_tp n = 0; n < numImages; ++n) { - for (int_tp g = 0; g < group_; ++g) { + for (int_tp g = 0; g < this->group_; ++g) { int_tp output_image_offset = n * this->top_dim_ + output_w_ * output_h_ * M_ * g; for (int out_ch = 0; out_ch < M_ && !verificationFail; out_ch++) @@ -849,12 +873,12 @@ bool ConvolutionLayerSpatial::verify_result( size_t offset = output_image_offset + out_ch * output_w_ * output_h_ + h * output_w_ + w; if (fabs(data[offset] - verify_data[offset]) > - 0.1 * fabs(verify_data[offset]) && - !(fabs(verify_data[offset]) < 1.e-3 - && fabs(data[offset] - verify_data[offset]) < 1.e-4)) { + 0.1 * fabs(verify_data[offset] * err_factor) && + !(fabs(verify_data[offset]) < 1e-3 * err_factor + && fabs(data[offset] - verify_data[offset]) < 1e-4 * err_factor)) { dbgPrint(printf("test verification failed @ image %d group %d" "out_ch %d h %d w %d got %G expected %G\n", - n, g, out_ch, h, w, data[offset], verify_data[offset])); + n, g, out_ch, h, w, float(data[offset]), float(verify_data[offset]))); verificationFail = 1; break; } @@ -866,11 +890,13 @@ bool ConvolutionLayerSpatial::verify_result( return true; } -template<> -bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( - const vector*>& bottom, const vector*>& top, +template +bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( + const vector*>& bottom, + const vector*>& top, int_tp blockM, - int_tp blockK, int_tp blockN) { + int_tp blockK, + int_tp blockN) { std::stringstream multFunctionBuilder; std::string stringBuilder; std::stringstream optionsString; @@ -879,7 +905,7 @@ bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( int_tp workItemOutput[3] = { blockM, blockK, blockN }; int_tp simd_size = blockK; - int_tp num_batches = num_; + int_tp num_batches = this->num_; int_tp globalWorkSizeDX = blockN; int_tp globalWorkSizeDY = blockM; @@ -907,12 +933,12 @@ bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( " -DSTRIDE_Y=" << stride_h_ << " -DDILATION_X=" << dilation_w_ << " -DDILATION_Y=" << dilation_h_ << - " -DINPUT_DEPTH=" << channels_ << + " -DINPUT_DEPTH=" << this->channels_ << " -DWIDTH1=" << M_ << " -DOUT_PADDING_LEFT=" << 0 << " -DOUT_PADDING_HEIGHT=" << 0 << " -DOUT_DEPTH=" << M_ << - " -DNUM_BATCHES=" << num_ << + " -DNUM_BATCHES=" << this->num_ << " -DDY=" << globalWorkSizeDY << " -DDX=" << globalWorkSizeDX << " -DKERNEL_WIDTH_DIV2=" << kernel_w_ / 2 << @@ -937,7 +963,7 @@ bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( optionsString << " -D__BEIGNET__"; string options = optionsString.str(); - viennacl::ocl::program & program = submit_conv_spatial_program(&ctx, + viennacl::ocl::program & program = submit_conv_spatial_program(&ctx, kernel_name_, options); size_t workgroupSize_used; @@ -965,11 +991,13 @@ bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( } -template<> -bool ConvolutionLayerSpatial::setup_IDLF( - const vector*>& bottom, const vector*>& top, +template +bool ConvolutionLayerSpatial::setup_IDLF( + const vector*>& bottom, + const vector*>& top, int_tp blockWidth, - int_tp blockHeight, int_tp simd_size) { + int_tp blockHeight, + int_tp simd_size) { std::stringstream multFunctionBuilder; std::string stringBuilder; std::stringstream optionsString; @@ -980,7 +1008,7 @@ bool ConvolutionLayerSpatial::setup_IDLF( const int_tp num_output_maps = M_; int_tp output_block_width = blockWidth; int_tp output_block_height = blockHeight; - int_tp num_batches = num_; + int_tp num_batches = this->num_; kernel_name_ = "IDLF_"; kernel_name_ += kernelUKey.c_str(); @@ -1010,9 +1038,9 @@ bool ConvolutionLayerSpatial::setup_IDLF( << " -D filter_qualifier=__global" << " -D OUT_BLOCK_WIDTH=" << output_block_width << " -D OUT_BLOCK_HEIGHT=" << output_block_height - << " -D INPUT_DEPTH=" << channels_ / group_ - << " -DTOTAL_INPUT_DEPTH_SIZE=" << channels_ - << " -DTOTAL_OUTPUT_DEPTH=" << num_output_ + << " -D INPUT_DEPTH=" << this->channels_ / this->group_ + << " -DTOTAL_INPUT_DEPTH_SIZE=" << this->channels_ + << " -DTOTAL_OUTPUT_DEPTH=" << this->num_output_ << " -DINPUT_START_X=" << 0 << " -DINPUT_START_Y=" << 0 << " -DINPUT_START_Z=" << 0 << " -DKERNEL_WIDTH=" << kernel_w_ @@ -1042,7 +1070,7 @@ bool ConvolutionLayerSpatial::setup_IDLF( if (IsBeignet(&ctx)) optionsString << " -D__BEIGNET__"; string options = optionsString.str(); - viennacl::ocl::program & program = submit_conv_spatial_program(&ctx, + viennacl::ocl::program & program = submit_conv_spatial_program(&ctx, kernel_name_, options); @@ -1071,9 +1099,10 @@ bool ConvolutionLayerSpatial::setup_IDLF( } } -template<> -bool ConvolutionLayerSpatial::tune_local_size( - const vector*>& bottom, const vector*>& top, +template +bool ConvolutionLayerSpatial::tune_local_size( + const vector*>& bottom, + const vector*>& top, kernelConfig* config) { if (config->use_null_local || !config->autoTune) return true; @@ -1156,9 +1185,10 @@ bool ConvolutionLayerSpatial::tune_local_size( return true; } -template<> -void ConvolutionLayerSpatial::create_convolution_kernel( - const vector*>& bottom, const vector*>& top, +template +void ConvolutionLayerSpatial::create_convolution_kernel( + const vector*>& bottom, + const vector*>& top, int_tp kernelType, int_tp blockWidth, int_tp blockHeight, int_tp blockDepth) { @@ -1173,10 +1203,11 @@ void ConvolutionLayerSpatial::create_convolution_kernel( assert(0); } -template<> -void ConvolutionLayerSpatial::setup_convolution( - const vector*>& bottom, const vector*>& top, - const Blob &verify_blob) { +template +void ConvolutionLayerSpatial::setup_convolution( + const vector*>& bottom, + const vector*>& top, + const Blob &verify_blob) { // Initializes unique kernel ID kernel_uid_ = 0; std::string viennacl_cache_path; @@ -1200,8 +1231,10 @@ void ConvolutionLayerSpatial::setup_convolution( if (this->group_ == 1 && ((M_ % 8 == 0) && (M_ % 32 != 24))) { create_convolution_kernel(bottom, top, 5, 1, 8, 32); create_convolution_kernel(bottom, top, 5, 2, 8, 32); - if (kernel_w_ < 4 && M_ % 32 == 0) + if ((kernel_w_ < 4 || (!std::is_same::value)) && M_ % 32 == 0) create_convolution_kernel(bottom, top, 5, 1, 16, 32); + if (kernel_w_ < 4 && (!std::is_same::value)) + create_convolution_kernel(bottom, top, 5, 2, 16, 32); } for (int simd_size = 8; simd_size <= 16; simd_size += 8) { @@ -1233,7 +1266,7 @@ void ConvolutionLayerSpatial::setup_convolution( // for simd 8. if (simd_size == 8 && M_ >= 16 - && ((num_ * M_ * output_w_ * output_h_ / + && ((this->num_ * M_ * output_w_ * output_h_ / static_cast(width * height)) >= max_compute_units * 7 * 16)) continue; @@ -1260,7 +1293,7 @@ void ConvolutionLayerSpatial::setup_convolution( for (int_tp x = 0; x < kernelQueue.size(); x++) { if (tune_local_size(bottom, top, kernelQueue[x])) { kernelQueue[x]->executionTime = timed_convolve(bottom, top, bottom_index_, - num_, kernelQueue[x]); + this->num_, kernelQueue[x]); } else { // skip those kernels without a good local size. kernelQueue[x]->verified = false; @@ -1268,7 +1301,7 @@ void ConvolutionLayerSpatial::setup_convolution( } #ifdef TEST_ALL_KERNELS if (kernelQueue[x]->tested == false) { - bool verified = verify_result(bottom, top, bottom_index_, num_, + bool verified = verify_result(bottom, top, bottom_index_, this->num_, verify_blob, kernelQueue[x]); if (verified == false) { dbgPrint(std::cout << "Kernel " @@ -1320,7 +1353,7 @@ void ConvolutionLayerSpatial::setup_convolution( } if (fastestKernel < 0) break; // Test fastest kernel - bool verified = verify_result(bottom, top, bottom_index_, num_, + bool verified = verify_result(bottom, top, bottom_index_, this->num_, verify_blob, kernelQueue[fastestKernel]); if (verified == true) { kernelQueue[fastestKernel]->verified = true; @@ -1344,7 +1377,7 @@ void ConvolutionLayerSpatial::setup_convolution( << "fallback to basic kernel" << std::endl); create_basic_kernel(bottom, top, 1, 1, 1); kernel_index_ = kernelQueue.size() - 1; - verification = verify_result(bottom, top, bottom_index_, num_, + verification = verify_result(bottom, top, bottom_index_, this->num_, verify_blob, kernelQueue[kernel_index_]); CHECK_EQ(verification, true) << "Basic kernel failed verification." << std::endl; @@ -1390,12 +1423,13 @@ void ConvolutionLayerSpatial::setup_convolution( ctx.cache_path(viennacl_cache_path); } -template<> -void ConvolutionLayerSpatial::Forward_gpu( - const vector*>& bottom, const vector*>& top) { +template +void ConvolutionLayerSpatial::Forward_gpu( + const vector*>& bottom, + const vector*>& top) { weight = this->blobs_[0]->gpu_data(); - weight_cpu = static_cast(this->blobs_[0]->cpu_data()); - if (bias_term_) + weight_cpu = static_cast(this->blobs_[0]->cpu_data()); + if (this->bias_term_) bias_ = this->blobs_[1]->gpu_data(); int bottom_size = bottom.size(); @@ -1411,17 +1445,17 @@ void ConvolutionLayerSpatial::Forward_gpu( bias_offset_ = 0; if (!tuned_) { - Blob verify_blob; + Blob verify_blob; verify_blob.ReshapeLike(*top[i]); - float *verify_data = verify_blob.mutable_gpu_data(); - const float *weight_gpu_data = this->blobs_[0]->gpu_data(); - const float *bottom_gpu_data = bottom[i]->gpu_data(); + Dtype *verify_data = verify_blob.mutable_gpu_data(); + const Dtype *weight_gpu_data = this->blobs_[0]->gpu_data(); + const Dtype *bottom_gpu_data = bottom[i]->gpu_data(); for (int_tp n = 0; n < this->num_; ++n) { this->forward_gpu_gemm(bottom_gpu_data, n * this->bottom_dim_, weight_gpu_data, verify_data, n * this->top_dim_); if (this->bias_term_) { - const float* bias = this->blobs_[1]->gpu_data(); + const Dtype* bias = this->blobs_[1]->gpu_data(); this->forward_gpu_bias(verify_data, n * this->top_dim_, bias); } } @@ -1430,28 +1464,29 @@ void ConvolutionLayerSpatial::Forward_gpu( CHECK_EQ(tuned_, true) << "Spatial convolution auto-tuning failed."; } - convolve(bottom, top, i, num_, bestKernelConfig); + convolve(bottom, top, i, this->num_, bestKernelConfig); } } -template<> -void ConvolutionLayerSpatial::Backward_gpu( - const vector*>& top, const vector& propagate_down, - const vector*>& bottom) { - const float* weight = this->blobs_[0]->gpu_data(); - float* weight_diff = this->blobs_[0]->mutable_gpu_diff(); +template +void ConvolutionLayerSpatial::Backward_gpu( + const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + const Dtype* weight = this->blobs_[0]->gpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); for (int_tp i = 0; i < top.size(); ++i) { - const float* top_diff = top[i]->gpu_diff(); + const Dtype* top_diff = top[i]->gpu_diff(); // Bias gradient, if necessary. if (this->bias_term_ && this->param_propagate_down_[1]) { - float* bias_diff = this->blobs_[1]->mutable_gpu_diff(); + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); for (int_tp n = 0; n < this->num_; ++n) { this->backward_gpu_bias(bias_diff, top_diff, n * this->top_dim_); } } if (this->param_propagate_down_[0] || propagate_down[i]) { - const float* bottom_data = bottom[i]->gpu_data(); - float* bottom_diff = bottom[i]->mutable_gpu_diff(); + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); for (int_tp n = 0; n < this->num_; ++n) { // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { @@ -1478,7 +1513,8 @@ void ConvolutionLayerSpatial::Backward_gpu( template void ConvolutionLayerSpatial::load_cached_kernels( - const vector*>& bottom, const vector*>& top) { + const vector*>& bottom, + const vector*>& top) { // Generates static key_ std::string previous_key = key_; generate_key(); @@ -1563,137 +1599,6 @@ void ConvolutionLayerSpatial::SetUp( } } -template void ConvolutionLayerSpatial::SetUp( - const vector*>& bottom, const vector*>& top, - caffe::Backend backend); - -template void ConvolutionLayerSpatial::SetUp( - const vector*>& bottom, const vector*>& top, - caffe::Backend backend); - -template void ConvolutionLayerSpatial::swizzleWeights( - const vector*>& bottom, - const vector*>& top, - int_tp swizzle_factor, - bool interleave = false); -template void ConvolutionLayerSpatial::swizzleWeights( - const vector*>& bottom, - const vector*>& top, - int_tp swizzle_factor, - bool interleave = false); - -template<> -void ConvolutionLayerSpatial::create_convolution_kernel( - const vector*>& bottom, const vector*>& top, - int_tp kernelType, - int_tp blockWidth, int_tp blockHeight, - int_tp blockDepth) { - NOT_IMPLEMENTED; - return; -} - -template<> -bool ConvolutionLayerSpatial::setup_IDLF( - const vector*>& bottom, const vector*>& top, - int_tp blockWidth, - int_tp blockHeight, int_tp blockDepth) { - NOT_IMPLEMENTED; - return false; -} - -template<> -bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( - const vector*>& bottom, const vector*>& top, - int_tp blockWidth, - int_tp blockHeight, int_tp blockDepth) { - NOT_IMPLEMENTED; - return false; -} - - -template<> -bool ConvolutionLayerSpatial::verify_result( - const vector*>& bottom, const vector*>& top, - int_tp index, - int_tp numImages, const Blob &verify_blob, kernelConfig* config) { - NOT_IMPLEMENTED; - return false; -} - -template<> -bool ConvolutionLayerSpatial::create_basic_kernel( - const vector*>& bottom, const vector*>& top, - int_tp blockWidth, - int_tp blockHeight, int_tp blockDepth) { - NOT_IMPLEMENTED; - return false; -} - -template<> -bool ConvolutionLayerSpatial::tune_local_size( - const vector*>& bottom, const vector*>& top, - kernelConfig* config) { - NOT_IMPLEMENTED; - return false; -} - -template<> -cl_int ConvolutionLayerSpatial::convolve( - const vector*>& bottom, const vector*>& top, - int_tp index, - int_tp numImages, kernelConfig* config) { - NOT_IMPLEMENTED; - return false; -} - -template<> -float ConvolutionLayerSpatial::timed_convolve( - const vector*>& bottom, const vector*>& top, - int_tp index, - int_tp numImages, kernelConfig* config) { - NOT_IMPLEMENTED; - return 0.f; -} - -template<> -void ConvolutionLayerSpatial::setup_convolution( - const vector*>& bottom, const vector*>& top, - const Blob &verify_blob) { - NOT_IMPLEMENTED; -} - -template<> -void ConvolutionLayerSpatial::calculate_global_size( - int_tp batch, - int_tp* workItemOutput, - size_t* localSizes, size_t* globalSizes) { - NOT_IMPLEMENTED; -} - -template<> -void ConvolutionLayerSpatial::generate_key() { - NOT_IMPLEMENTED; -} - -template<> -std::string ConvolutionLayerSpatial::generate_specific_key( - int_tp type, int_tp blockWidth, int_tp blockHeight, int_tp blockDepth) { - NOT_IMPLEMENTED; - return ""; -} - -template<> -void ConvolutionLayerSpatial::Forward_gpu( - const vector*>& bottom, const vector*>& top) { - NOT_IMPLEMENTED; -} - -template<> -void ConvolutionLayerSpatial::Backward_gpu( - const vector*>& top, const vector& propagate_down, - const vector*>& bottom) { - NOT_IMPLEMENTED; -} #else template void ConvolutionLayerSpatial::Forward_gpu( diff --git a/src/caffe/layers/dropout_layer.cu b/src/caffe/layers/dropout_layer.cu index aba3c790826..bf93befa1eb 100644 --- a/src/caffe/layers/dropout_layer.cu +++ b/src/caffe/layers/dropout_layer.cu @@ -54,7 +54,8 @@ void DropoutLayer::Forward_gpu(const vector*>& bottom, CL_KERNEL_SELECT("dropout_forward")); viennacl::ocl::enqueue( oclk_dropout(count, WrapHandle((cl_mem) bottom_data, &ctx), - WrapHandle(mask, &ctx), uint_thres_, scale_, + WrapHandle(mask, &ctx), fixup_arg_type(uint_thres_), + fixup_arg_type(scale_), WrapHandle((cl_mem) top_data, &ctx)), ctx.get_queue()); } else { @@ -113,7 +114,8 @@ void DropoutLayer::Backward_gpu(const vector*>& top, CL_KERNEL_SELECT("dropout_backward")); viennacl::ocl::enqueue( oclk_dropout(count, WrapHandle((cl_mem) top_diff, &ctx), - WrapHandle(mask, &ctx), uint_thres_, scale_, + WrapHandle(mask, &ctx), fixup_arg_type(uint_thres_), + fixup_arg_type(scale_), WrapHandle((cl_mem) bottom_diff, &ctx)), ctx.get_queue()); } else { diff --git a/src/caffe/layers/eltwise_layer.cpp b/src/caffe/layers/eltwise_layer.cpp index 4b9163bdeaa..1c78a235b72 100644 --- a/src/caffe/layers/eltwise_layer.cpp +++ b/src/caffe/layers/eltwise_layer.cpp @@ -51,6 +51,9 @@ void EltwiseLayer::Forward_cpu( const Dtype* bottom_data_b = NULL; const int_tp count = top[0]->count(); Dtype* top_data = top[0]->mutable_cpu_data(); + Dtype maxVal = FLT_MAX; + if (std::is_same::value) + maxVal = HALF_MAX; switch (op_) { case EltwiseParameter_EltwiseOp_PROD: caffe_mul(count, bottom[0]->cpu_data(), bottom[1]->cpu_data(), top_data); @@ -69,7 +72,7 @@ void EltwiseLayer::Forward_cpu( // Initialize mask = max_idx_.mutable_cpu_data(); caffe_set(count, (int_tp)-1, mask); - caffe_set(count, Dtype(-FLT_MAX), top_data); + caffe_set(count, Dtype(-maxVal), top_data); // bottom 0 & 1 bottom_data_a = bottom[0]->cpu_data(); bottom_data_b = bottom[1]->cpu_data(); diff --git a/src/caffe/layers/elu_layer.cu b/src/caffe/layers/elu_layer.cu index 0b57cf83379..d8afc2b7355 100644 --- a/src/caffe/layers/elu_layer.cu +++ b/src/caffe/layers/elu_layer.cu @@ -42,7 +42,8 @@ void ELULayer::Forward_gpu(const vector*>& bottom, CL_KERNEL_SELECT("elu_forward")); viennacl::ocl::enqueue( oclk_elu(count, WrapHandle((cl_mem) bottom_data, &ctx), - WrapHandle((cl_mem) top_data, &ctx), alpha), + WrapHandle((cl_mem) top_data, &ctx), + fixup_arg_type(alpha)), ctx.get_queue()); #endif // USE_GREENTEA } @@ -92,7 +93,8 @@ void ELULayer::Backward_gpu(const vector*>& top, oclk_elu(count, WrapHandle((cl_mem) top_diff, &ctx), WrapHandle((cl_mem) top_data, &ctx), WrapHandle((cl_mem) bottom_data, &ctx), - WrapHandle((cl_mem) bottom_diff, &ctx), alpha), + WrapHandle((cl_mem) bottom_diff, &ctx), + fixup_arg_type(alpha)), ctx.get_queue()); #endif // USE_GREENTEA } diff --git a/src/caffe/layers/embed_layer.cu b/src/caffe/layers/embed_layer.cu index 7d479a0ec67..9d632f9d2c7 100644 --- a/src/caffe/layers/embed_layer.cu +++ b/src/caffe/layers/embed_layer.cu @@ -92,6 +92,13 @@ template void EmbedLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input."; +#ifdef USE_GREENTEA + // FIXM, the half data type ocl kernel has bug, have to fall back. + if (std::is_same::value) { + Backward_cpu(top, propagate_down, bottom); + return; + } +#endif if (this->param_propagate_down_[0]) { const int_tp top_count = top[0]->count(); const Dtype* top_diff = top[0]->gpu_diff(); diff --git a/src/caffe/layers/hinge_loss_layer.cpp b/src/caffe/layers/hinge_loss_layer.cpp index 3869433883e..0044e36bb30 100644 --- a/src/caffe/layers/hinge_loss_layer.cpp +++ b/src/caffe/layers/hinge_loss_layer.cpp @@ -22,8 +22,8 @@ void HingeLossLayer::Forward_cpu(const vector*>& bottom, } for (int_tp i = 0; i < num; ++i) { for (int_tp j = 0; j < dim; ++j) { - bottom_diff[i * dim + j] = std::max( - Dtype(0), 1 + bottom_diff[i * dim + j]); + bottom_diff[i * dim + j] = fmax( + Dtype(0), Dtype(Dtype(1) + bottom_diff[i * dim + j])); } } Dtype* loss = top[0]->mutable_cpu_data(); @@ -61,10 +61,10 @@ void HingeLossLayer::Backward_cpu(const vector*>& top, switch (this->layer_param_.hinge_loss_param().norm()) { case HingeLossParameter_Norm_L1: caffe_cpu_sign(count, bottom_diff, bottom_diff); - caffe_scal(count, loss_weight / num, bottom_diff); + caffe_scal(count, Dtype(loss_weight / num), bottom_diff); break; case HingeLossParameter_Norm_L2: - caffe_scal(count, loss_weight * 2 / num, bottom_diff); + caffe_scal(count, Dtype(loss_weight * 2 / num), bottom_diff); break; default: LOG(FATAL) << "Unknown Norm"; diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index 90f023b3b1b..0b2ecec35a2 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -33,6 +33,7 @@ static void CL_CALLBACK gemm_callback (cl_event event, // Will return image to caller if the input image is NULL. Otherwise, // will use the image directly. It's caller's responsibility to // release the created image. +template static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, cl_mem *image, cl_mem buffer, int offset, bool is_matrix_a, bool transpose, @@ -48,7 +49,7 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, cl_image_desc desc; cl_image_format format; - bool halfPrecisionMode = false; + bool halfPrecisionMode = !std::is_same::value; memset(&desc, 0, sizeof(desc)); int src_offset = halfPrecisionMode ? sizeof(unsigned short) * offset : sizeof(float) * offset; @@ -88,9 +89,8 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, origin, region, wait_list_size, wait_list, event)); } else { - std::string kernel_name("gemm_buffer_copy_image_transpose_float"); - - viennacl::ocl::kernel &oclk_gemm_copy = program.get_kernel(kernel_name); + viennacl::ocl::kernel &oclk_gemm_copy = program.get_kernel( + CL_KERNEL_SELECT("gemm_buffer_copy_image_transpose")); size_t global_copy[2]; global_copy[0] = width; @@ -144,9 +144,8 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, buffer, *image, src_offset, origin, region, wait_list_size, wait_list, event)); } else { - std::string kernel_name("gemm_buffer_copy_image_no_transpose_float"); - - viennacl::ocl::kernel &oclk_gemm_copy = program.get_kernel(kernel_name); + viennacl::ocl::kernel &oclk_gemm_copy = program.get_kernel( + CL_KERNEL_SELECT("gemm_buffer_copy_image_no_transpose")); size_t global_copy[2]; global_copy[0] = padding ? padded_width : width; @@ -166,11 +165,12 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, } } +template static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int_tp M, - const int_tp N, const int_tp K, const float alpha, + const int_tp N, const int_tp K, const Dtype alpha, const cl_mem A, const int_tp offA, const cl_mem B, - const int_tp offB, const float beta, cl_mem C, + const int_tp offB, const Dtype beta, cl_mem C, const int_tp offC, bool is_image_a, bool is_image_b, enum gemm_type_t gemm_type, const size_t max_image_size) { CHECK_EQ(gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 @@ -183,7 +183,7 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP if (is_image_b) CHECK_EQ(offB, 0) << "Invalid input image offset." << std::endl; - bool halfPrecisionMode = false; + bool halfPrecisionMode = !std::is_same::value; int widthA = (TransA == CblasNoTrans) ? K : M; int heightA = (TransA == CblasNoTrans) ? M : K; int widthB = (TransB == CblasNoTrans) ? N : K; @@ -250,7 +250,12 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP kernel_name += "0"; else kernel_name += "1"; - kernel_name += "_float"; + + if(halfPrecisionMode) { + kernel_name += "_half"; + } else { + kernel_name += "_float"; + } oclk_gemm_float = &program.get_kernel(kernel_name); while(C_start_y < M) { @@ -295,7 +300,7 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP } if (!is_image_a) { - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, blockA_offset, + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, blockA_offset, true, TransA != CblasNoTrans, padding_A, imageA_h, imageA_w, blockA_height, blockA_width, ldA, 0, NULL, &ev[ev_idx]); @@ -303,7 +308,7 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP ev_idx++; } if (!is_image_b) { - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, blockB_offset, + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, blockB_offset, false, false, padding_B, imageB_h, imageB_w, blockB_height, blockB_width, ldB, 0, NULL, &ev[ev_idx]); @@ -316,7 +321,7 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP if (!is_image_a) { bool padding; padding = !is_image_b || halfPrecisionMode; - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, blockA_offset, + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, blockA_offset, true, TransA != CblasNoTrans, padding, imageA_h, imageA_w, blockA_height, blockA_width, ldA, 0, NULL, &ev[ev_idx]); @@ -325,7 +330,7 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP } if(!is_image_b && (K % use_buffer_indicator != 0)) { - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, blockB_offset, + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, blockB_offset, false, true, false, imageB_h, imageB_w, blockB_height, blockB_width, ldB, 0, NULL, &ev[ev_idx]); if (ev[ev_idx] != NULL) @@ -378,8 +383,8 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP oclk_gemm_float->arg(arg_idx++, blockC_height); oclk_gemm_float->arg(arg_idx++, blockC_width); oclk_gemm_float->arg(arg_idx++, ldC); - oclk_gemm_float->arg(arg_idx++, alpha); - oclk_gemm_float->arg(arg_idx++, beta); + oclk_gemm_float->arg(arg_idx++, fixup_arg_type(alpha)); + oclk_gemm_float->arg(arg_idx++, fixup_arg_type(beta)); oclk_gemm_float->arg(arg_idx++, padded_k); if (TransB != CblasNoTrans) oclk_gemm_float->arg(arg_idx++, block_Ksize); @@ -437,11 +442,12 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP clReleaseMemObject(ImB); } +template static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int_tp M, - const int_tp N, const int_tp K, const float alpha, + const int_tp N, const int_tp K, const Dtype alpha, const cl_mem A, const int_tp offA, const cl_mem B, - const int_tp offB, const float beta, cl_mem C, + const int_tp offB, const Dtype beta, cl_mem C, const int_tp offC, enum gemm_type_t gemm_type) { CHECK_EQ(gemm_type == GEMM_TYPE_FAST_BUFFER, true) << "Invalid fast buffer gemm type." << std::endl; @@ -451,33 +457,39 @@ static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANS viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) ->program(); - bool halfPrecisionMode= false; + bool halfPrecisionMode = !std::is_same::value; size_t sub_group_size = 8; bool is_small_batch = (M == 2 || M == 4 || M == 8); viennacl::ocl::kernel *oclk_gemm_float; std::string kernel_name("gemm_buffer_"); if(TransA == CblasNoTrans && TransB == CblasNoTrans) { - kernel_name += "NN_float"; + kernel_name += "NN"; if(halfPrecisionMode) { sub_group_size = 16; } } else if(TransA == CblasNoTrans && TransB != CblasNoTrans) { if (M == 2) - kernel_name +="NT_M_2_float"; + kernel_name +="NT_M_2"; else if (M == 4) - kernel_name +="NT_M_4_float"; + kernel_name +="NT_M_4"; else if (M == 8) - kernel_name +="NT_M_8_float"; + kernel_name +="NT_M_8"; else - kernel_name += "NT_float"; + kernel_name += "NT"; } else if(TransA != CblasNoTrans && TransB == CblasNoTrans) { - kernel_name += "TN_float"; + kernel_name += "TN"; if(halfPrecisionMode) { sub_group_size = 16; } } else { - kernel_name += "TT_float"; + kernel_name += "TT"; + } + + if(halfPrecisionMode) { + kernel_name += "_half"; + } else { + kernel_name += "_float"; } oclk_gemm_float = &program.get_kernel(kernel_name); @@ -520,8 +532,8 @@ static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANS oclk_gemm_float->arg(arg_idx++, M); oclk_gemm_float->arg(arg_idx++, N); oclk_gemm_float->arg(arg_idx++, K); - oclk_gemm_float->arg(arg_idx++, alpha); - oclk_gemm_float->arg(arg_idx++, beta); + oclk_gemm_float->arg(arg_idx++, fixup_arg_type(alpha)); + oclk_gemm_float->arg(arg_idx++, fixup_arg_type(beta)); if(TransB == CblasNoTrans || TransA != CblasNoTrans) { int stride = 256; @@ -549,16 +561,16 @@ static void innerprod_common(const int_tp ctx_id, const CBLAS_TRANSPOSE TransB, if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || gemm_type == GEMM_TYPE_FAST_IMAGE_32_2) { - greentea_gpu_fast_image_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, + greentea_gpu_fast_image_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, (Dtype)1., A, 0, B, 0, (Dtype)0., C, 0, false, false, gemm_type, max_image_size); } else if (gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) { - greentea_gpu_fast_image_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, + greentea_gpu_fast_image_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, (Dtype)1., A, 0, B_image, 0, (Dtype)0., C, 0, false, true, GEMM_TYPE_FAST_IMAGE_B_IMAGE, max_image_size); } else if (gemm_type == GEMM_TYPE_FAST_BUFFER) { - greentea_gpu_fast_buffer_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, + greentea_gpu_fast_buffer_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, 1.f, A, 0, B, 0, 0.f, C, 0, gemm_type); } else @@ -584,7 +596,9 @@ void InnerProductLayer::generate_key() { key_ = viennacl::tools::sha1(prefix + keyBuilder.str()); // short_key_ = keyBuilder.str(); } - +#ifdef HAS_HALF_SUPPORT +template void InnerProductLayer::generate_key(); +#endif template void InnerProductLayer::generate_key(); template void InnerProductLayer::generate_key(); @@ -610,6 +624,9 @@ bool InnerProductLayer::load_cache() { } } +#ifdef HAS_HALF_SUPPORT +template bool InnerProductLayer::load_cache(); +#endif template bool InnerProductLayer::load_cache(); template bool InnerProductLayer::load_cache(); @@ -618,6 +635,7 @@ void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, const cl_mem B_image, const size_t max_image_size) { if (std::is_same::value) { + innerprod_type_ = GEMM_TYPE_DEFAULT; return; } else { //1. load cache @@ -626,7 +644,7 @@ void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, } else { //2. if not cached generate tuning uint element_size = 0; - bool halfPrecisionMode= false; + bool halfPrecisionMode = !std::is_same::value; if(halfPrecisionMode) { element_size = sizeof(uint16_t); } else { @@ -649,7 +667,7 @@ void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, // warm up. for( int i = 0; i < gemm_tests.size(); i++ ) { - innerprod_common(ctx_id, TransB, M_, N_, K_, + innerprod_common(ctx_id, TransB, M_, N_, K_, A, B, B_image, C, gemm_tests[i], max_image_size); } float fastest_time = 1e10; @@ -659,7 +677,7 @@ void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, Timer timer; timer.initted(); timer.Start(); - innerprod_common(ctx_id, TransB, M_, N_, K_, + innerprod_common(ctx_id, TransB, M_, N_, K_, A, B, B_image, C, gemm_tests[i], max_image_size); timer.Stop(); float elapsedTime = timer.MilliSeconds(); @@ -692,6 +710,11 @@ void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, return; } +#ifdef HAS_HALF_SUPPORT +template void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, + const cl_mem B_image, const size_t max_image_size); +#endif template void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, const cl_mem B_image, const size_t max_image_size); @@ -728,17 +751,16 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, } else { #ifdef USE_GREENTEA int padded_height = 0, padded_width = 0; - Dtype *bias_mult_data; - Dtype *bias_term_data; - if (bias_term_) { - bias_mult_data = (Dtype*)bias_multiplier_.gpu_data(); - bias_term_data = (Dtype*)this->blobs_[1]->gpu_data(); - } int height = !transpose_ ? N_ : K_; int width = !transpose_ ? K_ : N_; if (M_ != 1) { - padded_height = !transpose_ ? height : (height + ((height & 7) ? 1 : 0)); - padded_width = !transpose_ ? width : (width + ((width & 7) ? 1 : 0)); + if (std::is_same::value) { + padded_height = !transpose_ ? height : (height + ((height & 7) ? 1 : 0)); + padded_width = !transpose_ ? width : (width + ((width & 7) ? 1 : 0)); + } else { + padded_height = !transpose_ ? height : (height + ((height & 7) ? (8-(height%8)) : 0)); + padded_width = !transpose_ ? width : (width + ((width & 7) ? (8-(width%8)) : 0)); + } } if (M_ == 1) { @@ -759,18 +781,16 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, if (M_ <= max_image_size && N_ <= max_image_size && K_ <= max_image_size && - std::is_same::value && + !std::is_same::value && this->device_->CheckCapability("cl_intel_subgroups")) { if (!test_only_ || copied_weight_data_ != this->blobs_[0]->data().get()) { int height = !transpose_ ? N_ : K_; int width = !transpose_ ? K_ : N_; - int padded_height = !transpose_ ? height : (height + ((height & 7) ? 1 : 0)); - int padded_width = !transpose_ ? width : (width + ((width & 7) ? 1 : 0)); if (weight_image_) { clReleaseMemObject((cl_mem)weight_image_); weight_image_ = NULL; } - greentea_gpu_gemm_copy_buffer_to_image(this->device_->id(), + greentea_gpu_gemm_copy_buffer_to_image(this->device_->id(), &weight_image_, (cl_mem) weight, 0, false, !transpose_, true, padded_height, padded_width, diff --git a/src/caffe/layers/lrn_layer.cu b/src/caffe/layers/lrn_layer.cu index cbe831172e6..9ba52373ed3 100644 --- a/src/caffe/layers/lrn_layer.cu +++ b/src/caffe/layers/lrn_layer.cu @@ -120,9 +120,9 @@ void LRNLayer::CrossChannelForward_fuse_pooling_gpu( oclk_lrn_fill.arg(argIdx++, tiled_width); oclk_lrn_fill.arg(argIdx++, size_); oclk_lrn_fill.arg(argIdx++, alpha_ / size_); - oclk_lrn_fill.arg(argIdx++, k_); + oclk_lrn_fill.arg(argIdx++, fixup_arg_type(k_)); oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) top_data, &ctx)); - oclk_lrn_fill.arg(argIdx++, -beta_); + oclk_lrn_fill.arg(argIdx++, fixup_arg_type(-beta_)); oclk_lrn_fill.arg(argIdx++, pool_h_); oclk_lrn_fill.arg(argIdx++, pool_w_); oclk_lrn_fill.arg(argIdx++, pool_stride_h_); @@ -151,9 +151,9 @@ void LRNLayer::CrossChannelForward_fuse_pooling_gpu( oclk_lrn_fill.arg(argIdx++, width_); oclk_lrn_fill.arg(argIdx++, size_); oclk_lrn_fill.arg(argIdx++, alpha_ / size_); - oclk_lrn_fill.arg(argIdx++, k_); + oclk_lrn_fill.arg(argIdx++, fixup_arg_type(k_)); oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) top_lrn_data, &ctx)); - oclk_lrn_fill.arg(argIdx++, -beta_); + oclk_lrn_fill.arg(argIdx++, fixup_arg_type(-beta_)); OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), oclk_lrn_fill.handle().get(), 1, NULL, global_work_size_, NULL, 0, NULL, @@ -224,10 +224,10 @@ void LRNLayer::CrossChannelForward_gpu( oclk_lrn_fill.arg(argIdx++, width_); oclk_lrn_fill.arg(argIdx++, size_); oclk_lrn_fill.arg(argIdx++, alpha_ / size_); - oclk_lrn_fill.arg(argIdx++, k_); + oclk_lrn_fill.arg(argIdx++, fixup_arg_type(k_)); oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) scale_data, &ctx)); oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) top_data, &ctx)); - oclk_lrn_fill.arg(argIdx++, -beta_); + oclk_lrn_fill.arg(argIdx++, fixup_arg_type(-beta_)); OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), oclk_lrn_fill.handle().get(), 1, NULL, @@ -249,9 +249,9 @@ void LRNLayer::CrossChannelForward_gpu( oclk_lrn_fill.arg(argIdx++, width_); oclk_lrn_fill.arg(argIdx++, size_); oclk_lrn_fill.arg(argIdx++, alpha_ / size_); - oclk_lrn_fill.arg(argIdx++, k_); + oclk_lrn_fill.arg(argIdx++, fixup_arg_type(k_)); oclk_lrn_fill.arg(argIdx++, WrapHandle((cl_mem) top_data, &ctx)); - oclk_lrn_fill.arg(argIdx++, -beta_); + oclk_lrn_fill.arg(argIdx++, fixup_arg_type(-beta_)); OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), oclk_lrn_fill.handle().get(), 1, NULL, global_work_size_, NULL, 0, NULL, @@ -286,6 +286,12 @@ void LRNLayer::CrossChannelForward_gpu( #endif // USE_GREENTEA } } +#ifdef HAS_HALF_SUPPORT +template void LRNLayer::CrossChannelForward_gpu( + const vector*>& bottom, const vector*>& top); +template void LRNLayer::CrossChannelForward_fuse_pooling_gpu( + const vector*>& bottom, const vector*>& top, bool); +#endif template void LRNLayer::CrossChannelForward_gpu( const vector*>& bottom, const vector*>& top); template void LRNLayer::CrossChannelForward_gpu( @@ -404,13 +410,18 @@ void LRNLayer::CrossChannelBackward_gpu( WrapHandle((cl_mem) (top[0]->gpu_data()), &ctx), WrapHandle((cl_mem) (scale_.gpu_data()), &ctx), WrapHandle((cl_mem) (top[0]->gpu_diff()), &ctx), num_, - channels_, height_, width_, size_, -beta_, - Dtype(2. * alpha_ * beta_ / size_), + channels_, height_, width_, size_, fixup_arg_type(-beta_), + fixup_arg_type(Dtype(2. * alpha_ * beta_ / size_)), WrapHandle((cl_mem) (bottom[0]->mutable_gpu_diff()), &ctx)), ctx.get_queue()); #endif // USE_GREENTEA } } +#ifdef HAS_HALF_SUPPORT +template void LRNLayer::CrossChannelBackward_gpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom); +#endif template void LRNLayer::CrossChannelBackward_gpu( const vector*>& top, const vector& propagate_down, const vector*>& bottom); diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp index a0c2a02b650..977d0125445 100644 --- a/src/caffe/layers/pooling_layer.cpp +++ b/src/caffe/layers/pooling_layer.cpp @@ -247,6 +247,10 @@ void PoolingLayer::Forward_cpu(const vector*>& bottom, Dtype* top_mask = NULL; // Different pooling methods. We explicitly do the switch outside the for // loop to save time, although this results in more code. + + Dtype maxVal = FLT_MAX; + if (std::is_same::value) + maxVal = HALF_MAX; switch (this->layer_param_.pooling_param().pool()) { case PoolingParameter_PoolMethod_MAX: // Initialize @@ -257,7 +261,7 @@ void PoolingLayer::Forward_cpu(const vector*>& bottom, mask = max_idx_.mutable_cpu_data(); caffe_set(top_count, (int_tp)-1, mask); } - caffe_set(top_count, Dtype(-FLT_MAX), top_data); + caffe_set(top_count, Dtype(-maxVal), top_data); // The main loop for (int_tp n = 0; n < bottom[0]->num(); ++n) { for (int_tp c = 0; c < channels_; ++c) { @@ -377,7 +381,7 @@ void PoolingLayer::Backward_cpu(const vector*>& top, for (int_tp pw = 0; pw < pooled_width_; ++pw) { const int_tp index = ph * pooled_width_ + pw; const int_tp bottom_index = - use_top_mask ? top_mask[index] : mask[index]; + use_top_mask ? int_tp(top_mask[index]) : mask[index]; bottom_diff[bottom_index] += top_diff[index]; } } diff --git a/src/caffe/layers/power_layer.cpp b/src/caffe/layers/power_layer.cpp index 605661dae51..0bdf32a92df 100644 --- a/src/caffe/layers/power_layer.cpp +++ b/src/caffe/layers/power_layer.cpp @@ -58,10 +58,10 @@ void PowerLayer::Backward_cpu(const vector*>& top, // Special case for y = (shift + scale * x)^2 // -> dy/dx = 2 * scale * (shift + scale * x) // = diff_scale * shift + diff_scale * scale * x - caffe_cpu_axpby(count, diff_scale_ * scale_, bottom_data, + caffe_cpu_axpby(count, Dtype(diff_scale_ * scale_), bottom_data, Dtype(0), bottom_diff); if (shift_ != Dtype(0)) { - caffe_add_scalar(count, diff_scale_ * shift_, bottom_diff); + caffe_add_scalar(count, Dtype(diff_scale_ * shift_), bottom_diff); } } else if (shift_ == Dtype(0)) { // Special case for y = (scale * x)^power diff --git a/src/caffe/layers/power_layer.cu b/src/caffe/layers/power_layer.cu index 73396ac8096..2d2a9d250ac 100644 --- a/src/caffe/layers/power_layer.cu +++ b/src/caffe/layers/power_layer.cu @@ -135,11 +135,11 @@ void PowerLayer::Backward_gpu(const vector*>& top, // -> dy/dx = 2 * scale * (shift + scale * x) // = diff_scale * shift + diff_scale * scale * x greentea_gpu_axpby(this->device_->id(), count, - diff_scale_ * scale_, (cl_mem) bottom_data, 0, + Dtype(diff_scale_ * scale_), (cl_mem) bottom_data, 0, Dtype(0), (cl_mem) bottom_diff, 0); if (shift_ != Dtype(0)) { greentea_gpu_add_scalar(this->device_->id(), count, - diff_scale_ * shift_, (cl_mem) bottom_diff, + Dtype(diff_scale_ * shift_), (cl_mem) bottom_diff, 0); } } else if (shift_ == Dtype(0)) { diff --git a/src/caffe/layers/reduction_layer.cpp b/src/caffe/layers/reduction_layer.cpp index efd5a0b08e1..22b0d3eae03 100644 --- a/src/caffe/layers/reduction_layer.cpp +++ b/src/caffe/layers/reduction_layer.cpp @@ -107,7 +107,7 @@ void ReductionLayer::Backward_cpu(const vector*>& top, caffe_scal(dim_, bottom_coeff, bottom_diff); break; case ReductionParameter_ReductionOp_SUMSQ: - caffe_cpu_scale(dim_, 2 * bottom_coeff, bottom_data, bottom_diff); + caffe_cpu_scale(dim_, Dtype(2 * bottom_coeff), bottom_data, bottom_diff); break; default: LOG(FATAL) << "Unknown reduction op: " diff --git a/src/caffe/layers/relu_layer.cu b/src/caffe/layers/relu_layer.cu index e23cd245912..640feeb5fed 100644 --- a/src/caffe/layers/relu_layer.cu +++ b/src/caffe/layers/relu_layer.cu @@ -44,7 +44,8 @@ void ReLULayer::Forward_gpu(const vector*>& bottom, CL_KERNEL_SELECT("relu_forward")); viennacl::ocl::enqueue( oclk_relu_forward(count, WrapHandle((cl_mem) bottom_data, &ctx), - WrapHandle((cl_mem) top_data, &ctx), negative_slope), + WrapHandle((cl_mem) top_data, &ctx), + fixup_arg_type(negative_slope)), ctx.get_queue()); #endif // USE_GREENTEA } @@ -96,7 +97,7 @@ void ReLULayer::Backward_gpu(const vector*>& top, oclk_relu_backward(count, WrapHandle((cl_mem) top_diff, &ctx), WrapHandle((cl_mem) bottom_data, &ctx), WrapHandle((cl_mem) bottom_diff, &ctx), - negative_slope), + fixup_arg_type(negative_slope)), ctx.get_queue()); #endif // USE_GREENTEA } diff --git a/src/caffe/layers/silence_layer.cu b/src/caffe/layers/silence_layer.cu index c7b5b3e261d..cb5a2107bcc 100644 --- a/src/caffe/layers/silence_layer.cu +++ b/src/caffe/layers/silence_layer.cu @@ -37,7 +37,7 @@ void SilenceLayer::Backward_gpu(const vector*>& top, CL_KERNEL_SELECT("gpu_set")); viennacl::ocl::enqueue( oclk_gpu_set( - bottom[i]->count(), Dtype(0), + bottom[i]->count(), fixup_arg_type(Dtype(0)), WrapHandle((cl_mem) bottom[i]->mutable_gpu_diff(), &ctx)), ctx.get_queue()); ctx.get_queue().finish(); diff --git a/src/caffe/layers/softmax_loss_layer.cpp b/src/caffe/layers/softmax_loss_layer.cpp index cfdbc891c89..bd304429c4b 100644 --- a/src/caffe/layers/softmax_loss_layer.cpp +++ b/src/caffe/layers/softmax_loss_layer.cpp @@ -103,8 +103,11 @@ void SoftmaxWithLossLayer::Forward_cpu( } DCHECK_GE(label_value, 0); DCHECK_LT(label_value, prob_.shape(softmax_axis_)); + Dtype min_value = FLT_MIN; + if (std::is_same::value) + min_value = HALF_MIN; loss -= log(std::max(prob_data[i * dim + label_value * inner_num_ + j], - Dtype(FLT_MIN))); + Dtype(min_value))); ++count; } } diff --git a/src/caffe/layers/threshold_layer.cu b/src/caffe/layers/threshold_layer.cu index b3486f4c318..9e19b24d849 100644 --- a/src/caffe/layers/threshold_layer.cu +++ b/src/caffe/layers/threshold_layer.cu @@ -43,7 +43,7 @@ void ThresholdLayer::Forward_gpu(const vector*>& bottom, viennacl::ocl::kernel &oclk_threshold = program.get_kernel( CL_KERNEL_SELECT("threshold")); viennacl::ocl::enqueue( - oclk_threshold(count, threshold_, + oclk_threshold(count, fixup_arg_type(threshold_), WrapHandle((cl_mem) bottom_data, &ctx), WrapHandle((cl_mem) top_data, &ctx)), ctx.get_queue()); diff --git a/src/caffe/solvers/adadelta_solver.cpp b/src/caffe/solvers/adadelta_solver.cpp index 8b9b8f05134..d83314d85b5 100644 --- a/src/caffe/solvers/adadelta_solver.cpp +++ b/src/caffe/solvers/adadelta_solver.cpp @@ -38,7 +38,7 @@ void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { this->update_[param_id]->mutable_cpu_data()); // update history of gradients - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(Dtype(1) - momentum), this->update_[param_id]->cpu_data(), momentum, this->history_[param_id]->mutable_cpu_data()); @@ -79,7 +79,7 @@ void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { this->update_[param_id]->mutable_cpu_data()); // update history of updates - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(Dtype(1) - momentum), this->update_[param_id]->cpu_data(), momentum, this->history_[update_history_offset + param_id]->mutable_cpu_data()); diff --git a/src/caffe/solvers/adadelta_solver.cu b/src/caffe/solvers/adadelta_solver.cu index 97daecf41a9..c180d031e79 100644 --- a/src/caffe/solvers/adadelta_solver.cu +++ b/src/caffe/solvers/adadelta_solver.cu @@ -41,12 +41,18 @@ void adadelta_update_gpu(device* dev, int_tp N, Dtype* g, Dtype* h, Dtype* h2, viennacl::ocl::enqueue( oclk_ada_delta_update(N, WrapHandle((cl_mem) g, &ctx), WrapHandle((cl_mem) h, &ctx), - WrapHandle((cl_mem) h2, &ctx), momentum, delta, - local_rate), + WrapHandle((cl_mem) h2, &ctx), + fixup_arg_type(momentum), fixup_arg_type(delta), + fixup_arg_type(local_rate)), ctx.get_queue()); #endif // USE_GREENTEA } } + +#ifdef HAS_HALF_SUPPORT +template void adadelta_update_gpu(device*, int_tp, half*, half*, + half*, half, half, half); +#endif template void adadelta_update_gpu(device*, int_tp, float*, float*, float*, float, float, float); template void adadelta_update_gpu(device*, int_tp, double*, double*, diff --git a/src/caffe/solvers/adagrad_solver.cu b/src/caffe/solvers/adagrad_solver.cu index 347285807c7..18065a4e607 100644 --- a/src/caffe/solvers/adagrad_solver.cu +++ b/src/caffe/solvers/adagrad_solver.cu @@ -38,12 +38,16 @@ void adagrad_update_gpu(device* dev, int_tp N, Dtype* g, Dtype* h, Dtype delta, CL_KERNEL_SELECT("ada_grad_update")); viennacl::ocl::enqueue( oclk_ada_grad_update(N, WrapHandle((cl_mem) g, &ctx), - WrapHandle((cl_mem) h, &ctx), delta, local_rate), + WrapHandle((cl_mem) h, &ctx), fixup_arg_type(delta), fixup_arg_type(local_rate)), ctx.get_queue()); #endif // USE_GREENTEA } } +#ifdef HAS_HALF_SUPPORT +template void adagrad_update_gpu(device*, int_tp, half*, half*, half, + half); +#endif template void adagrad_update_gpu(device*, int_tp, float*, float*, float, float); template void adagrad_update_gpu(device*, int_tp, double*, double*, diff --git a/src/caffe/solvers/adam_solver.cpp b/src/caffe/solvers/adam_solver.cpp index 651fc0177e5..41bc6dce0e9 100644 --- a/src/caffe/solvers/adam_solver.cpp +++ b/src/caffe/solvers/adam_solver.cpp @@ -38,15 +38,15 @@ void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { Blob* val_t = this->temp_[param_id].get(); const uint_tp t = this->iter_ + 1; - const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) / - (Dtype(1.) - pow(beta1, t)); + const Dtype correction = sqrt(Dtype(1) - pow(beta2, Dtype(t))) / + (Dtype(1.) - pow(beta1, Dtype(t))); const uint_tp N = net_params[param_id]->count(); const Dtype eps_hat = this->param_.delta(); switch (Caffe::mode()) { case Caffe::CPU: { // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t - caffe_cpu_axpby(N, Dtype(1)-beta1, + caffe_cpu_axpby(N, Dtype(Dtype(1)-beta1), net_params[param_id]->cpu_diff(), beta1, val_m->mutable_cpu_data()); @@ -55,7 +55,7 @@ void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { net_params[param_id]->cpu_diff(), net_params[param_id]->cpu_diff(), val_t->mutable_cpu_data()); - caffe_cpu_axpby(N, Dtype(1)-beta2, + caffe_cpu_axpby(N, Dtype(Dtype(1)-beta2), val_t->cpu_data(), beta2, val_v->mutable_cpu_data()); @@ -69,7 +69,7 @@ void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { val_t->cpu_data(), val_t->mutable_cpu_data()); - caffe_cpu_scale(N, local_rate*correction, + caffe_cpu_scale(N, Dtype(local_rate*correction), val_t->cpu_data(), net_params[param_id]->mutable_cpu_diff()); break; @@ -79,7 +79,7 @@ void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { adam_update_gpu(this->device_, N, net_params[param_id]->mutable_gpu_diff(), val_m->mutable_gpu_data(), val_v->mutable_gpu_data(), - beta1, beta2, eps_hat, local_rate * correction); + beta1, beta2, eps_hat, Dtype(local_rate * correction)); #else NO_GPU; #endif diff --git a/src/caffe/solvers/adam_solver.cu b/src/caffe/solvers/adam_solver.cu index 5fc35918ad5..18102405154 100644 --- a/src/caffe/solvers/adam_solver.cu +++ b/src/caffe/solvers/adam_solver.cu @@ -42,14 +42,18 @@ void adam_update_gpu(device* dev, int_tp N, Dtype* g, Dtype* m, Dtype* v, viennacl::ocl::enqueue( oclk_adam_update(N, WrapHandle((cl_mem) g, &ctx), WrapHandle((cl_mem) m, &ctx), - WrapHandle((cl_mem) v, &ctx), beta1, beta2, eps_hat, - corrected_local_rate), + WrapHandle((cl_mem) v, &ctx), fixup_arg_type(beta1), + fixup_arg_type(beta2), fixup_arg_type(eps_hat), + fixup_arg_type(corrected_local_rate)), ctx.get_queue()); #endif // USE_GREENTEA } } - +#ifdef HAS_HALF_SUPPORT +template void adam_update_gpu(device*, int_tp, half*, half*, half*, + half, half, half, half); +#endif template void adam_update_gpu(device*, int_tp, float*, float*, float*, float, float, float, float); template void adam_update_gpu(device*, int_tp, double*, double*, diff --git a/src/caffe/solvers/nesterov_solver.cpp b/src/caffe/solvers/nesterov_solver.cpp index 906e55e350b..3a22add3b16 100644 --- a/src/caffe/solvers/nesterov_solver.cpp +++ b/src/caffe/solvers/nesterov_solver.cpp @@ -29,8 +29,8 @@ void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { this->history_[param_id]->mutable_cpu_data()); // compute update: step back then over step - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, - this->history_[param_id]->cpu_data(), -momentum, + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(Dtype(1) + momentum), + this->history_[param_id]->cpu_data(), Dtype(-momentum), this->update_[param_id]->mutable_cpu_data()); // copy diff --git a/src/caffe/solvers/nesterov_solver.cu b/src/caffe/solvers/nesterov_solver.cu index 9a0d491a59e..fcb30bd6920 100644 --- a/src/caffe/solvers/nesterov_solver.cu +++ b/src/caffe/solvers/nesterov_solver.cu @@ -38,13 +38,17 @@ void nesterov_update_gpu(device* dev, int_tp N, Dtype* g, Dtype* h, CL_KERNEL_SELECT("nesterov_update")); viennacl::ocl::enqueue( oclk_nesterov_update(N, WrapHandle((cl_mem) g, &ctx), - WrapHandle((cl_mem) h, &ctx), momentum, - local_rate), + WrapHandle((cl_mem) h, &ctx), fixup_arg_type(momentum), + fixup_arg_type(local_rate)), ctx.get_queue()); #endif // USE_GREENTEA } } +#ifdef HAS_HALF_SUPPORT +template void nesterov_update_gpu(device*, int_tp, half*, half*, half, + half); +#endif template void nesterov_update_gpu(device*, int_tp, float*, float*, float, float); template void nesterov_update_gpu(device*, int_tp, double*, double*, diff --git a/src/caffe/solvers/rmsprop_solver.cu b/src/caffe/solvers/rmsprop_solver.cu index dc62df571f0..3ab795ad962 100644 --- a/src/caffe/solvers/rmsprop_solver.cu +++ b/src/caffe/solvers/rmsprop_solver.cu @@ -38,14 +38,18 @@ void rmsprop_update_gpu(device* dev, int_tp N, Dtype* g, Dtype* h, CL_KERNEL_SELECT("rms_prop_update")); viennacl::ocl::enqueue( oclk_rms_prop_update(N, WrapHandle((cl_mem) g, &ctx), - WrapHandle((cl_mem) h, &ctx), - rms_decay, delta, - local_rate), + WrapHandle((cl_mem) h, &ctx), + fixup_arg_type(rms_decay), fixup_arg_type(delta), + fixup_arg_type(local_rate)), ctx.get_queue()); #endif // USE_GREENTEA } } +#ifdef HAS_HALF_SUPPORT +template void rmsprop_update_gpu(device*, int_tp, half*, half*, half, + half, half); +#endif template void rmsprop_update_gpu(device*, int_tp, float*, float*, float, float, float); template void rmsprop_update_gpu(device*, int_tp, double*, double*, diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp index 56c75ab9e0e..50ce47cf577 100644 --- a/src/caffe/solvers/sgd_solver.cpp +++ b/src/caffe/solvers/sgd_solver.cpp @@ -51,7 +51,7 @@ Dtype SGDSolver::GetLearningRate() { } else if (lr_policy == "poly") { rate = this->param_.base_lr() * pow(Dtype(1.) - (Dtype(this->iter_) / Dtype(this->param_.max_iter())), - this->param_.power()); + Dtype(this->param_.power())); } else if (lr_policy == "sigmoid") { rate = this->param_.base_lr() * (Dtype(1.) / (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - diff --git a/src/caffe/solvers/sgd_solver.cu b/src/caffe/solvers/sgd_solver.cu index d0cd2cb26f0..c086abb49eb 100644 --- a/src/caffe/solvers/sgd_solver.cu +++ b/src/caffe/solvers/sgd_solver.cu @@ -36,11 +36,17 @@ void sgd_update_gpu(device* dev, int_tp N, Dtype* g, Dtype* h, Dtype momentum, CL_KERNEL_SELECT("sgd_update")); viennacl::ocl::enqueue( oclk_sgd_update(N, WrapHandle((cl_mem) g, &ctx), - WrapHandle((cl_mem) h, &ctx), momentum, local_rate), + WrapHandle((cl_mem) h, &ctx), fixup_arg_type(momentum), + fixup_arg_type(local_rate)), ctx.get_queue()); #endif // USE_GREENTEA } } + +#ifdef HAS_HALF_SUPPORT +template void sgd_update_gpu(device*, int_tp, half*, half*, half, + half); +#endif template void sgd_update_gpu(device*, int_tp, float*, float*, float, float); template void sgd_update_gpu(device*, int_tp, double*, double*, double, diff --git a/src/caffe/test/test_accuracy_layer.cpp b/src/caffe/test/test_accuracy_layer.cpp index 45aa94349c0..5b613af4a1d 100644 --- a/src/caffe/test/test_accuracy_layer.cpp +++ b/src/caffe/test/test_accuracy_layer.cpp @@ -119,6 +119,8 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPU) { int_tp num_correct_labels = 0; for (int_tp i = 0; i < 100; ++i) { max_value = -FLT_MAX; + if (std::is_same::value) + max_value = -HALF_MAX; max_id = 0; for (int_tp j = 0; j < 10; ++j) { if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { @@ -130,8 +132,10 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPU) { ++num_correct_labels; } } + const TypeParam delta = std::is_same::value ? + 1e-2 : 1e-4; EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / 100.0, 1e-4); + num_correct_labels / 100.0, delta); } TYPED_TEST(AccuracyLayerTest, TestForwardWithSpatialAxes) { @@ -155,6 +159,8 @@ TYPED_TEST(AccuracyLayerTest, TestForwardWithSpatialAxes) { for (int_tp h = 0; h < this->blob_bottom_data_->height(); ++h) { for (int_tp w = 0; w < this->blob_bottom_data_->width(); ++w) { max_value = -FLT_MAX; + if (std::is_same::value) + max_value = -HALF_MAX; max_id = 0; for (int_tp c = 0; c < this->blob_bottom_data_->channels(); ++c) { const TypeParam pred_value = @@ -174,8 +180,10 @@ TYPED_TEST(AccuracyLayerTest, TestForwardWithSpatialAxes) { } } } + const TypeParam delta = std::is_same::value ? + 1e-2 : 1e-4; EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / TypeParam(num_labels), 1e-4); + num_correct_labels / TypeParam(num_labels), delta); } TYPED_TEST(AccuracyLayerTest, TestForwardIgnoreLabel) { @@ -200,6 +208,8 @@ TYPED_TEST(AccuracyLayerTest, TestForwardIgnoreLabel) { } ++count; max_value = -FLT_MAX; + if (std::is_same::value) + max_value = -HALF_MAX; max_id = 0; for (int_tp j = 0; j < 10; ++j) { if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { @@ -212,8 +222,10 @@ TYPED_TEST(AccuracyLayerTest, TestForwardIgnoreLabel) { } } EXPECT_EQ(count, 97); // We set 3 out of 100 labels to kIgnoreLabelValue. + const TypeParam delta = std::is_same::value ? + 1e-2 : 1e-4; EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / TypeParam(count), 1e-4); + num_correct_labels / TypeParam(count), delta); } TYPED_TEST(AccuracyLayerTest, TestForwardCPUTopK) { @@ -243,8 +255,10 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUTopK) { } } + const TypeParam delta = std::is_same::value ? + 1e-2 : 1e-4; EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / 100.0, 1e-4); + num_correct_labels / 100.0, delta); } TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClass) { @@ -261,6 +275,8 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClass) { vector num_per_class(num_class, 0); for (int_tp i = 0; i < 100; ++i) { max_value = -FLT_MAX; + if (std::is_same::value) + max_value = -HALF_MAX; max_id = 0; for (int_tp j = 0; j < 10; ++j) { if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { @@ -274,13 +290,15 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClass) { ++correct_per_class[max_id]; } } + const TypeParam delta = std::is_same::value ? + 1e-2 : 1e-4; EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / 100.0, 1e-4); + num_correct_labels / 100.0, delta); for (int_tp i = 0; i < num_class; ++i) { TypeParam accuracy_per_class = (num_per_class[i] > 0 ? static_cast(correct_per_class[i]) / num_per_class[i] : 0); EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0), - accuracy_per_class, 1e-4); + accuracy_per_class, delta); } } @@ -310,6 +328,8 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClassWithIgnoreLabel) { } ++count; max_value = -FLT_MAX; + if (std::is_same::value) + max_value = -HALF_MAX; max_id = 0; for (int_tp j = 0; j < 10; ++j) { if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { @@ -324,13 +344,15 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClassWithIgnoreLabel) { } } EXPECT_EQ(count, 97); + const TypeParam delta = std::is_same::value ? + 1e-2 : 1e-4; EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / TypeParam(count), 1e-4); + num_correct_labels / TypeParam(count), delta); for (int_tp i = 0; i < 10; ++i) { TypeParam accuracy_per_class = (num_per_class[i] > 0 ? static_cast(correct_per_class[i]) / num_per_class[i] : 0); EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0), - accuracy_per_class, 1e-4); + accuracy_per_class, delta); } } diff --git a/src/caffe/test/test_batch_norm_layer.cpp b/src/caffe/test/test_batch_norm_layer.cpp index 936b93a1756..d19c4b8d6d4 100644 --- a/src/caffe/test/test_batch_norm_layer.cpp +++ b/src/caffe/test/test_batch_norm_layer.cpp @@ -68,7 +68,8 @@ namespace caffe { sum /= height * width * num; var /= height * width * num; - const Dtype kErrorBound = 0.001; + const Dtype kErrorBound = std::is_same::value ? + 1e-1 : 1e-3; // expect zero mean EXPECT_NEAR(0, sum, kErrorBound); // expect unit variance @@ -112,7 +113,8 @@ namespace caffe { sum /= height * width * num; var /= height * width * num; - const Dtype kErrorBound = 0.001; + const Dtype kErrorBound = std::is_same::value ? + 1e-1 : 1e-3; // expect zero mean EXPECT_NEAR(0, sum, kErrorBound); // expect unit variance diff --git a/src/caffe/test/test_bias_layer.cpp b/src/caffe/test/test_bias_layer.cpp index 42bd342d7ea..22df442ec23 100644 --- a/src/caffe/test/test_bias_layer.cpp +++ b/src/caffe/test/test_bias_layer.cpp @@ -82,8 +82,11 @@ TYPED_TEST(BiasLayerTest, TestForwardEltwise) { const int_tp count = this->blob_top_->count(); const Dtype* in_data_a = this->blob_bottom_->cpu_data(); const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], 1e-5); + EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], + delta * fabs(in_data_a[i] + in_data_b[i])); } } @@ -102,8 +105,11 @@ TYPED_TEST(BiasLayerTest, TestForwardEltwiseInPlace) { const int_tp count = this->blob_bottom_->count(); const Dtype* in_data_a = orig_bottom.cpu_data(); const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], 1e-5); + EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], + delta * fabs(in_data_a[i] + in_data_b[i])); } } @@ -143,13 +149,15 @@ TYPED_TEST(BiasLayerTest, TestBackwardEltwiseInPlace) { caffe_copy(top_diff.count(), top_diff.cpu_data(), this->blob_bottom_->mutable_cpu_diff()); layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], - this->blob_bottom_->cpu_diff()[i], 1e-5); + this->blob_bottom_->cpu_diff()[i], delta); } for (int_tp i = 0; i < this->blob_bottom_eltwise_->count(); ++i) { EXPECT_NEAR(orig_bias_diff.cpu_diff()[i], - this->blob_bottom_eltwise_->cpu_diff()[i], 1e-5); + this->blob_bottom_eltwise_->cpu_diff()[i], delta); } } @@ -168,8 +176,10 @@ TYPED_TEST(BiasLayerTest, TestForwardEltwiseWithParam) { const int_tp count = this->blob_top_->count(); const Dtype* in_data_a = this->blob_bottom_->cpu_data(); const Dtype* in_data_b = layer->blobs()[0]->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], 1e-5); + EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], delta); } } @@ -182,6 +192,8 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastBegin) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { @@ -189,7 +201,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastBegin) { EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_->data_at(n, c, h, w) + this->blob_bottom_broadcast_0_->data_at(n, c, 0, 0), - 1e-5); + delta * fabs(this->blob_top_->data_at(n, c, h, w))); } } } @@ -205,6 +217,8 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddle) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { @@ -212,7 +226,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddle) { EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_->data_at(n, c, h, w) + this->blob_bottom_broadcast_1_->data_at(c, h, 0, 0), - 1e-5); + delta * fabs(this->blob_top_->data_at(n, c, h, w))); } } } @@ -230,6 +244,8 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddleInPlace) { shared_ptr > layer(new BiasLayer(layer_param)); layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { @@ -237,7 +253,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddleInPlace) { EXPECT_NEAR(this->blob_bottom_->data_at(n, c, h, w), orig_bottom.data_at(n, c, h, w) + this->blob_bottom_broadcast_1_->data_at(c, h, 0, 0), - 1e-5); + delta); } } } @@ -280,13 +296,15 @@ TYPED_TEST(BiasLayerTest, TestBackwardBroadcastMiddleInPlace) { caffe_copy(top_diff.count(), top_diff.cpu_data(), this->blob_bottom_->mutable_cpu_diff()); layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], - this->blob_bottom_->cpu_diff()[i], 1e-5); + this->blob_bottom_->cpu_diff()[i], delta); } for (int_tp i = 0; i < this->blob_bottom_broadcast_1_->count(); ++i) { EXPECT_NEAR(orig_bias_diff.cpu_diff()[i], - this->blob_bottom_broadcast_1_->cpu_diff()[i], 1e-5); + this->blob_bottom_broadcast_1_->cpu_diff()[i], delta); } } @@ -301,13 +319,15 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddleWithParam) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { for (int_tp w = 0; w < this->blob_bottom_->width(); ++w) { EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_->data_at(n, c, h, w) + - layer->blobs()[0]->data_at(c, h, 0, 0), 1e-5); + layer->blobs()[0]->data_at(c, h, 0, 0), delta); } } } @@ -323,6 +343,8 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastEnd) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { @@ -330,7 +352,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastEnd) { EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_->data_at(n, c, h, w) + this->blob_bottom_broadcast_2_->data_at(h, w, 0, 0), - 1e-5); + delta); } } } @@ -349,8 +371,10 @@ TYPED_TEST(BiasLayerTest, TestForwardBias) { const int_tp count = this->blob_top_->count(); const Dtype* in_data = this->blob_bottom_->cpu_data(); const Dtype bias = *this->blob_bottom_bias_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data[i] + bias, 1e-5); + EXPECT_NEAR(data[i], in_data[i] + bias, delta); } } @@ -367,8 +391,10 @@ TYPED_TEST(BiasLayerTest, TestForwardBiasAxis2) { const int_tp count = this->blob_top_->count(); const Dtype* in_data = this->blob_bottom_->cpu_data(); const Dtype bias = *this->blob_bottom_bias_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data[i] + bias, 1e-5); + EXPECT_NEAR(data[i], in_data[i] + bias, delta); } } diff --git a/src/caffe/test/test_blob.cpp b/src/caffe/test/test_blob.cpp index 056efd6b9ef..0c51af3be27 100644 --- a/src/caffe/test/test_blob.cpp +++ b/src/caffe/test/test_blob.cpp @@ -120,7 +120,10 @@ class BlobMathTest : public MultiDeviceTest { protected: BlobMathTest() : blob_(new Blob(2, 3, 4, 5)), - epsilon_(1e-6) {} + epsilon_(1e-6) { + if (std::is_same::value) + epsilon_ = 1e-2; + } virtual ~BlobMathTest() { delete blob_; } Blob* const blob_; diff --git a/src/caffe/test/test_caffe_main.cpp b/src/caffe/test/test_caffe_main.cpp index 0f676b365c0..66affb26877 100644 --- a/src/caffe/test/test_caffe_main.cpp +++ b/src/caffe/test/test_caffe_main.cpp @@ -36,8 +36,24 @@ bool caffe::isSupported(void) { return caffe::Caffe::GetDefaultDevice()->backend() != caffe::BACKEND_OpenCL || caffe::Caffe::GetDefaultDevice()->CheckCapability("cl_khr_fp64"); } +#ifdef HAS_HALF_SUPPORT +template <> +bool caffe::isSupported(void) { + return caffe::Caffe::GetDefaultDevice()->backend() != caffe::BACKEND_OpenCL || + caffe::Caffe::GetDefaultDevice()->CheckCapability("cl_khr_fp16"); +} + +template <> +bool caffe::isSupported>(void) { + return caffe::isSupported(); +} template <> +bool caffe::isSupported>(void) { + return true; +} +#endif +template <> bool caffe::isSupported>(void) { return caffe::isSupported(); } @@ -63,6 +79,7 @@ bool caffe::isSupported(void) { return true; } #endif + #endif #ifndef CPU_ONLY diff --git a/src/caffe/test/test_contrastive_loss_layer.cpp b/src/caffe/test/test_contrastive_loss_layer.cpp index 9f5ba00f122..96fc1ec593d 100644 --- a/src/caffe/test/test_contrastive_loss_layer.cpp +++ b/src/caffe/test/test_contrastive_loss_layer.cpp @@ -77,12 +77,13 @@ TYPED_TEST(ContrastiveLossLayerTest, TestForward) { if (this->blob_bottom_y_->cpu_data()[i]) { // similar pairs loss += dist_sq; } else { - Dtype dist = std::max(margin - sqrt(dist_sq), 0.0); + Dtype dist = fmax(Dtype(margin - sqrt(dist_sq)), Dtype(0.0)); loss += dist*dist; } } loss /= static_cast(num) * Dtype(2); - EXPECT_NEAR(this->blob_top_loss_->cpu_data()[0], loss, 1e-5); + Dtype delta = 1e-5 * std::is_same::value ? 100 : 1; + EXPECT_NEAR(this->blob_top_loss_->cpu_data()[0], loss, delta); } TYPED_TEST(ContrastiveLossLayerTest, TestGradient) { @@ -120,11 +121,12 @@ TYPED_TEST(ContrastiveLossLayerTest, TestForwardLegacy) { if (this->blob_bottom_y_->cpu_data()[i]) { // similar pairs loss += dist_sq; } else { - loss += std::max(margin - dist_sq, Dtype(0.0)); + loss += fmax(Dtype(margin - dist_sq), Dtype(0.0)); } } loss /= static_cast(num) * Dtype(2); - EXPECT_NEAR(this->blob_top_loss_->cpu_data()[0], loss, 1e-5); + Dtype delta = 1e-5 * std::is_same::value ? 100 : 1; + EXPECT_NEAR(this->blob_top_loss_->cpu_data()[0], loss, delta); } TYPED_TEST(ContrastiveLossLayerTest, TestGradientLegacy) { diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp index 8c630980032..0382fd4973b 100644 --- a/src/caffe/test/test_convolution_layer.cpp +++ b/src/caffe/test/test_convolution_layer.cpp @@ -252,15 +252,17 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } @@ -296,15 +298,17 @@ TYPED_TEST(ConvolutionLayerTest, TestDilatedConvolution) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } @@ -332,6 +336,8 @@ TYPED_TEST(ConvolutionLayerTest, Test0DConvolution) { const int_tp num = this->blob_top_->count(3); const int_tp dim = this->blob_top_->shape(3); const int_tp bottom_dim = this->blob_bottom_->shape(3); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp n = 0; n < num; ++n) { for (int_tp d = 0; d < dim; ++d) { weight_offset[0] = d; @@ -341,7 +347,7 @@ TYPED_TEST(ConvolutionLayerTest, Test0DConvolution) { value += weight->data_at(weight_offset) * this->blob_bottom_->cpu_data()[n * bottom_dim + bottom_d]; } - EXPECT_NEAR(value, this->blob_top_->cpu_data()[n * dim + d], 1e-4); + EXPECT_NEAR(value, this->blob_top_->cpu_data()[n * dim + d], delta); } } } @@ -381,15 +387,17 @@ TYPED_TEST(ConvolutionLayerTest, TestSimple3DConvolution) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + const Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } @@ -428,15 +436,17 @@ TYPED_TEST(ConvolutionLayerTest, TestDilated3DConvolution) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + const Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } @@ -462,8 +472,10 @@ TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + const Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } @@ -490,8 +502,10 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + const Dtype delta = std::is_same::value ? + 5e-2 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } @@ -583,8 +597,10 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) { // Test equivalence of full and separable filters. const Dtype* top_data = this->blob_top_->cpu_data(); const Dtype* sep_top_data = this->blob_top_2_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], sep_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], sep_top_data[i], delta); } } @@ -919,15 +935,17 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -955,8 +973,10 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -1049,8 +1069,10 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) { // Test equivalence of full and separable filters. const TypeParam* top_data = this->blob_top_->cpu_data(); const TypeParam* sep_top_data = this->blob_top_2_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], sep_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], sep_top_data[i], delta); } } } diff --git a/src/caffe/test/test_convolution_layer_spatial.cpp b/src/caffe/test/test_convolution_layer_spatial.cpp index 944ccec5702..25a12b0e788 100644 --- a/src/caffe/test/test_convolution_layer_spatial.cpp +++ b/src/caffe/test/test_convolution_layer_spatial.cpp @@ -136,6 +136,12 @@ void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, } } } +#ifdef HAS_HALF_SUPPORT +template void caffe_conv(const Blob* in, + ConvolutionParameter* conv_param, + const vector > >& weights, + Blob* out); +#endif template void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, const vector > >& weights, @@ -257,15 +263,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, TestSimpleConvolution_Spatial) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -297,15 +305,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, TestSimpleConvolution_Spatial3x3) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -340,15 +350,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -393,15 +405,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, TestDilatedConvolution) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_vec_[1], convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -435,15 +449,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -479,15 +495,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -522,15 +540,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -565,15 +585,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -608,15 +630,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -650,15 +674,17 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, TestSimpleConvolution_Spatial5x5) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -688,8 +714,10 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, Test1x1Convolution_Spatial) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-2 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -720,8 +748,10 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, TestSimpleConvolutionGroup_Spatial) { this->MakeReferenceTop(this->blob_top_)); top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-2 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } } } @@ -818,8 +848,10 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, TestSobelConvolution_Spatial) { // Test equivalence of full and separable filters. const Dtype* top_data = this->blob_top_->cpu_data(); const Dtype* sep_top_data = this->blob_top_2_->cpu_data(); + Dtype delta = std::is_same::value ? + 5e-2 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], sep_top_data[i], 1e-4); + EXPECT_NEAR(top_data[i], sep_top_data[i], delta); } } } diff --git a/src/caffe/test/test_deconvolution_layer.cpp b/src/caffe/test/test_deconvolution_layer.cpp index dd61dc0b90f..fcd2ea4fb6e 100644 --- a/src/caffe/test/test_deconvolution_layer.cpp +++ b/src/caffe/test/test_deconvolution_layer.cpp @@ -128,8 +128,10 @@ TYPED_TEST(DeconvolutionLayerTest, TestSimpleDeconvolution) { } else if (h_overlap || w_overlap) { expected += 3; } + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; EXPECT_NEAR(top_data[this->blob_top_->offset(n, c, h, w)], - expected, 1e-4); + expected, delta); } } } diff --git a/src/caffe/test/test_eltwise_layer.cpp b/src/caffe/test/test_eltwise_layer.cpp index f59523ce9ab..884feb2e3fa 100644 --- a/src/caffe/test/test_eltwise_layer.cpp +++ b/src/caffe/test/test_eltwise_layer.cpp @@ -79,8 +79,10 @@ TYPED_TEST(EltwiseLayerTest, TestProd) { const Dtype* in_data_a = this->blob_bottom_a_->cpu_data(); const Dtype* in_data_b = this->blob_bottom_b_->cpu_data(); const Dtype* in_data_c = this->blob_bottom_c_->cpu_data(); + Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i] * in_data_c[i], 1e-4); + EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i] * in_data_c[i], delta); } } @@ -98,8 +100,10 @@ TYPED_TEST(EltwiseLayerTest, TestSum) { const Dtype* in_data_a = this->blob_bottom_a_->cpu_data(); const Dtype* in_data_b = this->blob_bottom_b_->cpu_data(); const Dtype* in_data_c = this->blob_bottom_c_->cpu_data(); + Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i] + in_data_c[i], 1e-4); + EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i] + in_data_c[i], delta); } } @@ -120,9 +124,11 @@ TYPED_TEST(EltwiseLayerTest, TestSumCoeff) { const Dtype* in_data_a = this->blob_bottom_a_->cpu_data(); const Dtype* in_data_b = this->blob_bottom_b_->cpu_data(); const Dtype* in_data_c = this->blob_bottom_c_->cpu_data(); + Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp i = 0; i < count; ++i) { EXPECT_NEAR(data[i], in_data_a[i] - 0.5*in_data_b[i] + 2*in_data_c[i], - 1e-4); + delta); } } diff --git a/src/caffe/test/test_embed_layer.cpp b/src/caffe/test/test_embed_layer.cpp index fe17501f9bd..ab4ae62d8ee 100644 --- a/src/caffe/test/test_embed_layer.cpp +++ b/src/caffe/test/test_embed_layer.cpp @@ -124,8 +124,8 @@ TYPED_TEST(EmbedLayerTest, TestForwardWithBias) { top_offset[4] = 0; bias_offset[0] = 0; for (int_tp j = 0; j < kNumOutput; ++j) { - EXPECT_FLOAT_EQ(layer->blobs()[0]->data_at(weight_offset) + - layer->blobs()[1]->data_at(bias_offset), + EXPECT_FLOAT_EQ(Dtype(layer->blobs()[0]->data_at(weight_offset) + + layer->blobs()[1]->data_at(bias_offset)), this->blob_top_->data_at(top_offset)); ++top_offset[4]; ++weight_offset[1]; diff --git a/src/caffe/test/test_euclidean_loss_layer.cpp b/src/caffe/test/test_euclidean_loss_layer.cpp index b026f5b2077..f03a93e601f 100644 --- a/src/caffe/test/test_euclidean_loss_layer.cpp +++ b/src/caffe/test/test_euclidean_loss_layer.cpp @@ -54,8 +54,10 @@ class EuclideanLossLayerTest : public MultiDeviceTest { layer_weight_2.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); const Dtype loss_weight_2 = layer_weight_2.Forward(this->blob_bottom_vec_, this->blob_top_vec_); - const Dtype kErrorMargin = 1e-5; - EXPECT_NEAR(loss_weight_1 * kLossWeight, loss_weight_2, kErrorMargin); + const Dtype kErrorMargin = std::is_same::value ? + 1e-3 : 1e-5; + EXPECT_NEAR(loss_weight_1 * kLossWeight, loss_weight_2, + kErrorMargin * fabs(loss_weight_2)); // Make sure the loss is non-trivial. const Dtype kNonTrivialAbsThresh = 1e-1; EXPECT_GE(fabs(loss_weight_1), kNonTrivialAbsThresh); diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index b7905d60020..081ab955131 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -276,8 +276,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { Dtype element = 0; for (int k = 0; k < N; ++k) { // (i, k) in X^T (== (k, i) in X) times (k, j) in X. - const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; - const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j]; + const Dtype element_i = (i == D) ? Dtype(1) : data.cpu_data()[k * D + i]; + const Dtype element_j = (j == D) ? Dtype(1) : data.cpu_data()[k * D + j]; element += element_i * element_j; } if (j == D) { @@ -287,7 +287,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { } } for (int k = 0; k < N; ++k) { - const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; + const Dtype element_i = (i == D) ? Dtype(1) : data.cpu_data()[k * D + i]; grad -= element_i * targets.cpu_data()[k]; } // Scale the gradient over the N samples. @@ -343,8 +343,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { const Dtype val_m = (1 - momentum) * grad + momentum * m; const Dtype val_v = (1 - momentum2) * grad * grad + momentum2 * v; Dtype alpha_t = learning_rate - * std::sqrt(Dtype(1) - pow(momentum2, num_iters)) - / (Dtype(1.) - pow(momentum, num_iters)); + * std::sqrt(Dtype(1) - pow(momentum2, Dtype(num_iters))) + / (Dtype(1.) - pow(momentum, Dtype(num_iters))); update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_); } else { LOG(FATAL)<< "Unknown solver type: " << solver_->type(); diff --git a/src/caffe/test/test_image_data_layer.cpp b/src/caffe/test/test_image_data_layer.cpp index 8ba16f1bbcc..7796cafd683 100644 --- a/src/caffe/test/test_image_data_layer.cpp +++ b/src/caffe/test/test_image_data_layer.cpp @@ -177,7 +177,7 @@ TYPED_TEST(ImageDataLayerTest, TestShuffle) { for (int_tp i = 0; i < 5; ++i) { Dtype value = this->blob_top_label_->cpu_data()[i]; // Check that the value has not been seen already (no duplicates). - EXPECT_EQ(values_to_indices.find(value), values_to_indices.end()); + EXPECT_EQ(values_to_indices.find(value) == values_to_indices.end(), true); values_to_indices[value] = i; num_in_order += (value == Dtype(i)); } diff --git a/src/caffe/test/test_inner_product_layer.cpp b/src/caffe/test/test_inner_product_layer.cpp index e54aabf1dc5..f9b37b32a19 100644 --- a/src/caffe/test/test_inner_product_layer.cpp +++ b/src/caffe/test/test_inner_product_layer.cpp @@ -137,7 +137,7 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { FillerParameter filler_param; UniformFiller filler(filler_param); caffe::Caffe::SetDevice(0); - + #if 0 for(auto i = 1; i <= 64; i*=2) { Blob* const blob_bottom = new Blob(i, 392, 8, 8); Blob* const blob_top = new Blob(); @@ -167,13 +167,16 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { int_tp N = layer->blobs()[0]->shape(0); int_tp K = layer->blobs()[0]->shape(1); - caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, (Dtype)1., A, B, (Dtype)0., C); + if (!std::is_same::value || i <= 2) { + caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, (Dtype)1., A, B, (Dtype)0., C); - const Dtype* data = blob_top->cpu_data(); - const int_tp count = blob_top->count(); - for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], C[i], 1e-1); + const Dtype* data = blob_top->cpu_data(); + const int_tp count = blob_top->count(); + for (int_tp i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], C[i], 1e-1); + } } + if (Caffe::mode() == Caffe::GPU) { Timer timer; timer.initted(); @@ -193,6 +196,7 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { delete blob_bottom; delete blob_top; } + #endif } TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { @@ -200,7 +204,7 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { FillerParameter filler_param; UniformFiller filler(filler_param); caffe::Caffe::SetDevice(0); - +#if 0 for(auto i = 1; i <= 64; i*=2) { Blob* const blob_bottom = new Blob(i, 25088+1, 1, 1); Blob* const blob_top = new Blob(); @@ -230,14 +234,17 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { int_tp N = layer->blobs()[0]->shape(0); int_tp K = layer->blobs()[0]->shape(1); - caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, (Dtype)1., A, B, (Dtype)0., C); + if (!std::is_same::value || i <= 2) { + caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, (Dtype)1., A, B, (Dtype)0., C); - const Dtype* data = blob_top->cpu_data(); - const int_tp count = blob_top->count(); - std::cout << blob_top->count() << std::endl; - for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], C[i], 1e-1); + const Dtype* data = blob_top->cpu_data(); + const int_tp count = blob_top->count(); + std::cout << blob_top->count() << std::endl; + for (int_tp i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], C[i], 1e-1); + } } + if (Caffe::mode() == Caffe::GPU) { Timer timer; timer.initted(); @@ -257,6 +264,7 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { delete blob_bottom; delete blob_top; } +#endif } template diff --git a/src/caffe/test/test_lrn_layer.cpp b/src/caffe/test/test_lrn_layer.cpp index 8967d8efd28..67d02149647 100644 --- a/src/caffe/test/test_lrn_layer.cpp +++ b/src/caffe/test/test_lrn_layer.cpp @@ -39,6 +39,8 @@ class LRNLayerTest : public MultiDeviceTest { filler.Fill(this->blob_bottom_); blob_bottom_vec_.push_back(blob_bottom_); blob_top_vec_.push_back(blob_top_); + if (std::is_same::value) + epsilon_ = 5e-2; } virtual ~LRNLayerTest() { delete blob_bottom_; delete blob_top_; } void ReferenceLRNForward(const Blob& blob_bottom, diff --git a/src/caffe/test/test_lstm_layer.cpp b/src/caffe/test/test_lstm_layer.cpp index d68ed10666f..ba7a78c4e6d 100644 --- a/src/caffe/test/test_lstm_layer.cpp +++ b/src/caffe/test/test_lstm_layer.cpp @@ -141,7 +141,8 @@ TYPED_TEST(LSTMLayerTest, TestForward) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); const int bottom_count = this->blob_bottom_.count(); const int top_count = this->blob_top_.count(); - const Dtype kEpsilon = 1e-5; + const Dtype kEpsilon = 1e-5 + * std::is_same::value ? 100 : 1; for (int t = 0; t < kNumTimesteps; ++t) { caffe_cpu_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count, this->blob_bottom_.mutable_cpu_data()); diff --git a/src/caffe/test/test_math_functions.cpp b/src/caffe/test/test_math_functions.cpp index 384e3ea8e56..29df2146cab 100644 --- a/src/caffe/test/test_math_functions.cpp +++ b/src/caffe/test/test_math_functions.cpp @@ -134,10 +134,15 @@ TYPED_TEST_CASE(GPUMathFunctionsTest, TestDtypes); TYPED_TEST(GPUMathFunctionsTest, TestAsum) { int_tp n = this->blob_bottom_->count(); + TypeParam precision = 0.01; + if (std::is_same::value) { + n = 512; + precision = 0.1; + } const TypeParam* x = this->blob_bottom_->cpu_data(); TypeParam std_asum = 0; for (int_tp i = 0; i < n; ++i) { - std_asum += std::fabs(x[i]); + std_asum += fabs(x[i]); } TypeParam gpu_asum; @@ -153,7 +158,7 @@ TYPED_TEST(GPUMathFunctionsTest, TestAsum) { (cl_mem)(this->blob_bottom_->gpu_data()), 0, &gpu_asum); #endif // USE_GREENTEA } - EXPECT_LT((gpu_asum - std_asum) / std_asum, 1e-2); + EXPECT_LT((gpu_asum - std_asum) / std_asum, precision); } TYPED_TEST(GPUMathFunctionsTest, TestSign) { diff --git a/src/caffe/test/test_mvn_layer.cpp b/src/caffe/test/test_mvn_layer.cpp index 377b2b85357..f522ea03324 100644 --- a/src/caffe/test/test_mvn_layer.cpp +++ b/src/caffe/test/test_mvn_layer.cpp @@ -60,7 +60,8 @@ TYPED_TEST(MVNLayerTest, TestForward) { sum /= height * width; var /= height * width; - const Dtype kErrorBound = 0.001; + const Dtype kErrorBound = std::is_same::value ? + 1e-1 : 1e-3; // expect zero mean EXPECT_NEAR(0, sum, kErrorBound); // expect unit variance @@ -95,7 +96,8 @@ TYPED_TEST(MVNLayerTest, TestForwardMeanOnly) { } sum /= height * width; - const Dtype kErrorBound = 0.001; + const Dtype kErrorBound = std::is_same::value ? + 1e-1 : 1e-3; // expect zero mean EXPECT_NEAR(0, sum, kErrorBound); } @@ -130,7 +132,8 @@ TYPED_TEST(MVNLayerTest, TestForwardAcrossChannels) { sum /= height * width * channels; var /= height * width * channels; - const Dtype kErrorBound = 0.001; + const Dtype kErrorBound = std::is_same::value ? + 1e-1 : 1e-3; // expect zero mean EXPECT_NEAR(0, sum, kErrorBound); // expect unit variance diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp index 549f0db19e7..bbed79f29cc 100644 --- a/src/caffe/test/test_net.cpp +++ b/src/caffe/test/test_net.cpp @@ -964,7 +964,8 @@ TYPED_TEST(NetTest, TestLossWeight) { Caffe::set_random_seed(this->seed_, Caffe::GetDefaultDevice()); this->InitUnsharedWeightsNet(&kLossWeights[i], NULL, kForceBackward); const Dtype weighted_loss = this->net_->ForwardBackward(); - const Dtype error_margin = kErrorMargin * fabs(kLossWeights[i]); + const Dtype error_margin = kErrorMargin * fabs(kLossWeights[i]) + * std::is_same::value ? 100 : 1; EXPECT_NEAR(loss * kLossWeights[i], weighted_loss, error_margin) << "loss weight = " << kLossWeights[i]; const vector > >& weighted_blobs = @@ -1006,7 +1007,8 @@ TYPED_TEST(NetTest, TestLossWeightMidNet) { // Check that the loss is non-trivial, otherwise the test doesn't prove much. const Dtype kMinLossAbsValue = 1e-2; ASSERT_GE(fabs(loss), kMinLossAbsValue); - const Dtype kErrorMargin = 1e-4; + const Dtype kErrorMargin = 1e-4 + * std::is_same::value ? 100 : 1; const int_tp kNumLossWeights = 6; Dtype kLossWeights[kNumLossWeights] = {2, 0, 1, -1, -2.5, 3.7}; for (int_tp i = 0; i < kNumLossWeights; ++i) { @@ -1032,7 +1034,8 @@ TYPED_TEST(NetTest, TestComboLossWeight) { Dtype loss_weight; Dtype midnet_loss_weight; const bool kForceBackward = true; - const Dtype kErrorMargin = 1e-4; + const Dtype kErrorMargin = 1e-4 + * is_same::value ? 100 : 1; // Get the loss and gradients with 'EuclideanLoss' weight 1, // 'InnerProduct' weight 1. @@ -1274,12 +1277,17 @@ TYPED_TEST(NetTest, TestSharedWeightsUpdate) { unshared_params2.CopyFrom(*ip2_weights, copy_diff, reshape); unshared_params2.CopyFrom(*ip2_weights, !copy_diff, reshape); // Make sure the diffs are non-trivial and sum to the diff in the shared net. + + const Dtype kErrorMargin = 1e-4 + * is_same::value ? 1000 : 1; for (int_tp i = 0; i < count; ++i) { EXPECT_NE(0, ip1_weights->cpu_diff()[i]); EXPECT_NE(0, ip2_weights->cpu_diff()[i]); EXPECT_NE(ip1_weights->cpu_diff()[i], ip2_weights->cpu_diff()[i]); - EXPECT_FLOAT_EQ(ip1_weights->cpu_diff()[i] + ip2_weights->cpu_diff()[i], - shared_params.cpu_diff()[i]); + Dtype error_margin = kErrorMargin * + fabs(ip1_weights->cpu_diff()[i] + ip2_weights->cpu_diff()[i]); + EXPECT_NEAR(Dtype(ip1_weights->cpu_diff()[i] + ip2_weights->cpu_diff()[i]), + shared_params.cpu_diff()[i], error_margin); } caffe_axpy(count, Dtype(-1), ip1_weights->cpu_diff(), unshared_params1.mutable_cpu_data()); diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp index 060460c5062..45f7fe4cf77 100644 --- a/src/caffe/test/test_neuron_layer.cpp +++ b/src/caffe/test/test_neuron_layer.cpp @@ -79,7 +79,9 @@ class NeuronLayerTest : public MultiDeviceTest { EXPECT_EQ(top_data[i], bottom_data[i] * scale); } } - const Dtype std_error = sqrt(dropout_ratio * (1 - dropout_ratio) / count); + const Dtype std_error = sqrt(dropout_ratio * (1 - dropout_ratio) / count) * + std::is_same::value ? + 10 : 1; // Fail if the number dropped was more than 1.96 * std_error away from the // expected number -- requires 95% confidence that the dropout layer is not // obeying the given dropout_ratio for test failure. @@ -95,7 +97,11 @@ class NeuronLayerTest : public MultiDeviceTest { ExpLayer layer(layer_param); layer.SetUp(blob_bottom_vec_, blob_top_vec_); layer.Forward(blob_bottom_vec_, blob_top_vec_); - const Dtype kDelta = 2e-2; + Dtype kDelta; + if (!std::is_same::value) + kDelta = 2e-2; + else + kDelta = 10; const Dtype* bottom_data = blob_bottom_->cpu_data(); const Dtype* top_data = blob_top_->cpu_data(); for (int_tp i = 0; i < blob_bottom_->count(); ++i) { @@ -130,7 +136,7 @@ class NeuronLayerTest : public MultiDeviceTest { bool channel_shared = layer->layer_param().prelu_param().channel_shared(); for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { int_tp c = channel_shared ? 0 : (i / hw) % channels; - EXPECT_EQ(top_data[i], + EXPECT_FLOAT_EQ(top_data[i], std::max(bottom_data[i], (Dtype)(0)) + slope_data[c] * std::min(bottom_data[i], (Dtype)(0))); } @@ -153,7 +159,11 @@ class NeuronLayerTest : public MultiDeviceTest { LogLayer layer(layer_param); layer.SetUp(blob_bottom_vec_, blob_top_vec_); layer.Forward(blob_bottom_vec_, blob_top_vec_); - const Dtype kDelta = 2e-3; + Dtype kDelta; + if (!std::is_same::value) + kDelta = 2e-3; + else + kDelta = 2e-1; const Dtype* bottom_data = blob_bottom_->cpu_data(); const Dtype* top_data = blob_top_->cpu_data(); for (int_tp i = 0; i < blob_bottom_->count(); ++i) { @@ -268,7 +278,11 @@ TYPED_TEST(NeuronLayerTest, TestELU) { ELULayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); - const Dtype kDelta = 2e-4; + Dtype kDelta; + if (!std::is_same::value) + kDelta = 2e-4; + else + kDelta = 2e-2; // Now, check values const Dtype* bottom_data = this->blob_bottom_->cpu_data(); const Dtype* top_data = this->blob_top_->cpu_data(); @@ -327,7 +341,12 @@ TYPED_TEST(NeuronLayerTest, TestSigmoid) { // Now, check values const Dtype* bottom_data = this->blob_bottom_->cpu_data(); const Dtype* top_data = this->blob_top_->cpu_data(); - const Dtype kDelta = 2e-3; + Dtype kDelta; + if (!std::is_same::value) + kDelta = 2e-3; + else + kDelta = 2e-1; + for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { EXPECT_NEAR(top_data[i], 1. / (1 + exp(-bottom_data[i])), kDelta); // check that we squashed the value between 0 and 1 @@ -356,10 +375,10 @@ TYPED_TEST(NeuronLayerTest, TestTanH) { for (int_tp j = 0; j < this->blob_bottom_->channels(); ++j) { for (int_tp k = 0; k < this->blob_bottom_->height(); ++k) { for (int_tp l = 0; l < this->blob_bottom_->width(); ++l) { - EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4, + EXPECT_GE(Dtype(this->blob_top_->data_at(i, j, k, l) + 1e-4), (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) / (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1)); - EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4, + EXPECT_LE(Dtype(this->blob_top_->data_at(i, j, k, l) - 1e-4), (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) / (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1)); } @@ -618,7 +637,7 @@ TYPED_TEST(NeuronLayerTest, TestBNLL) { const Dtype* bottom_data = this->blob_bottom_->cpu_data(); const Dtype* top_data = this->blob_top_->cpu_data(); for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { - EXPECT_GE(top_data[i], 0.); + EXPECT_GE(top_data[i], Dtype(0.)); EXPECT_GE(top_data[i], bottom_data[i]); } } diff --git a/src/caffe/test/test_pooling_layer.cpp b/src/caffe/test/test_pooling_layer.cpp index 102a44109c6..01054667998 100644 --- a/src/caffe/test/test_pooling_layer.cpp +++ b/src/caffe/test/test_pooling_layer.cpp @@ -560,7 +560,8 @@ TYPED_TEST(PoolingLayerTest, TestForwardAve) { EXPECT_EQ(this->blob_top_->height(), 3); EXPECT_EQ(this->blob_top_->width(), 3); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); - Dtype epsilon = 1e-5; + Dtype epsilon = std::is_same::value ? + 1e-3 : 1e-5; EXPECT_NEAR(this->blob_top_->cpu_data()[0], 8.0 / 9, epsilon); EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4.0 / 3, epsilon); EXPECT_NEAR(this->blob_top_->cpu_data()[2], 8.0 / 9, epsilon); diff --git a/src/caffe/test/test_power_layer.cpp b/src/caffe/test/test_power_layer.cpp index 61e1d4d9288..fe68870ed65 100644 --- a/src/caffe/test/test_power_layer.cpp +++ b/src/caffe/test/test_power_layer.cpp @@ -42,7 +42,10 @@ class PowerLayerTest : public MultiDeviceTest { // Now, check values const Dtype* bottom_data = this->blob_bottom_->cpu_data(); const Dtype* top_data = this->blob_top_->cpu_data(); - const Dtype min_precision = 1e-5; + const Dtype min_precision = std::is_same::value ? + 1e-3 : 1e-5; + const Dtype precision_factor = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { Dtype expected_value = pow(shift + scale * bottom_data[i], power); if (power == Dtype(0) || power == Dtype(1) || power == Dtype(2)) { @@ -52,7 +55,7 @@ class PowerLayerTest : public MultiDeviceTest { EXPECT_TRUE(isnan(top_data[i])); } else { Dtype precision = std::max( - Dtype(std::abs(expected_value * Dtype(1e-4))), min_precision); + Dtype(std::abs(expected_value * precision_factor)), min_precision); EXPECT_NEAR(expected_value, top_data[i], precision); } } diff --git a/src/caffe/test/test_random_number_generator.cpp b/src/caffe/test/test_random_number_generator.cpp index c81c27f61cd..6ac4b06b3b2 100644 --- a/src/caffe/test/test_random_number_generator.cpp +++ b/src/caffe/test/test_random_number_generator.cpp @@ -325,8 +325,8 @@ TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussianTimesGaussian) { } // Check that result has mean 0. - TypeParam mu_product = pow(mu, 2); - TypeParam sigma_product = sqrt(pow(sigma, 2) / 2); + TypeParam mu_product = pow(mu, TypeParam(2)); + TypeParam sigma_product = sqrt(TypeParam(pow(sigma, TypeParam(2)) / 2)); this->RngGaussianChecks(mu_product, sigma_product, gaussian_data_1); } @@ -526,8 +526,8 @@ TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussianTimesGaussianGPU) { // Check that result does not violate checked properties of Gaussian // (though it is not actually a Gaussian). - TypeParam mu_product = pow(mu, 2); - TypeParam sigma_product = sqrt(pow(sigma, 2) / 2); + TypeParam mu_product = pow(mu, TypeParam(2)); + TypeParam sigma_product = sqrt(TypeParam(pow(sigma, TypeParam(2)) / 2)); this->RngGaussianChecks(mu_product, sigma_product, gaussian_data_1); } diff --git a/src/caffe/test/test_scale_layer.cpp b/src/caffe/test/test_scale_layer.cpp index 69c7643dbaf..cf3f538fbd3 100644 --- a/src/caffe/test/test_scale_layer.cpp +++ b/src/caffe/test/test_scale_layer.cpp @@ -85,8 +85,11 @@ TYPED_TEST(ScaleLayerTest, TestForwardEltwise) { const int_tp count = this->blob_top_->count(); const Dtype* in_data_a = this->blob_bottom_->cpu_data(); const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i], 1e-5); + EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i], + delta * fabs(in_data_a[i] * in_data_b[i])); } } @@ -105,8 +108,11 @@ TYPED_TEST(ScaleLayerTest, TestForwardEltwiseInPlace) { const int_tp count = this->blob_bottom_->count(); const Dtype* in_data_a = orig_bottom.cpu_data(); const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i], 1e-5); + EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i], + delta * fabs(in_data_a[i] * in_data_b[i])); } } @@ -146,13 +152,17 @@ TYPED_TEST(ScaleLayerTest, TestBackwardEltwiseInPlace) { caffe_copy(top_diff.count(), top_diff.cpu_data(), this->blob_bottom_->mutable_cpu_diff()); layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], - this->blob_bottom_->cpu_diff()[i], 1e-5); + this->blob_bottom_->cpu_diff()[i], + delta * fabs(orig_bottom_diff.cpu_diff()[i])); } for (int_tp i = 0; i < this->blob_bottom_eltwise_->count(); ++i) { EXPECT_NEAR(orig_scale_diff.cpu_diff()[i], - this->blob_bottom_eltwise_->cpu_diff()[i], 1e-5); + this->blob_bottom_eltwise_->cpu_diff()[i], + delta * fabs(orig_scale_diff.cpu_diff()[i])); } } @@ -171,8 +181,11 @@ TYPED_TEST(ScaleLayerTest, TestForwardEltwiseWithParam) { const int_tp count = this->blob_top_->count(); const Dtype* in_data_a = this->blob_bottom_->cpu_data(); const Dtype* in_data_b = layer->blobs()[0]->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i], 1e-5); + EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i], + delta * fabs(in_data_a[i] * in_data_b[i])); } } @@ -185,6 +198,8 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastBegin) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { @@ -192,7 +207,7 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastBegin) { EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_->data_at(n, c, h, w) * this->blob_bottom_broadcast_0_->data_at(n, c, 0, 0), - 1e-5); + delta * fabs(this->blob_top_->data_at(n, c, h, w))); } } } @@ -208,6 +223,8 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastMiddle) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { @@ -215,7 +232,7 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastMiddle) { EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_->data_at(n, c, h, w) * this->blob_bottom_broadcast_1_->data_at(c, h, 0, 0), - 1e-5); + delta * fabs(this->blob_top_->data_at(n, c, h, w))); } } } @@ -233,6 +250,8 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastMiddleInPlace) { shared_ptr > layer(new ScaleLayer(layer_param)); layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { @@ -240,7 +259,7 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastMiddleInPlace) { EXPECT_NEAR(this->blob_bottom_->data_at(n, c, h, w), orig_bottom.data_at(n, c, h, w) * this->blob_bottom_broadcast_1_->data_at(c, h, 0, 0), - 1e-5); + delta * fabs(this->blob_bottom_->data_at(n, c, h, w))); } } } @@ -283,13 +302,15 @@ TYPED_TEST(ScaleLayerTest, TestBackwardBroadcastMiddleInPlace) { caffe_copy(top_diff.count(), top_diff.cpu_data(), this->blob_bottom_->mutable_cpu_diff()); layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], - this->blob_bottom_->cpu_diff()[i], 1e-5); + this->blob_bottom_->cpu_diff()[i], delta); } for (int_tp i = 0; i < this->blob_bottom_broadcast_1_->count(); ++i) { EXPECT_NEAR(orig_scale_diff.cpu_diff()[i], - this->blob_bottom_broadcast_1_->cpu_diff()[i], 1e-5); + this->blob_bottom_broadcast_1_->cpu_diff()[i], delta); } } @@ -304,13 +325,16 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastMiddleWithParam) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { for (int_tp w = 0; w < this->blob_bottom_->width(); ++w) { EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_->data_at(n, c, h, w) * - layer->blobs()[0]->data_at(c, h, 0, 0), 1e-5); + layer->blobs()[0]->data_at(c, h, 0, 0), + delta * fabs(this->blob_top_->data_at(n, c, h, w))); } } } @@ -330,6 +354,8 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastMiddleWithParamAndBias) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 5e-3 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { @@ -337,7 +363,8 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastMiddleWithParamAndBias) { EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_->data_at(n, c, h, w) * layer->blobs()[0]->data_at(c, h, 0, 0) + - layer->blobs()[1]->data_at(c, h, 0, 0), 1e-5); + layer->blobs()[1]->data_at(c, h, 0, 0), + delta * fabs(this->blob_top_->data_at(n, c, h, w))); } } } @@ -353,6 +380,8 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastEnd) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { @@ -360,7 +389,7 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastEnd) { EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_->data_at(n, c, h, w) * this->blob_bottom_broadcast_2_->data_at(h, w, 0, 0), - 1e-5); + delta * fabs(this->blob_top_->data_at(n, c, h, w))); } } } @@ -379,8 +408,10 @@ TYPED_TEST(ScaleLayerTest, TestForwardScale) { const int_tp count = this->blob_top_->count(); const Dtype* in_data = this->blob_bottom_->cpu_data(); const Dtype scale = *this->blob_bottom_scale_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data[i] * scale, 1e-5); + EXPECT_NEAR(data[i], in_data[i] * scale, delta * fabs(data[i])); } } @@ -397,8 +428,10 @@ TYPED_TEST(ScaleLayerTest, TestForwardScaleAxis2) { const int_tp count = this->blob_top_->count(); const Dtype* in_data = this->blob_bottom_->cpu_data(); const Dtype scale = *this->blob_bottom_scale_->cpu_data(); + const Dtype delta = std::is_same::value ? + 1e-3 : 1e-5; for (int_tp i = 0; i < count; ++i) { - EXPECT_NEAR(data[i], in_data[i] * scale, 1e-5); + EXPECT_NEAR(data[i], in_data[i] * scale, delta * fabs(data[i])); } } diff --git a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp index 3ff96c4ee70..b3aa378d3e6 100644 --- a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp +++ b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp @@ -70,7 +70,8 @@ class SigmoidCrossEntropyLossLayerTest : public MultiDeviceTest { targets_filler_param.set_min(0.0); targets_filler_param.set_max(1.0); UniformFiller targets_filler(targets_filler_param); - Dtype eps = 2e-2; + Dtype eps = std::is_same::value ? + 2e-1 : 2e-2; for (int_tp i = 0; i < 100; ++i) { // Fill the data vector data_filler.Fill(this->blob_bottom_data_); @@ -87,7 +88,8 @@ class SigmoidCrossEntropyLossLayerTest : public MultiDeviceTest { this->blob_bottom_targets_->cpu_data(); Dtype reference_loss = kLossWeight * SigmoidCrossEntropyLossReference( count, num, blob_bottom_data, blob_bottom_targets); - EXPECT_NEAR(reference_loss, layer_loss, eps) << "debug: trial #" << i; + EXPECT_NEAR(reference_loss, layer_loss, eps * reference_loss) + << "debug: trial #" << i; } } diff --git a/src/caffe/test/test_softmax_layer.cpp b/src/caffe/test/test_softmax_layer.cpp index f988097bde9..b567afdba8f 100644 --- a/src/caffe/test/test_softmax_layer.cpp +++ b/src/caffe/test/test_softmax_layer.cpp @@ -61,12 +61,15 @@ TYPED_TEST(SoftmaxLayerTest, TestForward) { for (int_tp j = 0; j < this->blob_bottom_->channels(); ++j) { scale += exp(this->blob_bottom_->data_at(i, j, k, l)); } + + const Dtype delta = std::is_same::value ? + 1e-2 : 1e-4; for (int_tp j = 0; j < this->blob_bottom_->channels(); ++j) { - EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4, - exp(this->blob_bottom_->data_at(i, j, k, l)) / scale) + EXPECT_GE(Dtype(this->blob_top_->data_at(i, j, k, l) + delta), + Dtype(exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)) << "debug: " << i << " " << j; - EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4, - exp(this->blob_bottom_->data_at(i, j, k, l)) / scale) + EXPECT_LE(Dtype(this->blob_top_->data_at(i, j, k, l) - delta), + Dtype(exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)) << "debug: " << i << " " << j; } } @@ -128,11 +131,11 @@ TYPED_TEST(CuDNNSoftmaxLayerTest, TestForwardCuDNN) { scale += exp(this->blob_bottom_->data_at(i, j, k, l)); } for (int_tp j = 0; j < this->blob_bottom_->channels(); ++j) { - EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4, - exp(this->blob_bottom_->data_at(i, j, k, l)) / scale) + EXPECT_GE(Dtype(this->blob_top_->data_at(i, j, k, l) + 1e-4), + Dtype(exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)) << "debug: " << i << " " << j; - EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4, - exp(this->blob_bottom_->data_at(i, j, k, l)) / scale) + EXPECT_LE(Dtype(this->blob_top_->data_at(i, j, k, l) - 1e-4), + Dtype(exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)) << "debug: " << i << " " << j; } } diff --git a/src/caffe/test/test_softmax_with_loss_layer.cpp b/src/caffe/test/test_softmax_with_loss_layer.cpp index 34402d90438..728a0a33894 100644 --- a/src/caffe/test/test_softmax_with_loss_layer.cpp +++ b/src/caffe/test/test_softmax_with_loss_layer.cpp @@ -81,7 +81,9 @@ TYPED_TEST(SoftmaxWithLossLayerTest, TestForwardIgnoreLabel) { accum_loss += this->blob_top_loss_->cpu_data()[0]; } // Check that each label was included all but once. - EXPECT_NEAR(4 * full_loss, accum_loss, 1e-4); + Dtype delta = std::is_same::value ? + 1e-1 : 1e-4; + EXPECT_NEAR(4 * full_loss, accum_loss, delta * fabs(accum_loss)); } TYPED_TEST(SoftmaxWithLossLayerTest, TestGradientIgnoreLabel) { diff --git a/src/caffe/test/test_syncedmem.cpp b/src/caffe/test/test_syncedmem.cpp index ba09895f94d..f3db2fc6621 100644 --- a/src/caffe/test/test_syncedmem.cpp +++ b/src/caffe/test/test_syncedmem.cpp @@ -1,12 +1,12 @@ #include -#include "gtest/gtest.h" #include "caffe/common.hpp" #include "caffe/syncedmem.hpp" #include "caffe/util/device_alternate.hpp" #include "caffe/util/math_functions.hpp" +#include "gtest/gtest.h" #include "caffe/test/test_caffe_main.hpp" #ifdef USE_GREENTEA diff --git a/src/caffe/test/test_tanh_layer.cpp b/src/caffe/test/test_tanh_layer.cpp index 5b42f211a5a..4568fd08208 100644 --- a/src/caffe/test/test_tanh_layer.cpp +++ b/src/caffe/test/test_tanh_layer.cpp @@ -55,11 +55,16 @@ class TanHLayerTest : public MultiDeviceTest { // Now, check values const Dtype* bottom_data = this->blob_bottom_->cpu_data(); const Dtype* top_data = this->blob_top_->cpu_data(); - const Dtype min_precision = 1e-5; + Dtype min_precision = 1e-5; + Dtype precision_factor = 1e-4; + if (std::is_same::value) { + min_precision = 100. * min_precision; + precision_factor = 100. * precision_factor; + } for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { Dtype expected_value = tanh_naive(bottom_data[i]); Dtype precision = std::max( - Dtype(std::abs(expected_value * Dtype(1e-4))), min_precision); + Dtype(std::abs(expected_value * Dtype(precision_factor))), min_precision); EXPECT_NEAR(expected_value, top_data[i], precision); } } diff --git a/src/caffe/util/blocking_queue.cpp b/src/caffe/util/blocking_queue.cpp index 333f36508b0..cf45c2af315 100644 --- a/src/caffe/util/blocking_queue.cpp +++ b/src/caffe/util/blocking_queue.cpp @@ -84,7 +84,9 @@ uint_tp BlockingQueue::size() const { boost::mutex::scoped_lock lock(sync_->mutex_); return queue_.size(); } - +#ifdef HAS_HALF_SUPPORT +template class BlockingQueue*>; +#endif template class BlockingQueue*>; template class BlockingQueue*>; diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp index 88c9a0c76e6..b38f891d15f 100755 --- a/src/caffe/util/hdf5.cpp +++ b/src/caffe/util/hdf5.cpp @@ -83,6 +83,19 @@ void hdf5_load_nd_dataset_helper( } } +#ifdef HAS_HALF_SUPPORT +template <> +void hdf5_load_nd_dataset(hid_t file_id, const char* dataset_name_, + int min_dim, int max_dim, Blob* blob, bool reshape) { + // FIXME + hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob, + reshape); + herr_t status = H5LTread_dataset_short( + file_id, dataset_name_, (short*)blob->mutable_cpu_data()); + CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_; +} +#endif + template <> void hdf5_load_nd_dataset(hid_t file_id, const char* dataset_name_, int min_dim, int max_dim, Blob* blob, bool reshape) { @@ -103,6 +116,31 @@ void hdf5_load_nd_dataset(hid_t file_id, const char* dataset_name_, CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_; } +#ifdef HAS_HALF_SUPPORT +template <> +void hdf5_save_nd_dataset( + const hid_t file_id, const string& dataset_name, const Blob& blob, + bool write_diff) { + //FIXME + int_tp num_axes = blob.num_axes(); + hsize_t *dims = new hsize_t[num_axes]; + for (int_tp i = 0; i < num_axes; ++i) { + dims[i] = blob.shape(i); + } + const half* data; + if (write_diff) { + data = blob.cpu_diff(); + } else { + data = blob.cpu_data(); + } + herr_t status = H5LTmake_dataset_short( + file_id, dataset_name.c_str(), num_axes, dims, (const short*)(data)); + CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name; + delete[] dims; + +} +#endif + template <> void hdf5_save_nd_dataset( const hid_t file_id, const string& dataset_name, const Blob& blob, diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index 12698268846..77879e6ad9a 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -55,6 +55,15 @@ void im2col_cpu(const Dtype* data_im, const int_tp channels, } // Explicit instantiation +#ifdef HAS_HALF_SUPPORT +template void im2col_cpu(const half* data_im, const int_tp channels, + const int_tp height, const int_tp width, + const int_tp kernel_h, const int_tp kernel_w, + const int_tp pad_h, const int_tp pad_w, + const int_tp stride_h, const int_tp stride_w, + const int_tp dilation_h, + const int_tp dilation_w, half* data_col); +#endif template void im2col_cpu(const float* data_im, const int_tp channels, const int_tp height, const int_tp width, const int_tp kernel_h, const int_tp kernel_w, @@ -155,6 +164,15 @@ void im2col_nd_cpu(const Dtype* data_im, const int_tp num_spatial_axes, } // Explicit instantiation +#ifdef HAS_HALF_SUPPORT +template void im2col_nd_cpu(const half* data_im, + const int_tp num_spatial_axes, + const int_tp* im_shape, + const int_tp* col_shape, + const int_tp* kernel_shape, + const int_tp* pad, const int_tp* stride, + const int_tp* dilation, half* data_col); +#endif template void im2col_nd_cpu(const float* data_im, const int_tp num_spatial_axes, const int_tp* im_shape, @@ -208,6 +226,15 @@ void col2im_cpu(const Dtype* data_col, const int_tp channels, } // Explicit instantiation +#ifdef HAS_HALF_SUPPORT +template void col2im_cpu(const half* data_col, const int_tp channels, + const int_tp height, const int_tp width, + const int_tp kernel_h, const int_tp kernel_w, + const int_tp pad_h, const int_tp pad_w, + const int_tp stride_h, const int_tp stride_w, + const int_tp dilation_h, + const int_tp dilation_w, half* data_im); +#endif template void col2im_cpu(const float* data_col, const int_tp channels, const int_tp height, const int_tp width, const int_tp kernel_h, const int_tp kernel_w, @@ -235,6 +262,15 @@ void col2im_nd_cpu(const Dtype* data_col, const int_tp num_spatial_axes, } // Explicit instantiation +#ifdef HAS_HALF_SUPPORT +template void col2im_nd_cpu(const half* data_col, + const int_tp num_spatial_axes, + const int_tp* im_shape, + const int_tp* col_shape, + const int_tp* kernel_shape, + const int_tp* pad, const int_tp* stride, + const int_tp* dilation, half* data_im); +#endif template void col2im_nd_cpu(const float* data_col, const int_tp num_spatial_axes, const int_tp* im_shape, diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index e129b94aba7..39501374327 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -9,6 +9,295 @@ namespace caffe { +#ifdef HAS_HALF_SUPPORT +template<> +void caffe_add_scalar(const int_tp N, const half alpha, half* Y) { + for (int_tp i = 0; i < N; ++i) { + Y[i] += alpha; + } +} + +template<> +void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int_tp M, + const int_tp N, const int_tp K, const half alpha, + const half* A, const half* B, const half beta, + half* C) { + int_tp inc_a = (TransA == CblasNoTrans) ? 1 : M; + int_tp inc_b = (TransB == CblasNoTrans) ? N : 1; + for (int_tp m = 0; m < M; m++) { + for (int_tp n = 0; n < N; n++) { + half acc = 0; + int_tp b_index = TransB == CblasNoTrans ? + n : K * n; + int_tp a_index = TransA == CblasNoTrans ? + K * m : m; + for (int_tp k = 0; k < K; k++) { + acc += A[a_index] * B[b_index]; + a_index += inc_a; + b_index += inc_b; + } + if (beta != 0) + C[m * N + n] = acc * alpha + beta * C[m * N + n]; + else + C[m * N + n] = acc * alpha; + } + } +} + +template<> +void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int_tp M, + const int_tp N, const half alpha, const half* A, + const half* x, const half beta, half* y) { + int_tp a_inc = (TransA == CblasNoTrans) ? 1 : N; + int_tp y_cnt = (TransA == CblasNoTrans) ? M : N; + int_tp x_cnt = (TransA == CblasNoTrans) ? N : M; + for (int_tp m = 0; m < y_cnt; m++) { + int_tp a_index = (TransA == CblasNoTrans) ? m * N : m; + half acc = 0; + for (int_tp n = 0; n < x_cnt; n++) { + acc += A[a_index] * x[n]; + a_index += a_inc; + } + if (beta == 0) + y[m] = acc * alpha; + else + y[m] = acc * alpha + beta * y[m]; + } +} + +template<> +void caffe_axpy(const int_tp N, const half alpha, const half* X, + half* Y) { + for (int_tp n = 0; n < N; n++) { + Y[n] += alpha * X[n]; + } +} + +template<> +void caffe_scal(const int_tp N, const half alpha, half *X) { + for (int_tp n = 0; n < N; n++) + X[n] *= alpha; +} + +template<> +void caffe_cpu_axpby(const int_tp N, const half alpha, const half* X, + const half beta, half* Y) { + cblas_haxpby(N, alpha, X, 1, beta, Y, 1); +} + +void vhAdd(const int_tp n, const half* a, const half* b, + half* y) { + for(int i = 0; i < n; i++) { + y[i] = a[i] + b[i]; + } +} + +template<> +void caffe_add(const int_tp n, const half* a, const half* b, + half* y) { + vhAdd(n, a, b, y); +} +void vhSub(const int_tp n, const half* a, const half* b, + half* y) { + for(int i = 0; i < n; i++) { + y[i] = a[i] - b[i]; + } +} + +template<> +void caffe_sub(const int_tp n, const half* a, const half* b, + half* y) { + vhSub(n, a, b, y); +} + +void vhMul(const int_tp n, const half* a, const half* b, + half* y) { + for(int i = 0; i < n; i++) { + y[i] = a[i] * b[i]; + } +} + +template<> +void caffe_mul(const int_tp n, const half* a, const half* b, + half* y) { + vhMul(n, a, b, y); +} + +void vhDiv(const int_tp n, const half* a, const half* b, + half* y) { + for(int i = 0; i < n; i++) { + y[i] = a[i] / b[i]; + } +} + +template<> +void caffe_div(const int_tp n, const half* a, const half* b, + half* y) { + vhDiv(n, a, b, y); +} + +void vhPowx(const int_tp n, const half*a, const half b, half* y) +{ + for( int i = 0; i < n; i++) + y[i] = pow(a[i], b); +} + +template<> +void caffe_powx(const int_tp n, const half* a, const half b, + half* y) { + vhPowx(n, a, b, y); +} + +void vhSqr(const int_tp n, const half *a, half* y) { + for(int i = 0; i < n; i++) { + y[i] = sqrt(a[i]); + } +} + +template<> +void caffe_sqr(const int_tp n, const half* a, half* y) { + vhSqr(n, a, y); +} + +void vhExp(const int_tp n, const half* a, half* y) { + for(int i = 0; i < n; i++) { + y[i] = exp(a[i]); + } +} + +template<> +void caffe_exp(const int_tp n, const half* a, half* y) { + vhExp(n, a, y); +} + +void vhLn(const int_tp n, const half* a, half* y) { + for(int i = 0; i < n; i++) { + y[i] = log(a[i]); + } +} + +template<> +void caffe_log(const int_tp n, const half* a, half* y) { + vhLn(n, a, y); +} + +void vhAbs(const int_tp n, const half *a, half* y) { + for(int i = 0; i < n; i++) { + y[i] = fabs(a[i]); + } +} + +template<> +void caffe_abs(const int_tp n, const half* a, half* y) { + vhAbs(n, a, y); +} + +void vsHqrt(const int_tp n, const half* a, half* y) { + for (int_tp i = 0; i < n; i++) { + y[i] = sqrt(a[i]); + } +} +template <> +void caffe_sqrt(const int_tp n, const half* a, half* y) { + vsHqrt(n, a, y); +} + +template<> +void caffe_rng_uniform(const int_tp n, const half a, const half b, half* r) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_LE(a, b); + boost::uniform_real random_distribution(float(a), caffe_nextafter(float(b))); + + boost::variate_generator> variate_generator( + caffe_rng(), random_distribution); + for (int_tp i = 0; i < n; ++i) { + r[i] = variate_generator(); + } +} + +template<> +void caffe_rng_gaussian(const int_tp n, const half a, const half sigma, + half* r) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GT(sigma, 0); + float fsigma = sigma; + float fa = a; + boost::normal_distribution random_distribution(fa, fsigma); + boost::variate_generator> variate_generator( + caffe_rng(), random_distribution); + for (int_tp i = 0; i < n; ++i) { + r[i] = variate_generator(); + } +} + +template<> +void caffe_rng_bernoulli(const int_tp n, const half p, + unsigned int* r) { + // FIXME + CHECK_GE(n, 0); + CHECK(r); + CHECK_GE(p, 0); + CHECK_LE(p, 1); + float f_p = p; + boost::bernoulli_distribution random_distribution(f_p); + boost::variate_generator> variate_generator( + caffe_rng(), random_distribution); + for (int_tp i = 0; i < n; ++i) { + //r[i] = static_cast(variate_generator()); + } +} +template<> +void caffe_rng_bernoulli(const int_tp n, const half p, + int* r) { + // FIXME + CHECK_GE(n, 0); + CHECK(r); + CHECK_GE(p, 0); + CHECK_LE(p, 1); + float f_p = p; + boost::bernoulli_distribution random_distribution(f_p); + boost::variate_generator> variate_generator( + caffe_rng(), random_distribution); + for (int_tp i = 0; i < n; ++i) { + //r[i] = static_cast(variate_generator()); + } +} + +template<> +void caffe_cpu_scale(const int_tp n, const half alpha, const half *x, + half* y) { + for (int_tp i = 0; i < n; i++) + y[i] = x[i]; + //cblas_hcopy(n, x, 1, y, 1); + caffe_scal(n, alpha, y); +} + +template<> +half caffe_cpu_strided_dot(const int_tp n, const half* x, + const int_tp incx, const half* y, + const int_tp incy) { + half sum = 0; + for (int_tp i = 0; i < n; i++) + sum += x[i * incx] * y[i * incy]; + return sum; +} + +template<> +half caffe_cpu_asum(const int_tp n, const half* x) { + half sum = 0; + for (int_tp i = 0; i < n; i++) + sum += fabs(x[i]); + return sum; +} +#endif + template<> void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int_tp M, @@ -47,6 +336,7 @@ void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int_tp M, cblas_dgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1); } + template<> void caffe_axpy(const int_tp N, const float alpha, const float* X, float* Y) { @@ -76,6 +366,9 @@ template void caffe_set(const int_tp N, const uint32_t alpha, template void caffe_set(const int_tp N, int64_t alpha, int64_t* Y); template void caffe_set(const int_tp N, const uint64_t alpha, uint64_t* Y); +#ifdef HAS_HALF_SUPPORT +template void caffe_set(const int_tp N, const half alpha, half* Y); +#endif template void caffe_set(const int_tp N, const float alpha, float* Y); template void caffe_set(const int_tp N, const double alpha, double* Y); @@ -104,6 +397,9 @@ template void caffe_cpu_copy(const int_tp N, const int_tp* X, int_tp* Y); template void caffe_cpu_copy(const int_tp N, const uint_tp* X, uint_tp* Y); +#ifdef HAS_HALF_SUPPORT +template void caffe_cpu_copy(const int_tp N, const half* X, half* Y); +#endif template void caffe_cpu_copy(const int_tp N, const float* X, float* Y); template void caffe_cpu_copy(const int_tp N, const double* X, double* Y); @@ -129,6 +425,9 @@ void caffe_copy(const int_tp N, const Dtype* X, Dtype* Y) { template void caffe_copy(const int_tp N, const int_tp* X, int_tp* Y); template void caffe_copy(const int_tp N, const uint_tp* X, uint_tp* Y); +#ifdef HAS_HALF_SUPPORT +template void caffe_copy(const int_tp N, const half* X, half* Y); +#endif template void caffe_copy(const int_tp N, const float* X, float* Y); template void caffe_copy(const int_tp N, const double* X, double* Y); @@ -141,7 +440,6 @@ template<> void caffe_scal(const int_tp N, const double alpha, double *X) { cblas_dscal(N, alpha, X, 1); } - template<> void caffe_cpu_axpby(const int_tp N, const float alpha, const float* X, const float beta, float* Y) { @@ -154,6 +452,7 @@ void caffe_cpu_axpby(const int_tp N, const double alpha, cblas_daxpby(N, alpha, X, 1, beta, Y, 1); } + template<> void caffe_add(const int_tp n, const float* a, const float* b, float* y) { @@ -399,12 +698,18 @@ Dtype caffe_cpu_dot(const int_tp n, const Dtype* x, const Dtype* y) { return caffe_cpu_strided_dot(n, x, 1, y, 1); } +#ifdef HAS_HALF_SUPPORT +template +half caffe_cpu_dot(const int_tp n, const half* x, const half* y); +#endif + template float caffe_cpu_dot(const int_tp n, const float* x, const float* y); template double caffe_cpu_dot(const int_tp n, const double* x, const double* y); + template<> float caffe_cpu_asum(const int_tp n, const float* x) { return cblas_sasum(n, x, 1); diff --git a/src/gtest/gtest-all.cpp b/src/gtest/gtest-all.cpp index 81cdb578cd5..e2c24d67c24 100644 --- a/src/gtest/gtest-all.cpp +++ b/src/gtest/gtest-all.cpp @@ -9111,6 +9111,55 @@ const char* TypedTestCasePState::VerifyRegisteredTestNames( return registered_tests; } +template <> +AssertionResult CmpHelperFloatingPointEQ(const char* expected_expression, + const char* actual_expression, + half expected, + half actual); +template <> +AssertionResult CmpHelperFloatingPointEQ(const char* expected_expression, + const char* actual_expression, + float expected, + float actual); +template <> +AssertionResult CmpHelperFloatingPointEQ(const char* expected_expression, + const char* actual_expression, + double expected, + double actual); + +template <> +AssertionResult CmpHelperFloatingPointLE(const char* expected_expression, + const char* actual_expression, + half expected, + half actual); +template <> +AssertionResult CmpHelperFloatingPointLE(const char* expected_expression, + const char* actual_expression, + float expected, + float actual); +template <> +AssertionResult CmpHelperFloatingPointLE(const char* expected_expression, + const char* actual_expression, + double expected, + double actual); + +template <> +AssertionResult CmpHelperFloatingPointGE(const char* expected_expression, + const char* actual_expression, + half expected, + half actual); +template <> +AssertionResult CmpHelperFloatingPointGE(const char* expected_expression, + const char* actual_expression, + float expected, + float actual); +template <> +AssertionResult CmpHelperFloatingPointGE(const char* expected_expression, + const char* actual_expression, + double expected, + double actual); + + #endif // GTEST_HAS_TYPED_TEST_P } // namespace internal diff --git a/src/gtest/gtest.h b/src/gtest/gtest.h index 124fb2321f9..abf6307bd13 100644 --- a/src/gtest/gtest.h +++ b/src/gtest/gtest.h @@ -51,6 +51,7 @@ #ifndef GTEST_INCLUDE_GTEST_GTEST_H_ #define GTEST_INCLUDE_GTEST_GTEST_H_ +#include "caffe/util/fp16.hpp" #include #include @@ -2766,6 +2767,18 @@ class TypeWithSize { // The specialization for size 4. template <> +class TypeWithSize<2> { + public: + // unsigned int has size 4 in both gcc and MSVC. + // + // As base/basictypes.h doesn't compile on Windows, we cannot use + // uint32, uint64, and etc here. + typedef int16_t Int; + typedef uint16_t UInt; +}; + +// The specialization for size 4. +template <> class TypeWithSize<4> { public: // unsigned int has size 4 in both gcc and MSVC. @@ -2791,6 +2804,8 @@ class TypeWithSize<8> { }; // Integer types of known sizes. +typedef TypeWithSize<2>::Int Int16; +typedef TypeWithSize<2>::UInt UInt16; typedef TypeWithSize<4>::Int Int32; typedef TypeWithSize<4>::UInt UInt32; typedef TypeWithSize<8>::Int Int64; @@ -7017,10 +7032,10 @@ class FloatingPoint { // The mask for the fraction bits. static const Bits kFractionBitMask = - ~static_cast(0) >> (kExponentBitCount + 1); + static_cast(~0ull) >> (kExponentBitCount + 1); // The mask for the exponent bits. - static const Bits kExponentBitMask = ~(kSignBitMask | kFractionBitMask); + static const Bits kExponentBitMask = static_cast(~(kSignBitMask | kFractionBitMask)); // How many ULP's (Units in the Last Place) we want to tolerate when // comparing two numbers. The larger the value, the more error we @@ -7042,7 +7057,8 @@ class FloatingPoint { // around may change its bits, although the new value is guaranteed // to be also a NAN. Therefore, don't expect this constructor to // preserve the bits in x when x is a NAN. - explicit FloatingPoint(const RawType& x) { u_.value_ = x; } + //explicit FloatingPoint(const RawType& x) { u_.bits_ = *(const Bits*)(&x); } + explicit FloatingPoint(const RawType x) { memcpy(&u_.bits_, &x, sizeof(RawType)); } // Static methods @@ -7050,9 +7066,11 @@ class FloatingPoint { // // This function is needed to test the AlmostEquals() method. static RawType ReinterpretBits(const Bits bits) { - FloatingPoint fp(0); - fp.u_.bits_ = bits; - return fp.u_.value_; + //FloatingPoint fp(0); + //fp.u_.bits_ = bits; + RawType val; + memcpy(&val, &bits, sizeof(RawType)); + return val; } // Returns the floating-point number that represent positive infinity. @@ -7091,15 +7109,23 @@ class FloatingPoint { // The IEEE standard says that any comparison operation involving // a NAN must return false. if (is_nan() || rhs.is_nan()) return false; + int distance = DistanceBetweenSignAndMagnitudeNumbers(u_.bits_, rhs.u_.bits_); + return distance <= kMaxUlps; + } - return DistanceBetweenSignAndMagnitudeNumbers(u_.bits_, rhs.u_.bits_) - <= kMaxUlps; + bool LT(const FloatingPoint&rhs) const { + if (is_nan() || rhs.is_nan()) return false; + return ReinterpretBits(u_.bits_) < ReinterpretBits(rhs.u_.bits_); + } + + bool GT(const FloatingPoint&rhs) const { + if (is_nan() || rhs.is_nan()) return false; + return ReinterpretBits(u_.bits_) > ReinterpretBits(rhs.u_.bits_); } private: // The data type used to store the actual floating-point number. union FloatingPointUnion { - RawType value_; // The raw floating-point number. Bits bits_; // The bits that represent the number. }; @@ -7142,6 +7168,7 @@ class FloatingPoint { // Typedefs the instances of the FloatingPoint template class that we // care to use. +typedef FloatingPoint Half; typedef FloatingPoint Float; typedef FloatingPoint Double; @@ -9427,7 +9454,7 @@ void DefaultPrintNonContainerTo(const T& value, ::std::ostream* os) { // impossible to define #1 (e.g. when foo is ::std, defining // anything in it is undefined behavior unless you are a compiler // vendor.). - *os << value; + *os << fixup_arg_type(value); } } // namespace testing_internal @@ -18580,7 +18607,10 @@ AssertionResult CmpHelperFloatingPointEQ(const char* expected_expression, const char* actual_expression, RawType expected, RawType actual) { - const FloatingPoint lhs(expected), rhs(actual); + + auto expect_val = expected; + auto actual_val = actual; + const FloatingPoint lhs(expect_val), rhs(actual_val); if (lhs.AlmostEquals(rhs)) { return AssertionSuccess(); @@ -18601,6 +18631,62 @@ AssertionResult CmpHelperFloatingPointEQ(const char* expected_expression, false); } +template +AssertionResult CmpHelperFloatingPointGE(const char* expected_expression, + const char* actual_expression, + RawType expected, + RawType actual) { + auto expect_val = expected; + auto actual_val = actual; + const FloatingPoint lhs(expect_val), rhs(actual_val); + if (lhs.AlmostEquals(rhs) || lhs.GT(rhs)) { + return AssertionSuccess(); + } + + ::std::stringstream expected_ss; + expected_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << expected; + + ::std::stringstream actual_ss; + actual_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << actual; + + return EqFailure(expected_expression, + actual_expression, + StringStreamToString(&expected_ss), + StringStreamToString(&actual_ss), + false); +} + +template +AssertionResult CmpHelperFloatingPointLE(const char* expected_expression, + const char* actual_expression, + RawType expected, + RawType actual) { + + auto expect_val = expected; + auto actual_val = actual; + const FloatingPoint lhs(expect_val), rhs(actual_val); + + if (lhs.AlmostEquals(rhs) || lhs.LT(rhs)) { + return AssertionSuccess(); + } + + ::std::stringstream expected_ss; + expected_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << expected; + + ::std::stringstream actual_ss; + actual_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << actual; + + return EqFailure(expected_expression, + actual_expression, + StringStreamToString(&expected_ss), + StringStreamToString(&actual_ss), + false); +} + // Helper function for implementing ASSERT_NEAR. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. @@ -19226,17 +19312,25 @@ AssertionResult AssertPred5Helper(const char* pred_text, // ASSERT_GT(records.size(), 0) << "There is no record left."; #define EXPECT_EQ(expected, actual) \ - EXPECT_PRED_FORMAT2(::testing::internal:: \ - EqHelper::Compare, \ - expected, actual) + EXPECT_PRED_FORMAT2(::testing::internal:: \ + EqHelper::Compare, \ + expected, actual) + #define EXPECT_NE(expected, actual) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperNE, expected, actual) #define EXPECT_LE(val1, val2) \ - EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2) + if (!std::is_same::value) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2); \ + else \ + EXPECT_FLOAT_LE(val1, val2) + #define EXPECT_LT(val1, val2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2) #define EXPECT_GE(val1, val2) \ - EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2) + if (!std::is_same::value) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2); \ + else \ + EXPECT_FLOAT_GE(val1, val2) #define EXPECT_GT(val1, val2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGT, val1, val2) @@ -19330,9 +19424,41 @@ AssertionResult AssertPred5Helper(const char* pred_text, // FloatingPoint template class in gtest-internal.h if you are // interested in the implementation details. -#define EXPECT_FLOAT_EQ(expected, actual)\ - EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ - expected, actual) +#define IS_DOUBLE_TYPE(val) \ + (std::is_same::value || \ + std::is_same::value || \ + std::is_same::value || \ + std::is_same::value) + +#define EXPECT_FLOAT_GE(expected, actual)\ + if (IS_DOUBLE_TYPE(expected)) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointGE, \ + expected,\ + actual);\ + else \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointGE, \ + expected,\ + actual) + +#define EXPECT_FLOAT_LE(expected, actual)\ + if (IS_DOUBLE_TYPE(expected)) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointLE, \ + expected,\ + actual); \ + else \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointLE, \ + expected,\ + actual) + +#define EXPECT_FLOAT_EQ(expected, actual) \ + if (IS_DOUBLE_TYPE(expected)) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + (float)expected,\ + (float)actual); \ + else \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + expected,\ + actual) #define EXPECT_DOUBLE_EQ(expected, actual)\ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ diff --git a/tools/caffe-fp16.cpp b/tools/caffe-fp16.cpp new file mode 100644 index 00000000000..5b734d789a0 --- /dev/null +++ b/tools/caffe-fp16.cpp @@ -0,0 +1,592 @@ +#ifdef HAS_HALF_SUPPORT +#ifdef WITH_PYTHON_LAYER +#include "boost/python.hpp" +namespace bp = boost::python; +#endif + +#include +#include + +#include +#include +#include +#include + +#include "boost/algorithm/string.hpp" +#include "caffe/caffe.hpp" +#include "caffe/device.hpp" +#include "caffe/util/signal_handler.h" + +#ifdef USE_LIBDNN +#include "caffe/layers/libdnn_conv_layer.hpp" +#endif + +using caffe::Blob; +using caffe::Caffe; +using caffe::Net; +using caffe::Layer; +using caffe::Solver; +using caffe::shared_ptr; +using caffe::string; +using caffe::Timer; +using caffe::vector; +using caffe::device; +using std::ostringstream; + +DEFINE_string(gpu, "", + "Optional; run in GPU mode on given device IDs separated by ','." + "Use '-gpu all' to run on all available GPUs. The effective training " + "batch size is multiplied by the number of devices."); +DEFINE_string(solver, "", + "The solver definition protocol buffer text file."); +DEFINE_string(model, "", + "The model definition protocol buffer text file."); +DEFINE_string(phase, "", + "Optional; network phase (TRAIN or TEST). Only used for 'time'."); +DEFINE_int32(level, 0, + "Optional; network level."); +DEFINE_string(stage, "", + "Optional; network stages (not to be confused with phase), " + "separated by ','."); +DEFINE_string(snapshot, "", + "Optional; the snapshot solver state to resume training."); +DEFINE_string(weights, "", + "Optional; the pretrained weights to initialize finetuning, " + "separated by ','. Cannot be set simultaneously with snapshot."); +DEFINE_int32(iterations, 50, + "The number of iterations to run."); +DEFINE_string(sigint_effect, "stop", + "Optional; action to take when a SIGINT signal is received: " + "snapshot, stop or none."); +DEFINE_string(sighup_effect, "snapshot", + "Optional; action to take when a SIGHUP signal is received: " + "snapshot, stop or none."); +DEFINE_bool(lt, false, + "Optional; enable per layer timings"); + + +// A simple registry for caffe commands. +typedef int (*BrewFunction)(); +typedef std::map BrewMap; +BrewMap g_brew_map; + +#define RegisterBrewFunction(func) \ +namespace { \ +class __Registerer_##func { \ + public: /* NOLINT */ \ + __Registerer_##func() { \ + g_brew_map[#func] = &func; \ + } \ +}; \ +__Registerer_##func g_registerer_##func; \ +} + +static BrewFunction GetBrewFunction(const caffe::string& name) { + if (g_brew_map.count(name)) { + return g_brew_map[name]; + } else { + LOG(ERROR) << "Available caffe actions:"; + for (BrewMap::iterator it = g_brew_map.begin(); + it != g_brew_map.end(); ++it) { + LOG(ERROR) << "\t" << it->first; + } + LOG(FATAL) << "Unknown action: " << name; + return NULL; // not reachable, just to suppress old compiler warnings. + } +} + +// Parse GPU ids or use all available devices +static void get_gpus(vector* gpus) { + if (FLAGS_gpu == "all") { + int count = 0; +#ifndef CPU_ONLY + count = Caffe::EnumerateDevices(true); +#else + NO_GPU; +#endif + for (int i = 0; i < count; ++i) { + gpus->push_back(i); + } + } else if (FLAGS_gpu.size()) { + vector strings; + boost::split(strings, FLAGS_gpu, boost::is_any_of(",")); + for (int i = 0; i < strings.size(); ++i) { + gpus->push_back(boost::lexical_cast(strings[i])); + } + } else { + CHECK_EQ(gpus->size(), 0); + } +} + +// Parse phase from flags +caffe::Phase get_phase_from_flags(caffe::Phase default_value) { + if (FLAGS_phase == "") + return default_value; + if (FLAGS_phase == "TRAIN") + return caffe::TRAIN; + if (FLAGS_phase == "TEST") + return caffe::TEST; + LOG(FATAL) << "phase must be \"TRAIN\" or \"TEST\""; + return caffe::TRAIN; // Avoid warning +} + +// Parse stages from flags +vector get_stages_from_flags() { + vector stages; + boost::split(stages, FLAGS_stage, boost::is_any_of(",")); + return stages; +} + +// caffe commands to call by +// caffe +// +// To add a command, define a function "int command()" and register it with +// RegisterBrewFunction(action); + +// Device Query: show diagnostic information for a GPU device, or +// enumerate all devices if none is specified. +int device_query() { + if (FLAGS_gpu.size() == 0 || FLAGS_gpu == "all") { + // If no gpu is specified, enumerate all the devices. + caffe::Caffe::EnumerateDevices(); + } else { +#ifndef CPU_ONLY + LOG(INFO) << "Querying GPUs " << FLAGS_gpu; + vector gpus; + get_gpus(&gpus); + Caffe::SetDevices(gpus); + for (int i = 0; i < gpus.size(); ++i) { + caffe::Caffe::SetDevice(gpus[i]); + caffe::Caffe::DeviceQuery(); + } +#ifdef USE_GREENTEA + if (Caffe::GetDefaultDevice()->backend() == caffe::BACKEND_OpenCL) { + if (gpus.size() > 0 && gpus[0] >= 0) { + // Explicitly call for OCL + FFT + caffe::Caffe::TeardownDevice(gpus[0]); + } + } +#endif // USE_GREENTEA +#endif // !CPU_ONLY + } + return 0; +} +RegisterBrewFunction(device_query); + +// Load the weights from the specified caffemodel(s) into the train and +// test nets. +void CopyLayers(caffe::Solver* solver, const std::string& model_list) { + std::vector model_names; + boost::split(model_names, model_list, boost::is_any_of(",") ); + for (int i = 0; i < model_names.size(); ++i) { + LOG(INFO) << "Finetuning from " << model_names[i]; + solver->net()->CopyTrainedLayersFrom(model_names[i]); + for (int j = 0; j < solver->test_nets().size(); ++j) { + solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]); + } + } +} + +// Translate the signal effect the user specified on the command-line to the +// corresponding enumeration. +caffe::SolverAction::Enum GetRequestedAction( + const std::string& flag_value) { + if (flag_value == "stop") { + return caffe::SolverAction::STOP; + } + if (flag_value == "snapshot") { + return caffe::SolverAction::SNAPSHOT; + } + if (flag_value == "none") { + return caffe::SolverAction::NONE; + } + LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified"; + return caffe::SolverAction::NONE; +} + +// Train / Finetune a model. +int train() { + CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; + CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size()) + << "Give a snapshot to resume training or weights to finetune " + "but not both."; + vector stages = get_stages_from_flags(); + + caffe::SolverParameter solver_param; + caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param); + + solver_param.mutable_train_state()->set_level(FLAGS_level); + for (int i = 0; i < stages.size(); i++) { + solver_param.mutable_train_state()->add_stage(stages[i]); + } + + // If the gpus flag is not provided, allow the mode and device to be set + // in the solver prototxt. + if (FLAGS_gpu.size() == 0 + && solver_param.has_solver_mode() + && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) { + if (solver_param.has_device_id()) { + FLAGS_gpu = "" + + boost::lexical_cast(solver_param.device_id()); + } else { // Set default GPU if unspecified + FLAGS_gpu = "" + boost::lexical_cast(0); + } + } + + vector gpus; + get_gpus(&gpus); + if (gpus.size() == 0) { + LOG(INFO) << "Use CPU."; + Caffe::set_mode(Caffe::CPU); + } else { +#ifndef CPU_ONLY + // Load all devices that will be used + Caffe::SetDevices(gpus); + + ostringstream s; + for (int_tp i = 0; i < gpus.size(); ++i) { + s << (i ? ", " : "") << gpus[i]; + } + LOG(INFO) << "Using GPUs " << s.str(); + solver_param.set_device_id(gpus[0]); + // Initialize the first device + Caffe::SetDevice(gpus[0]); + Caffe::set_mode(Caffe::GPU); + Caffe::set_solver_count(gpus.size()); +#endif // !CPU_ONLY + } + + caffe::SignalHandler signal_handler( + GetRequestedAction(FLAGS_sigint_effect), + GetRequestedAction(FLAGS_sighup_effect)); + + shared_ptr > + solver(caffe::SolverRegistry::CreateSolver(solver_param)); + + solver->SetActionFunction(signal_handler.GetActionFunction()); + + if (FLAGS_snapshot.size()) { + LOG(INFO) << "Resuming from " << FLAGS_snapshot; + solver->Restore(FLAGS_snapshot.c_str()); + } else if (FLAGS_weights.size()) { + CopyLayers(solver.get(), FLAGS_weights); + } + + LOG(INFO) << "Starting Optimization"; + if (gpus.size() > 1) { +#ifdef USE_CUDA +#ifdef USE_NCCL + caffe::NCCL nccl(solver); + nccl.Run(gpus, FLAGS_snapshot.size() > 0 ? FLAGS_snapshot.c_str() : NULL); +#else + LOG(FATAL) << "Multi-GPU execution not available - rebuild with USE_NCCL"; +#endif // USE_NCCL +#endif // USE_CUDA + } else { + solver->Solve(); + } + LOG(INFO) << "Optimization Done."; + +#ifdef USE_GREENTEA + if (Caffe::GetDefaultDevice()->backend() == caffe::BACKEND_OpenCL) { + if (gpus.size() > 0 && gpus[0] >= 0) { + // Explicitly call for OCL + FFT + caffe::Caffe::TeardownDevice(gpus[0]); + } + } +#endif + return 0; +} +RegisterBrewFunction(train); + + +// Test: score a model. +int test() { + CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to score."; + CHECK_GT(FLAGS_weights.size(), 0) << "Need model weights to score."; + vector stages = get_stages_from_flags(); + + // Set device id and mode + vector gpus; + get_gpus(&gpus); + if (gpus.size() != 0) { +#ifndef CPU_ONLY + LOG(INFO) << "Use GPU with device ID " << gpus[0]; + Caffe::SetDevices(gpus); + Caffe::set_mode(Caffe::GPU); + Caffe::SetDevice(gpus[0]); +#endif // !CPU_ONLY + } else { + LOG(INFO) << "Use CPU."; + Caffe::set_mode(Caffe::CPU); + } + // Instantiate the caffe net. + Net caffe_net(FLAGS_model, caffe::TEST, + Caffe::GetDefaultDevice(), FLAGS_level, &stages); + caffe_net.CopyTrainedLayersFrom(FLAGS_weights); + LOG(INFO) << "Running for " << FLAGS_iterations << " iterations."; + + vector test_score_output_id; + vector test_score; + half loss = 0; + for (int_tp i = 0; i < FLAGS_iterations; ++i) { + half iter_loss; + const vector*>& result = + caffe_net.Forward(&iter_loss); + loss += iter_loss; + int_tp idx = 0; + for (int_tp j = 0; j < result.size(); ++j) { + const half* result_vec = result[j]->cpu_data(); + for (int_tp k = 0; k < result[j]->count(); ++k, ++idx) { + const half score = result_vec[k]; + if (i == 0) { + test_score.push_back(score); + test_score_output_id.push_back(j); + } else { + test_score[idx] += score; + } + const std::string& output_name = caffe_net.blob_names()[ + caffe_net.output_blob_indices()[j]]; + LOG(INFO) << "Batch " << i << ", " << output_name << " = " << score; + } + } + } + loss /= FLAGS_iterations; + LOG(INFO) << "Loss: " << loss; + for (int_tp i = 0; i < test_score.size(); ++i) { + const std::string& output_name = caffe_net.blob_names()[ + caffe_net.output_blob_indices()[test_score_output_id[i]]]; + const half loss_weight = caffe_net.blob_loss_weights()[ + caffe_net.output_blob_indices()[test_score_output_id[i]]]; + std::ostringstream loss_msg_stream; + const half mean_score = test_score[i] / FLAGS_iterations; + if (loss_weight) { + loss_msg_stream << " (* " << loss_weight + << " = " << loss_weight * mean_score << " loss)"; + } + LOG(INFO) << output_name << " = " << mean_score << loss_msg_stream.str(); + } +#ifdef USE_GREENTEA + if (Caffe::GetDefaultDevice()->backend() == caffe::BACKEND_OpenCL) { + if (gpus.size() > 0 && gpus[0] >= 0) { + // Explicitly call for OCL + FFT + caffe::Caffe::TeardownDevice(gpus[0]); + } + } +#endif + + return 0; +} +RegisterBrewFunction(test); + + +// Time: benchmark the execution time of a model. +int time() { + CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to time."; + caffe::Phase phase = get_phase_from_flags(caffe::TRAIN); + vector stages = get_stages_from_flags(); + + // Set device id and mode + vector gpus; + get_gpus(&gpus); + if (gpus.size() != 0) { +#ifndef CPU_ONLY + LOG(INFO) << "Use GPU with device ID " << gpus[0]; + Caffe::SetDevices(gpus); + Caffe::set_mode(Caffe::GPU); + Caffe::SetDevice(gpus[0]); +#endif // !CPU_ONLY + } else { + LOG(INFO) << "Use CPU."; + Caffe::set_mode(Caffe::CPU); + } + // Instantiate the caffe net. + Net caffe_net(FLAGS_model, phase, + Caffe::GetDefaultDevice(), FLAGS_level, &stages); + + // Do a clean forward and backward pass, so that memory allocation are done + // and future iterations will be more stable. + LOG(INFO) << "Performing Forward"; + // Note that for the speed benchmark, we will assume that the network does + // not take any input blobs. + half initial_loss; + caffe_net.Forward(&initial_loss); + LOG(INFO) << "Initial loss: " << initial_loss; + if (phase == caffe::TRAIN) { + LOG(INFO) << "Performing Backward"; + caffe_net.Backward(); + } + + const vector > >& layers = caffe_net.layers(); + const vector*> >& bottom_vecs = caffe_net.bottom_vecs(); + const vector*> >& top_vecs = caffe_net.top_vecs(); + const vector >& bottom_need_backward = + caffe_net.bottom_need_backward(); + LOG(INFO) << "*** Benchmark begins ***"; + LOG(INFO) << "Testing for " << FLAGS_iterations << " iterations."; + Timer total_timer; + total_timer.Start(); + Timer forward_timer; + Timer backward_timer; + Timer timer; + std::vector forward_time_per_layer(layers.size(), 0.0); + std::vector backward_time_per_layer(layers.size(), 0.0); + double forward_time = 0.0; + double backward_time = 0.0; + + for (int_tp j = 0; j < FLAGS_iterations; ++j) { + Timer iter_timer; + iter_timer.Start(); + forward_timer.Start(); + for (int_tp i = 0; i < layers.size(); ++i) { + if (FLAGS_lt) { + timer.Start(); + } + + layers[i]->Forward(bottom_vecs[i], top_vecs[i]); + + if (FLAGS_lt) { + Caffe::Synchronize(Caffe::GetDefaultDevice()->id()); + forward_time_per_layer[i] += timer.MicroSeconds(); + } + } + Caffe::Synchronize(Caffe::GetDefaultDevice()->id()); + forward_time += forward_timer.MicroSeconds(); + if (phase == caffe::TRAIN) { + backward_timer.Start(); + for (int_tp i = layers.size() - 1; i >= 0; --i) { + timer.Start(); + layers[i]->Backward(top_vecs[i], bottom_need_backward[i], + bottom_vecs[i]); + Caffe::Synchronize(Caffe::GetDefaultDevice()->id()); + backward_time_per_layer[i] += timer.MicroSeconds(); + } + backward_time += backward_timer.MicroSeconds(); + } + LOG(INFO) << "Iteration: " << j + 1 << " forward-backward time: " + << iter_timer.MilliSeconds() << " ms."; + } + + if (FLAGS_lt) { + LOG(INFO) << "Average time per layer: "; + for (int_tp i = 0; i < layers.size(); ++i) { + const caffe::string& layername = layers[i]->layer_param().name(); + LOG(INFO) << std::setfill(' ') << std::setw(10) << layername << + "\tforward: " << forward_time_per_layer[i] / 1000 / + FLAGS_iterations << " ms."; + LOG(INFO) << std::setfill(' ') << std::setw(10) << layername << + "\tbackward: " << backward_time_per_layer[i] / 1000 / + FLAGS_iterations << " ms."; + } + } + total_timer.Stop(); + LOG(INFO) << "Average Forward pass: " << forward_time / 1000 / + FLAGS_iterations << " ms."; + LOG(INFO) << "Average Backward pass: " << backward_time / 1000 / + FLAGS_iterations << " ms."; + LOG(INFO) << "Average Forward-Backward: " << total_timer.MilliSeconds() / + FLAGS_iterations << " ms."; + LOG(INFO) << "Total Time: " << total_timer.MilliSeconds() << " ms."; + LOG(INFO) << "*** Benchmark ends ***"; + +#ifdef USE_GREENTEA + if (Caffe::GetDefaultDevice()->backend() == caffe::BACKEND_OpenCL) { + if (gpus.size() > 0 && gpus[0] >= 0) { + // Explicitly call for OCL + FFT + caffe::Caffe::TeardownDevice(gpus[0]); + } + } +#endif + return 0; +} +RegisterBrewFunction(time); + + +int autotune() { + CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to time."; + + vector gpus; + get_gpus(&gpus); + if (gpus.size() == 0) { + LOG(INFO) << "Use CPU."; + Caffe::set_mode(Caffe::CPU); + } else { +#ifndef CPU_ONLY + // Load all devices that will be used + Caffe::SetDevices(gpus); + + ostringstream s; + for (int_tp i = 0; i < gpus.size(); ++i) { + s << (i ? ", " : "") << gpus[i]; + } + LOG(INFO) << "Using GPUs " << s.str(); + // Initialize the first device + Caffe::SetDevice(gpus[0]); + Caffe::set_mode(Caffe::GPU); + Caffe::set_solver_count(gpus.size()); +#endif // !CPU_ONLY + } + + caffe::SignalHandler signal_handler( + GetRequestedAction(FLAGS_sigint_effect), + GetRequestedAction(FLAGS_sighup_effect)); + + Net net(FLAGS_model, caffe::TRAIN, Caffe::GetDefaultDevice()); + + for (int i = 0; i < net.layers().size(); ++i) { +#ifdef USE_LIBDNN + shared_ptr > layer = + boost::dynamic_pointer_cast > + (net.layers()[i]); + if (layer.get() != nullptr) { + half* top_data = net.top_vecs()[i][0]->mutable_gpu_data(); + half* top_diff = net.top_vecs()[i][0]->mutable_gpu_diff(); + half* bottom_data = net.top_vecs()[i][0]->mutable_gpu_data(); + half* bottom_diff = net.top_vecs()[i][0]->mutable_gpu_diff(); + int_tp batch_size = net.top_vecs()[i][0]->shape(0); + layer->Tune(top_data, top_diff, bottom_data, bottom_diff, batch_size); + } +#endif // USE_LIBDNN + } + return 0; +} +RegisterBrewFunction(autotune); + + + + +int main(int argc, char** argv) { + // Print output to stderr (while still logging). + FLAGS_alsologtostderr = 1; + // Set version + gflags::SetVersionString(AS_STRING(CAFFE_VERSION)); + // Usage message. + gflags::SetUsageMessage("command line brew\n" + "usage: caffe \n\n" + "commands:\n" + " train train or finetune a model\n" + " test score a model\n" + " device_query show GPU diagnostic information\n" + " time benchmark model execution time" + " autotune autotune a model"); + // Run tool or show usage. + caffe::GlobalInit(&argc, &argv); + if (argc == 2) { +#ifdef WITH_PYTHON_LAYER + try { +#endif + return GetBrewFunction(caffe::string(argv[1]))(); +#ifdef WITH_PYTHON_LAYER + } catch (bp::error_already_set) { + PyErr_Print(); + return 1; + } +#endif + } else { + gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe"); + } +} +#else +int main(int argc, char** argv) { +} +#endif From ca71f1b323dd93ef43bd72e28be5b1369b4bbb4d Mon Sep 17 00:00:00 2001 From: Junkai Wu Date: Fri, 30 Jun 2017 07:44:50 +0800 Subject: [PATCH 24/33] Optimize buffer based gemm_nt kernel with both float and fp16 versions. --- src/caffe/greentea/cl_kernels.cpp | 164 +++++++++++++++++++++++- src/caffe/greentea/cl_kernels/gemm.cl | 215 +++++++++++++++++++++++++++++++- src/caffe/layers/inner_product_layer.cu | 2 +- 3 files changed, 372 insertions(+), 9 deletions(-) diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index ab7cedcfab6..9c87d5bcff1 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -3798,13 +3798,168 @@ static std::vector> cl_kernels{ "#undef TILE_K", // NOLINT "#undef TILE_N", // NOLINT "", // NOLINT -"", // NOLINT "#define VEC_SIZE 1", // NOLINT -"#define LWG_HEIGHT 16", // NOLINT "#define TILE_M 8", // NOLINT -"#define TILE_K 32", // NOLINT "#define TILE_N 8", // NOLINT -"#define SLM_BLOCK 512", // NOLINT +"#define SLM_BLOCK 128", // NOLINT +"", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"#define LWG_HEIGHT 2", // NOLINT +"#define TILE_K 64", // NOLINT +"#else", // NOLINT +"#define LWG_HEIGHT 4", // NOLINT +"#define TILE_K 32", // NOLINT +"#endif", // NOLINT +"", // NOLINT +"#if TYPE == TYPE_HALF", // NOLINT +"__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT +"__attribute__((intel_reqd_sub_group_size(8)))", // NOLINT +"__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(", // NOLINT +"const __global Dtype *src0, int off0,", // NOLINT +"const __global Dtype *src1, int off1,", // NOLINT +"__global Dtype *dst, int offd,", // NOLINT +"int M,", // NOLINT +"int N,", // NOLINT +"int K,", // NOLINT +"KERNEL_ARG_DTYPE alpha_in,", // NOLINT +"KERNEL_ARG_DTYPE beta_in)", // NOLINT +"{", // NOLINT +"const Dtype alpha = (Dtype)alpha_in;", // NOLINT +"const Dtype beta = (Dtype)beta_in;", // NOLINT +"const int group_x = get_group_id(0);", // NOLINT +"const int group_y = get_group_id(1);", // NOLINT +"const int local_x = get_local_id(0);", // NOLINT +"const int local_y = get_local_id(1);", // NOLINT +"const int global_x = get_global_id(0);", // NOLINT +"const int global_y = get_global_id(1);", // NOLINT +"", // NOLINT +"Dtype8 dot00 = 0.f;", // NOLINT +"Dtype8 dot01 = 0.f;", // NOLINT +"Dtype8 dot02 = 0.f;", // NOLINT +"Dtype8 dot03 = 0.f;", // NOLINT +"Dtype8 dot04 = 0.f;", // NOLINT +"Dtype8 dot05 = 0.f;", // NOLINT +"Dtype8 dot06 = 0.f;", // NOLINT +"Dtype8 dot07 = 0.f;", // NOLINT +"", // NOLINT +"Dtype8 brow0;", // NOLINT +"Dtype8 brow1;", // NOLINT +"Dtype8 brow2;", // NOLINT +"Dtype8 brow3;", // NOLINT +"Dtype8 brow4;", // NOLINT +"Dtype8 brow5;", // NOLINT +"Dtype8 brow6;", // NOLINT +"Dtype8 brow7;", // NOLINT +"", // NOLINT +"__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;", // NOLINT +"", // NOLINT +"const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;", // NOLINT +"", // NOLINT +"const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;", // NOLINT +"", // NOLINT +"__local Dtype slm_brow[8 * SLM_BLOCK];", // NOLINT +"__local Dtype* slm_brow0;", // NOLINT +"", // NOLINT +"int local_index = mad24(local_y, 8, local_x) * 8;", // NOLINT +"int w;", // NOLINT +"for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"vstore4(vload4(0, (__global float *)(src1_read0 + mad24(0, K, local_index))), 0, (__local float *)(slm_brow + mad24(0, SLM_BLOCK, local_index)));", // NOLINT +"vstore4(vload4(0, (__global float *)(src1_read0 + mad24(1, K, local_index))), 0, (__local float *)(slm_brow + mad24(1, SLM_BLOCK, local_index)));", // NOLINT +"vstore4(vload4(0, (__global float *)(src1_read0 + mad24(2, K, local_index))), 0, (__local float *)(slm_brow + mad24(2, SLM_BLOCK, local_index)));", // NOLINT +"vstore4(vload4(0, (__global float *)(src1_read0 + mad24(3, K, local_index))), 0, (__local float *)(slm_brow + mad24(3, SLM_BLOCK, local_index)));", // NOLINT +"vstore4(vload4(0, (__global float *)(src1_read0 + mad24(4, K, local_index))), 0, (__local float *)(slm_brow + mad24(4, SLM_BLOCK, local_index)));", // NOLINT +"vstore4(vload4(0, (__global float *)(src1_read0 + mad24(5, K, local_index))), 0, (__local float *)(slm_brow + mad24(5, SLM_BLOCK, local_index)));", // NOLINT +"vstore4(vload4(0, (__global float *)(src1_read0 + mad24(6, K, local_index))), 0, (__local float *)(slm_brow + mad24(6, SLM_BLOCK, local_index)));", // NOLINT +"vstore4(vload4(0, (__global float *)(src1_read0 + mad24(7, K, local_index))), 0, (__local float *)(slm_brow + mad24(7, SLM_BLOCK, local_index)));", // NOLINT +"barrier(CLK_LOCAL_MEM_FENCE);", // NOLINT +"", // NOLINT +"slm_brow0 = slm_brow + local_x * (TILE_K / 8);", // NOLINT +"w = b_tile;", // NOLINT +"int end_w = min(b_tile + SLM_BLOCK, K);", // NOLINT +"while( w + TILE_K <= end_w ) {", // NOLINT +"Dtype8 arow;", // NOLINT +"", // NOLINT +"brow0 = as_half8(vload4(0, (__local float *)(slm_brow0 + 0 * SLM_BLOCK)));", // NOLINT +"brow1 = as_half8(vload4(0, (__local float *)(slm_brow0 + 1 * SLM_BLOCK)));", // NOLINT +"brow2 = as_half8(vload4(0, (__local float *)(slm_brow0 + 2 * SLM_BLOCK)));", // NOLINT +"brow3 = as_half8(vload4(0, (__local float *)(slm_brow0 + 3 * SLM_BLOCK)));", // NOLINT +"brow4 = as_half8(vload4(0, (__local float *)(slm_brow0 + 4 * SLM_BLOCK)));", // NOLINT +"brow5 = as_half8(vload4(0, (__local float *)(slm_brow0 + 5 * SLM_BLOCK)));", // NOLINT +"brow6 = as_half8(vload4(0, (__local float *)(slm_brow0 + 6 * SLM_BLOCK)));", // NOLINT +"brow7 = as_half8(vload4(0, (__local float *)(slm_brow0 + 7 * SLM_BLOCK)));", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _row, _dot ) arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); _dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); _dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); _dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); _dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); _dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); _dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); _dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); _dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );", // NOLINT +"MM_DOT_PRODUCT( 0, dot00 );", // NOLINT +"MM_DOT_PRODUCT( 1, dot01 );", // NOLINT +"MM_DOT_PRODUCT( 2, dot02 );", // NOLINT +"MM_DOT_PRODUCT( 3, dot03 );", // NOLINT +"MM_DOT_PRODUCT( 4, dot04 );", // NOLINT +"MM_DOT_PRODUCT( 5, dot05 );", // NOLINT +"MM_DOT_PRODUCT( 6, dot06 );", // NOLINT +"MM_DOT_PRODUCT( 7, dot07 );", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"", // NOLINT +"src0_read += TILE_K;", // NOLINT +"slm_brow0 += TILE_K;", // NOLINT +"w += TILE_K;", // NOLINT +"}", // NOLINT +"src1_read0 += SLM_BLOCK;", // NOLINT +"}", // NOLINT +"", // NOLINT +"if(w < K) {", // NOLINT +"Dtype8 arow;", // NOLINT +"", // NOLINT +"#define READ_BROW(_brow, _row) _brow = as_half8(vload4(0, (__local float *)(slm_brow0 + _row * SLM_BLOCK))); _brow.s0 = (mad24(local_x, 8, w) < K) ? _brow.s0 : 0.0f; _brow.s1 = (mad24(local_x, 8, w + 1) < K) ? _brow.s1 : 0.0f; _brow.s2 = (mad24(local_x, 8, w + 2) < K) ? _brow.s2 : 0.0f; _brow.s3 = (mad24(local_x, 8, w + 3) < K) ? _brow.s3 : 0.0f; _brow.s4 = (mad24(local_x, 8, w + 4) < K) ? _brow.s4 : 0.0f; _brow.s5 = (mad24(local_x, 8, w + 5) < K) ? _brow.s5 : 0.0f; _brow.s6 = (mad24(local_x, 8, w + 6) < K) ? _brow.s6 : 0.0f; _brow.s7 = (mad24(local_x, 8, w + 7) < K) ? _brow.s7 : 0.0f;", // NOLINT +"READ_BROW(brow0, 0);", // NOLINT +"READ_BROW(brow1, 1);", // NOLINT +"READ_BROW(brow2, 2);", // NOLINT +"READ_BROW(brow3, 3);", // NOLINT +"READ_BROW(brow4, 4);", // NOLINT +"READ_BROW(brow5, 5);", // NOLINT +"READ_BROW(brow6, 6);", // NOLINT +"READ_BROW(brow7, 7);", // NOLINT +"", // NOLINT +"#define MM_DOT_PRODUCT( _row, _dot ) arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); arow.s0 = (mad24(local_x, 8, w) < K) ? arow.s0 : 0.0f; arow.s1 = (mad24(local_x, 8, w + 1) < K) ? arow.s1 : 0.0f; arow.s2 = (mad24(local_x, 8, w + 2) < K) ? arow.s2 : 0.0f; arow.s3 = (mad24(local_x, 8, w + 3) < K) ? arow.s3 : 0.0f; arow.s4 = (mad24(local_x, 8, w + 4) < K) ? arow.s4 : 0.0f; arow.s5 = (mad24(local_x, 8, w + 5) < K) ? arow.s5 : 0.0f; arow.s6 = (mad24(local_x, 8, w + 6) < K) ? arow.s6 : 0.0f; arow.s7 = (mad24(local_x, 8, w + 7) < K) ? arow.s7 : 0.0f; _dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); _dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); _dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); _dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); _dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); _dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); _dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); _dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );", // NOLINT +"MM_DOT_PRODUCT( 0, dot00 );", // NOLINT +"MM_DOT_PRODUCT( 1, dot01 );", // NOLINT +"MM_DOT_PRODUCT( 2, dot02 );", // NOLINT +"MM_DOT_PRODUCT( 3, dot03 );", // NOLINT +"MM_DOT_PRODUCT( 4, dot04 );", // NOLINT +"MM_DOT_PRODUCT( 5, dot05 );", // NOLINT +"MM_DOT_PRODUCT( 6, dot06 );", // NOLINT +"MM_DOT_PRODUCT( 7, dot07 );", // NOLINT +"#undef MM_DOT_PRODUCT", // NOLINT +"}", // NOLINT +"", // NOLINT +"#define REDUCE(_dot) _dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));", // NOLINT +"REDUCE(dot00);", // NOLINT +"REDUCE(dot01);", // NOLINT +"REDUCE(dot02);", // NOLINT +"REDUCE(dot03);", // NOLINT +"REDUCE(dot04);", // NOLINT +"REDUCE(dot05);", // NOLINT +"REDUCE(dot06);", // NOLINT +"REDUCE(dot07);", // NOLINT +"#undef REDUCE", // NOLINT +"", // NOLINT +"Dtype output = 0.0f;", // NOLINT +"#define OUTPUT( _dot) output = (local_x == 0) ? _dot.s0 : output; output = (local_x == 1) ? _dot.s1 : output; output = (local_x == 2) ? _dot.s2 : output; output = (local_x == 3) ? _dot.s3 : output; output = (local_x == 4) ? _dot.s4 : output; output = (local_x == 5) ? _dot.s5 : output; output = (local_x == 6) ? _dot.s6 : output; output = (local_x == 7) ? _dot.s7 : output; dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); dst_write0 += N;", // NOLINT +"", // NOLINT +"if(global_x < N && global_y * 8 < M) {", // NOLINT +"OUTPUT(dot00);", // NOLINT +"if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }", // NOLINT +"if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }", // NOLINT +"if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }", // NOLINT +"if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }", // NOLINT +"if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }", // NOLINT +"if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }", // NOLINT +"if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }", // NOLINT +"}", // NOLINT +"#undef OUTPUT", // NOLINT +"}", // NOLINT +"", // NOLINT +"#else", // NOLINT "", // NOLINT "__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))", // NOLINT "__attribute__((intel_reqd_sub_group_size(8)))", // NOLINT @@ -3952,6 +4107,7 @@ static std::vector> cl_kernels{ "}", // NOLINT "#undef OUTPUT", // NOLINT "}", // NOLINT +"#endif", // NOLINT "", // NOLINT "#undef VEC_SIZE", // NOLINT "#undef LWG_HEIGHT", // NOLINT diff --git a/src/caffe/greentea/cl_kernels/gemm.cl b/src/caffe/greentea/cl_kernels/gemm.cl index f3577bd9522..e2c297c1d7e 100644 --- a/src/caffe/greentea/cl_kernels/gemm.cl +++ b/src/caffe/greentea/cl_kernels/gemm.cl @@ -1209,13 +1209,219 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( #undef TILE_K #undef TILE_N - #define VEC_SIZE 1 -#define LWG_HEIGHT 16 #define TILE_M 8 -#define TILE_K 32 #define TILE_N 8 -#define SLM_BLOCK 512 +#define SLM_BLOCK 128 + +#if TYPE == TYPE_HALF +#define LWG_HEIGHT 2 +#define TILE_K 64 +#else +#define LWG_HEIGHT 4 +#define TILE_K 32 +#endif + +#if TYPE == TYPE_HALF +__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) +__attribute__((intel_reqd_sub_group_size(8))) +__kernel void TEMPLATE(gemm_buffer_NT, Dtype)( + const __global Dtype *src0, int off0, + const __global Dtype *src1, int off1, + __global Dtype *dst, int offd, + int M, + int N, + int K, + KERNEL_ARG_DTYPE alpha_in, + KERNEL_ARG_DTYPE beta_in) +{ + const Dtype alpha = (Dtype)alpha_in; + const Dtype beta = (Dtype)beta_in; + const int group_x = get_group_id(0); + const int group_y = get_group_id(1); + const int local_x = get_local_id(0); + const int local_y = get_local_id(1); + const int global_x = get_global_id(0); + const int global_y = get_global_id(1); + + Dtype8 dot00 = 0.f; + Dtype8 dot01 = 0.f; + Dtype8 dot02 = 0.f; + Dtype8 dot03 = 0.f; + Dtype8 dot04 = 0.f; + Dtype8 dot05 = 0.f; + Dtype8 dot06 = 0.f; + Dtype8 dot07 = 0.f; + + Dtype8 brow0; + Dtype8 brow1; + Dtype8 brow2; + Dtype8 brow3; + Dtype8 brow4; + Dtype8 brow5; + Dtype8 brow6; + Dtype8 brow7; + + __global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd; + + const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0; + + const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1; + + __local Dtype slm_brow[8 * SLM_BLOCK]; + __local Dtype* slm_brow0; + + int local_index = mad24(local_y, 8, local_x) * 8; + int w; + for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) { + barrier(CLK_LOCAL_MEM_FENCE); + vstore4(vload4(0, (__global float *)(src1_read0 + mad24(0, K, local_index))), 0, (__local float *)(slm_brow + mad24(0, SLM_BLOCK, local_index))); + vstore4(vload4(0, (__global float *)(src1_read0 + mad24(1, K, local_index))), 0, (__local float *)(slm_brow + mad24(1, SLM_BLOCK, local_index))); + vstore4(vload4(0, (__global float *)(src1_read0 + mad24(2, K, local_index))), 0, (__local float *)(slm_brow + mad24(2, SLM_BLOCK, local_index))); + vstore4(vload4(0, (__global float *)(src1_read0 + mad24(3, K, local_index))), 0, (__local float *)(slm_brow + mad24(3, SLM_BLOCK, local_index))); + vstore4(vload4(0, (__global float *)(src1_read0 + mad24(4, K, local_index))), 0, (__local float *)(slm_brow + mad24(4, SLM_BLOCK, local_index))); + vstore4(vload4(0, (__global float *)(src1_read0 + mad24(5, K, local_index))), 0, (__local float *)(slm_brow + mad24(5, SLM_BLOCK, local_index))); + vstore4(vload4(0, (__global float *)(src1_read0 + mad24(6, K, local_index))), 0, (__local float *)(slm_brow + mad24(6, SLM_BLOCK, local_index))); + vstore4(vload4(0, (__global float *)(src1_read0 + mad24(7, K, local_index))), 0, (__local float *)(slm_brow + mad24(7, SLM_BLOCK, local_index))); + barrier(CLK_LOCAL_MEM_FENCE); + + slm_brow0 = slm_brow + local_x * (TILE_K / 8); + w = b_tile; + int end_w = min(b_tile + SLM_BLOCK, K); + while( w + TILE_K <= end_w ) { + Dtype8 arow; + + brow0 = as_half8(vload4(0, (__local float *)(slm_brow0 + 0 * SLM_BLOCK))); + brow1 = as_half8(vload4(0, (__local float *)(slm_brow0 + 1 * SLM_BLOCK))); + brow2 = as_half8(vload4(0, (__local float *)(slm_brow0 + 2 * SLM_BLOCK))); + brow3 = as_half8(vload4(0, (__local float *)(slm_brow0 + 3 * SLM_BLOCK))); + brow4 = as_half8(vload4(0, (__local float *)(slm_brow0 + 4 * SLM_BLOCK))); + brow5 = as_half8(vload4(0, (__local float *)(slm_brow0 + 5 * SLM_BLOCK))); + brow6 = as_half8(vload4(0, (__local float *)(slm_brow0 + 6 * SLM_BLOCK))); + brow7 = as_half8(vload4(0, (__local float *)(slm_brow0 + 7 * SLM_BLOCK))); + +#define MM_DOT_PRODUCT( _row, _dot ) \ + arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); \ + _dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \ + _dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \ + _dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \ + _dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \ + _dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \ + _dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \ + _dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \ + _dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot ); \ + + MM_DOT_PRODUCT( 0, dot00 ); + MM_DOT_PRODUCT( 1, dot01 ); + MM_DOT_PRODUCT( 2, dot02 ); + MM_DOT_PRODUCT( 3, dot03 ); + MM_DOT_PRODUCT( 4, dot04 ); + MM_DOT_PRODUCT( 5, dot05 ); + MM_DOT_PRODUCT( 6, dot06 ); + MM_DOT_PRODUCT( 7, dot07 ); +#undef MM_DOT_PRODUCT + + src0_read += TILE_K; + slm_brow0 += TILE_K; + w += TILE_K; + } + src1_read0 += SLM_BLOCK; + } + + if(w < K) { + Dtype8 arow; + +#define READ_BROW(_brow, _row) \ + _brow = as_half8(vload4(0, (__local float *)(slm_brow0 + _row * SLM_BLOCK))); \ + _brow.s0 = (mad24(local_x, 8, w) < K) ? _brow.s0 : 0.0f; \ + _brow.s1 = (mad24(local_x, 8, w + 1) < K) ? _brow.s1 : 0.0f; \ + _brow.s2 = (mad24(local_x, 8, w + 2) < K) ? _brow.s2 : 0.0f; \ + _brow.s3 = (mad24(local_x, 8, w + 3) < K) ? _brow.s3 : 0.0f; \ + _brow.s4 = (mad24(local_x, 8, w + 4) < K) ? _brow.s4 : 0.0f; \ + _brow.s5 = (mad24(local_x, 8, w + 5) < K) ? _brow.s5 : 0.0f; \ + _brow.s6 = (mad24(local_x, 8, w + 6) < K) ? _brow.s6 : 0.0f; \ + _brow.s7 = (mad24(local_x, 8, w + 7) < K) ? _brow.s7 : 0.0f; \ + + READ_BROW(brow0, 0); + READ_BROW(brow1, 1); + READ_BROW(brow2, 2); + READ_BROW(brow3, 3); + READ_BROW(brow4, 4); + READ_BROW(brow5, 5); + READ_BROW(brow6, 6); + READ_BROW(brow7, 7); + +#define MM_DOT_PRODUCT( _row, _dot ) \ + arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); \ + arow.s0 = (mad24(local_x, 8, w) < K) ? arow.s0 : 0.0f; \ + arow.s1 = (mad24(local_x, 8, w + 1) < K) ? arow.s1 : 0.0f; \ + arow.s2 = (mad24(local_x, 8, w + 2) < K) ? arow.s2 : 0.0f; \ + arow.s3 = (mad24(local_x, 8, w + 3) < K) ? arow.s3 : 0.0f; \ + arow.s4 = (mad24(local_x, 8, w + 4) < K) ? arow.s4 : 0.0f; \ + arow.s5 = (mad24(local_x, 8, w + 5) < K) ? arow.s5 : 0.0f; \ + arow.s6 = (mad24(local_x, 8, w + 6) < K) ? arow.s6 : 0.0f; \ + arow.s7 = (mad24(local_x, 8, w + 7) < K) ? arow.s7 : 0.0f; \ + _dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \ + _dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \ + _dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \ + _dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \ + _dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \ + _dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \ + _dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \ + _dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot ); \ + + MM_DOT_PRODUCT( 0, dot00 ); + MM_DOT_PRODUCT( 1, dot01 ); + MM_DOT_PRODUCT( 2, dot02 ); + MM_DOT_PRODUCT( 3, dot03 ); + MM_DOT_PRODUCT( 4, dot04 ); + MM_DOT_PRODUCT( 5, dot05 ); + MM_DOT_PRODUCT( 6, dot06 ); + MM_DOT_PRODUCT( 7, dot07 ); +#undef MM_DOT_PRODUCT + } + +#define REDUCE(_dot) \ + _dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \ + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7)); \ + + REDUCE(dot00); + REDUCE(dot01); + REDUCE(dot02); + REDUCE(dot03); + REDUCE(dot04); + REDUCE(dot05); + REDUCE(dot06); + REDUCE(dot07); +#undef REDUCE + + Dtype output = 0.0f; +#define OUTPUT( _dot) \ + output = (local_x == 0) ? _dot.s0 : output; \ + output = (local_x == 1) ? _dot.s1 : output; \ + output = (local_x == 2) ? _dot.s2 : output; \ + output = (local_x == 3) ? _dot.s3 : output; \ + output = (local_x == 4) ? _dot.s4 : output; \ + output = (local_x == 5) ? _dot.s5 : output; \ + output = (local_x == 6) ? _dot.s6 : output; \ + output = (local_x == 7) ? _dot.s7 : output; \ + dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \ + dst_write0 += N; + + if(global_x < N && global_y * 8 < M) { + OUTPUT(dot00); + if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); } + if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); } + if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); } + if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); } + if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); } + if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); } + if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); } + } +#undef OUTPUT +} + +#else __attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) __attribute__((intel_reqd_sub_group_size(8))) @@ -1398,6 +1604,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( } #undef OUTPUT } +#endif #undef VEC_SIZE #undef LWG_HEIGHT diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index 0b2ecec35a2..f06db1a5782 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -511,7 +511,7 @@ static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANS global[1] = 1; } else { size_t lx = sub_group_size; - size_t ly = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 16 : 4; + size_t ly = (TransB != CblasNoTrans && TransA == CblasNoTrans && halfPrecisionMode) ? 2 : 4; int dx = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 1 : 4; int dy = 8; size_t gx = (size_t)(N + dx - 1) / dx; From 683bd2f370767f19e41854eb2fea4c15b7407395 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Mon, 3 Jul 2017 14:18:06 +0800 Subject: [PATCH 25/33] Add negative slope support for relu fusion. To support yolo2, we need to add negative_slope support for relu fusion. Signed-off-by: Zhigang Gong --- include/caffe/layers/conv_spatial_layer.hpp | 16 +++++---- src/caffe/greentea/cl_kernels.cpp | 23 ++++++------ .../greentea/cl_kernels/conv_layer_spatial.cl | 36 +++++++------------ src/caffe/layers/conv_layer_spatial.cpp | 17 +++++++-- src/caffe/proto/caffe.proto | 1 + tools/inference-optimize/model_fuse.py | 42 ++++++++++++++++------ 6 files changed, 80 insertions(+), 55 deletions(-) diff --git a/include/caffe/layers/conv_spatial_layer.hpp b/include/caffe/layers/conv_spatial_layer.hpp index 5290fb65299..3c6bd671eca 100644 --- a/include/caffe/layers/conv_spatial_layer.hpp +++ b/include/caffe/layers/conv_spatial_layer.hpp @@ -193,11 +193,6 @@ class ConvolutionLayerSpatial : public BaseConvolutionLayer { return (this->layer_param_.convolution_param().fuse_type() != ConvolutionParameter_FuseType_UNFUSED); } - bool IsFusedWithReLU() const - { - return (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_RELU); - } - bool IsFusedWithMaxPoolAndReLU() const { return (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_MAX_POOLING_RELU); @@ -208,6 +203,12 @@ class ConvolutionLayerSpatial : public BaseConvolutionLayer { return (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_ELTWISE_RELU); } + bool IsFusedWithReLU() const + { + return IsFusedWithEltwiseReLU() || + (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_RELU); + } + #endif #endif @@ -262,8 +263,11 @@ class ConvolutionLayerSpatial : public BaseConvolutionLayer { kernelConfig* bestKernelConfig; // parameters for fused eltwise layer. - EltwiseParameter_EltwiseOp op_; vector coeffs_; + EltwiseParameter_EltwiseOp op_; + vector coeffs_; Blob max_idx_; + // parameter for relu + Dtype negative_slope_; bool stable_prod_grad_; diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index 9c87d5bcff1..95223472561 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -528,15 +528,19 @@ static std::vector> cl_kernels{ "}", // NOLINT "", // NOLINT "#ifdef FUSED_CONV_RELU", // NOLINT -"#define ACTIVATION_RELU_FUNCTION(x) max((Dtype)(x), (Dtype)0.0f)", // NOLINT +"#define ACTIVATION_RELU_FUNCTION(x) ((Dtype)(x) > 0 ? (Dtype)(x) : ((Dtype)(x) * (Dtype)(negative_slope)))", // NOLINT +"#define NEGATIVE_SLOPE_ARG KERNEL_ARG_DTYPE negative_slope,", // NOLINT "#else", // NOLINT "#define ACTIVATION_RELU_FUNCTION(x) (x)", // NOLINT +"#define NEGATIVE_SLOPE_ARG", // NOLINT "#endif", // NOLINT "", // NOLINT "#ifdef FUSED_CONV_ELTWISE", // NOLINT "#define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = ACTIVATION_RELU_FUNCTION(eltwise_data[(_offset_)] + (_data_));} while(0)", // NOLINT +"#define ELTWISE_DATA_ARG __global Dtype* eltwise_data,", // NOLINT "#else", // NOLINT "#define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = ACTIVATION_RELU_FUNCTION(_data_);} while(0)", // NOLINT +"#define ELTWISE_DATA_ARG", // NOLINT "#endif", // NOLINT "", // NOLINT "#define __CAT(x, y) x##y", // NOLINT @@ -562,9 +566,8 @@ static std::vector> cl_kernels{ "", // NOLINT "#ifdef MULTI", // NOLINT "__kernel void CFMultiNoPadding(", // NOLINT -"#ifdef FUSED_CONV_ELTWISE", // NOLINT -"__global Dtype* eltwise_data,", // NOLINT -"#endif", // NOLINT +"ELTWISE_DATA_ARG", // NOLINT +"NEGATIVE_SLOPE_ARG", // NOLINT "__global Dtype* image_data,", // NOLINT "int_tp image_offset,", // NOLINT "__global Dtype* kernel_data, int_tp kernel_offset,", // NOLINT @@ -683,9 +686,8 @@ static std::vector> cl_kernels{ "#endif", // NOLINT "__kernel void", // NOLINT "convolve_simd(", // NOLINT -"#ifdef FUSED_CONV_ELTWISE", // NOLINT -"__global Dtype* eltwise_data,", // NOLINT -"#endif", // NOLINT +"ELTWISE_DATA_ARG", // NOLINT +"NEGATIVE_SLOPE_ARG", // NOLINT "__global Dtype* inputs_base,", // NOLINT "filter_qualifier Dtype* weights_base,", // NOLINT "__global Dtype* biases_base,", // NOLINT @@ -933,11 +935,8 @@ static std::vector> cl_kernels{ "#define OUT_PITCH_X output_width", // NOLINT "#define ROW_PITCH input_width", // NOLINT "", // NOLINT -"#ifdef FUSED_CONV_ELTWISE", // NOLINT -"#define GEMM_LIKE_KERNEL_ARGS __global Dtype* eltwise_data, const __global Dtype *src0, const __global Dtype *src1, const __global Dtype *biases, __global Dtype *dst, const ushort input_width, const ushort input_height, const ushort output_width, const ushort output_height, const int_tp out_pitch_y, const int_tp out_pitch_z, const int_tp aligned_input_size, const int_tp slice_pitch", // NOLINT -"#else", // NOLINT -"#define GEMM_LIKE_KERNEL_ARGS const __global Dtype *src0, const __global Dtype *src1, const __global Dtype *biases, __global Dtype *dst, const ushort input_width, const ushort input_height, const ushort output_width, const ushort output_height, const int_tp out_pitch_y, const int_tp out_pitch_z, const int_tp aligned_input_size, const int_tp slice_pitch", // NOLINT -"#endif", // NOLINT +"", // NOLINT +"#define GEMM_LIKE_KERNEL_ARGS ELTWISE_DATA_ARG NEGATIVE_SLOPE_ARG const __global Dtype *src0, const __global Dtype *src1, const __global Dtype *biases, __global Dtype *dst, const ushort input_width, const ushort input_height, const ushort output_width, const ushort output_height, const int_tp out_pitch_y, const int_tp out_pitch_z, const int_tp aligned_input_size, const int_tp slice_pitch", // NOLINT "", // NOLINT "#endif", // NOLINT "", // NOLINT diff --git a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl index 515ab46029e..f971593ca13 100644 --- a/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl +++ b/src/caffe/greentea/cl_kernels/conv_layer_spatial.cl @@ -7,15 +7,19 @@ __kernel void TEMPLATE(conv_layer_spatial_phony,Dtype)(KERNEL_ARG_DTYPE arg) { } #ifdef FUSED_CONV_RELU -#define ACTIVATION_RELU_FUNCTION(x) max((Dtype)(x), (Dtype)0.0f) +#define ACTIVATION_RELU_FUNCTION(x) ((Dtype)(x) > 0 ? (Dtype)(x) : ((Dtype)(x) * (Dtype)(negative_slope))) +#define NEGATIVE_SLOPE_ARG KERNEL_ARG_DTYPE negative_slope, #else #define ACTIVATION_RELU_FUNCTION(x) (x) +#define NEGATIVE_SLOPE_ARG #endif #ifdef FUSED_CONV_ELTWISE #define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = ACTIVATION_RELU_FUNCTION(eltwise_data[(_offset_)] + (_data_));} while(0) +#define ELTWISE_DATA_ARG __global Dtype* eltwise_data, #else #define ACTIVATION_FUNCTION(_dst_, _offset_, _data_) do { (_dst_)[(_offset_)] = ACTIVATION_RELU_FUNCTION(_data_);} while(0) +#define ELTWISE_DATA_ARG #endif #define __CAT(x, y) x##y @@ -41,9 +45,8 @@ __kernel void TEMPLATE(conv_layer_spatial_phony,Dtype)(KERNEL_ARG_DTYPE arg) { #ifdef MULTI __kernel void CFMultiNoPadding( -#ifdef FUSED_CONV_ELTWISE - __global Dtype* eltwise_data, -#endif + ELTWISE_DATA_ARG + NEGATIVE_SLOPE_ARG __global Dtype* image_data, int_tp image_offset, __global Dtype* kernel_data, int_tp kernel_offset, @@ -162,9 +165,8 @@ __attribute__((reqd_work_group_size(1, 1, SIMD_SIZE))) #endif __kernel void convolve_simd( -#ifdef FUSED_CONV_ELTWISE - __global Dtype* eltwise_data, -#endif + ELTWISE_DATA_ARG + NEGATIVE_SLOPE_ARG __global Dtype* inputs_base, filter_qualifier Dtype* weights_base, __global Dtype* biases_base, @@ -412,23 +414,10 @@ typedef struct half0 { half s0; } half0; //never used but makes compiler happy. #define OUT_PITCH_X output_width #define ROW_PITCH input_width -#ifdef FUSED_CONV_ELTWISE -#define GEMM_LIKE_KERNEL_ARGS \ - __global Dtype* eltwise_data, \ - const __global Dtype *src0, \ - const __global Dtype *src1, \ - const __global Dtype *biases, \ - __global Dtype *dst, \ - const ushort input_width, \ - const ushort input_height, \ - const ushort output_width, \ - const ushort output_height, \ - const int_tp out_pitch_y, \ - const int_tp out_pitch_z, \ - const int_tp aligned_input_size, \ - const int_tp slice_pitch -#else + #define GEMM_LIKE_KERNEL_ARGS \ + ELTWISE_DATA_ARG \ + NEGATIVE_SLOPE_ARG \ const __global Dtype *src0, \ const __global Dtype *src1, \ const __global Dtype *biases, \ @@ -441,7 +430,6 @@ typedef struct half0 { half s0; } half0; //never used but makes compiler happy. const int_tp out_pitch_z, \ const int_tp aligned_input_size, \ const int_tp slice_pitch -#endif #endif diff --git a/src/caffe/layers/conv_layer_spatial.cpp b/src/caffe/layers/conv_layer_spatial.cpp index ad8c7c48a25..b00e49d379a 100644 --- a/src/caffe/layers/conv_layer_spatial.cpp +++ b/src/caffe/layers/conv_layer_spatial.cpp @@ -79,6 +79,11 @@ void ConvolutionLayerSpatial::LayerSetUp( CHECK(op_ == EltwiseParameter_EltwiseOp_SUM); } + if (IsFusedWithReLU()) + negative_slope_ = this->layer_param_.convolution_param().relu_param().negative_slope(); + else + negative_slope_ = 0; + if (std::getenv("CLCAFFE_CACHE_PATH")) cache_path_ << std::getenv("CLCAFFE_CACHE_PATH"); else if (std::getenv("VIENNACL_CACHE_PATH")) @@ -460,7 +465,7 @@ bool ConvolutionLayerSpatial::create_basic_kernel( << kernel_name_; if (IsFusedWithEltwiseReLU()) { - optionsString << " -DFUSED_CONV_RELU=1 -DFUSED_CONV_ELTWISE=1"; + optionsString << " -DFUSED_CONV_ELTWISE=1"; } if (IsFusedWithReLU()) { @@ -568,6 +573,8 @@ cl_int ConvolutionLayerSpatial::convolve( cl_uint argIdx = 0; if (IsFusedWithEltwiseReLU()) kernel.arg(argIdx++, WrapHandle((cl_mem) bottom[1]->gpu_data(), &ctx)); + if (IsFusedWithReLU()) + kernel.arg(argIdx++, fixup_arg_type(negative_slope_)); try { setBufferKernelArg(bottom, top, &kernel, argIdx++, &ctx, @@ -636,6 +643,8 @@ cl_int ConvolutionLayerSpatial::convolve( cl_uint argIdx = 0; if (IsFusedWithEltwiseReLU()) kernel.arg(argIdx++, WrapHandle((cl_mem) bottom[1]->gpu_data(), &ctx)); + if (IsFusedWithReLU()) + kernel.arg(argIdx++, fixup_arg_type(negative_slope_)); int_tp kernel_offset = kernel_h_ * kernel_w_ * (this->channels_ / this->group_) * M_ * g; @@ -727,6 +736,8 @@ cl_int ConvolutionLayerSpatial::convolve( if (IsFusedWithEltwiseReLU()) kernel.arg(argIdx++, WrapHandle((cl_mem) bottom[1]->gpu_data(), &ctx)); + if (IsFusedWithReLU()) + kernel.arg(argIdx++, fixup_arg_type(negative_slope_)); int_tp kernel_offset = kernel_h_ * kernel_w_ * (this->channels_ / this->group_) * M_ @@ -947,7 +958,7 @@ bool ConvolutionLayerSpatial::create_gemm_like_conv_kernel( " -DTILE_N_LAST_DIV8=" << (M_ % 32) / 8; if (IsFusedWithEltwiseReLU()) { - optionsString << " -DFUSED_CONV_RELU=1 -DFUSED_CONV_ELTWISE=1"; + optionsString << " -DFUSED_CONV_ELTWISE=1"; } if (IsFusedWithReLU()) { @@ -1059,7 +1070,7 @@ bool ConvolutionLayerSpatial::setup_IDLF( optionsString << " -DINPUT_PAD_W=" << pad_w_ << " -DINPUT_PAD_H=" << pad_h_; if (IsFusedWithEltwiseReLU()) { - optionsString << " -DFUSED_CONV_RELU=1 -DFUSED_CONV_ELTWISE=1"; + optionsString << " -DFUSED_CONV_ELTWISE=1"; } if (IsFusedWithReLU()) { diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index cc54e8c0491..ce0cebfe69d 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -637,6 +637,7 @@ message ConvolutionParameter { } optional FuseType fuse_type = 19 [default = UNFUSED]; // Whether to fuse convolution with other layers optional EltwiseParameter eltwise_param = 20; + optional ReLUParameter relu_param = 30; } message CropParameter { diff --git a/tools/inference-optimize/model_fuse.py b/tools/inference-optimize/model_fuse.py index 9d86d7817aa..fc9a27977fe 100644 --- a/tools/inference-optimize/model_fuse.py +++ b/tools/inference-optimize/model_fuse.py @@ -13,16 +13,30 @@ from pdb import set_trace def resnet_block_to_fuse_type(model, cur_conv_index): - actual = [model.layer[cur_conv_index+1].type, model.layer[cur_conv_index+2].type, model.layer[cur_conv_index+3].type, model.layer[cur_conv_index+4].type] + maxindex = len(model.layer)-1 + if cur_conv_index+1>maxindex: + return 0 #UNFUSED + elif cur_conv_index+2>maxindex: + actual = [model.layer[cur_conv_index+1].type, 'xxx', 'xxx', 'xxx'] + elif cur_conv_index+3>maxindex: + actual = [model.layer[cur_conv_index+1].type, model.layer[cur_conv_index+2].type, 'xxx', 'xxx'] + elif cur_conv_index+4>maxindex: + actual = [model.layer[cur_conv_index+1].type, model.layer[cur_conv_index+2].type, model.layer[cur_conv_index+3].type, 'xxx'] + else: + actual = [model.layer[cur_conv_index+1].type, model.layer[cur_conv_index+2].type, model.layer[cur_conv_index+3].type, model.layer[cur_conv_index+4].type] resnet = ['BatchNorm', 'Scale', 'ReLU'] resnet_merged = ['ReLU'] resnet_elt = ['BatchNorm', 'Scale', 'Eltwise', 'ReLU'] resnet_elt_merged = ['Eltwise', 'ReLU'] - if actual[:1] == resnet_merged or actual[:3] == resnet: - return 2 #FUSED_CONV_RELU TODO: not magic number - if actual[:2] == resnet_elt_merged or actual == resnet_elt: - return 3 #FUSED_CONV_ELTWISE_RELU - return 0 #UNFUSED + if actual[:1] == resnet_merged: + return (2, model.layer[cur_conv_index+1].relu_param.negative_slope) #FUSED_CONV_RELU TODO: not magic number + if actual[:3] == resnet: + return (2, model.layer[cur_conv_index+3].relu_param.negative_slope) #FUSED_CONV_RELU TODO: not magic number + if actual[:2] == resnet_elt_merged: + return (3, model.layer[cur_conv_index+2].relu_param.negative_slope)#FUSED_CONV_ELTWISE_RELU + if actual == resnet_elt: + return (3, model.layer[cur_conv_index+4].relu_param.negative_slope) #FUSED_CONV_ELTWISE_RELU + return 0,0 #UNFUSED def find_fused_blob_names(model, cur_conv_index): i = cur_conv_index + 1 @@ -71,8 +85,8 @@ def is_fused_layer(model, layer_index): def fuse_conv_layer(in_model, in_index, out_model, new_index): if out_model.layer[new_index].type == 'Convolution': - fuse_mode = resnet_block_to_fuse_type(in_model, in_index) - if [in_model.layer[in_index+1].type, in_model.layer[in_index+2].type] == ['BatchNorm', 'Scale']: + (fuse_mode, negative_slope) = resnet_block_to_fuse_type(in_model, in_index) + if len(in_model.layer)>in_index+2 and [in_model.layer[in_index+1].type, in_model.layer[in_index+2].type] == ['BatchNorm', 'Scale']: out_model.layer[new_index].convolution_param.bias_term = True out_model.layer[new_index].convolution_param.fuse_type = fuse_mode if fuse_mode == 3: # FUSED_CONV_ELTWISE_RELU, need to change top name to orig ReLU's top name @@ -80,6 +94,8 @@ def fuse_conv_layer(in_model, in_index, out_model, new_index): out_model.layer[new_index].top.remove(out_model.layer[new_index].top[0]) out_model.layer[new_index].top.append(new_top) out_model.layer[new_index].bottom.append(elt_bottom) + if fuse_mode != 0: + out_model.layer[new_index].convolution_param.relu_param.negative_slope = negative_slope def fuse_lrn_layer(in_model, in_index, out_model, out_index): if out_model.layer[out_index].type == 'LRN' and in_model.layer[in_index + 1].type == 'Pooling' and in_model.layer[ @@ -182,8 +198,14 @@ def generate_weights(in_model, args): continue #TODO:Need fix conv+bn if (in_model.layer[k].type == 'Convolution'): # Assuming convolution is followed by bn and scale - next1type = in_model.layer[k + 1].type - next2type = in_model.layer[k + 2].type + if k + 1 < len(in_model.layer): + next1type = in_model.layer[k + 1].type + else: + next1type = 'end' + if k + 2 < len(in_model.layer): + next2type = in_model.layer[k + 2].type + else: + next2type = 'end' else: print 'Warning: ' + prm + ' has parameters but I can\'t infer its layer type.' continue From b7f36b3ffd68133f13b357eda50f537f5845d2be Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Wed, 5 Jul 2017 17:18:30 +0800 Subject: [PATCH 26/33] We still need EXAMPLES_SOURCE_DIR for some test cases. Signed-off-by: Zhigang Gong --- cmake/Templates/caffe_config.h.in | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmake/Templates/caffe_config.h.in b/cmake/Templates/caffe_config.h.in index c1e3710c1e3..4ce33a2bb59 100644 --- a/cmake/Templates/caffe_config.h.in +++ b/cmake/Templates/caffe_config.h.in @@ -10,8 +10,9 @@ /* This is an absolute path so that we can run test from any build * directory */ #define ABS_TEST_DATA_DIR "${PROJECT_SOURCE_DIR}/src/caffe/test/test_data/" +#define EXAMPLES_SOURCE_DIR "${PROJECT_SOURCE_DIR}/examples/" /* Test device */ #define CUDA_TEST_DEVICE ${CUDA_TEST_DEVICE} -#endif // CAFFE_CONFIG_HPP_ \ No newline at end of file +#endif // CAFFE_CONFIG_HPP_ From 3733407479df03e560889b6264918712596f68f0 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Wed, 5 Jul 2017 17:40:19 +0800 Subject: [PATCH 27/33] Fix two OCL kernel compilation warnings. Signed-off-by: Zhigang Gong --- src/caffe/greentea/cl_kernels.cpp | 6 +++++- src/caffe/greentea/cl_kernels.sh | 2 +- src/caffe/greentea/cl_kernels/gemm.cl | 4 ++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index 95223472561..6f992063aff 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -3919,6 +3919,8 @@ static std::vector> cl_kernels{ "READ_BROW(brow6, 6);", // NOLINT "READ_BROW(brow7, 7);", // NOLINT "", // NOLINT +"#undef READ_BROW", // NOLINT +"", // NOLINT "#define MM_DOT_PRODUCT( _row, _dot ) arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); arow.s0 = (mad24(local_x, 8, w) < K) ? arow.s0 : 0.0f; arow.s1 = (mad24(local_x, 8, w + 1) < K) ? arow.s1 : 0.0f; arow.s2 = (mad24(local_x, 8, w + 2) < K) ? arow.s2 : 0.0f; arow.s3 = (mad24(local_x, 8, w + 3) < K) ? arow.s3 : 0.0f; arow.s4 = (mad24(local_x, 8, w + 4) < K) ? arow.s4 : 0.0f; arow.s5 = (mad24(local_x, 8, w + 5) < K) ? arow.s5 : 0.0f; arow.s6 = (mad24(local_x, 8, w + 6) < K) ? arow.s6 : 0.0f; arow.s7 = (mad24(local_x, 8, w + 7) < K) ? arow.s7 : 0.0f; _dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); _dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); _dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); _dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); _dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); _dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); _dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); _dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );", // NOLINT "MM_DOT_PRODUCT( 0, dot00 );", // NOLINT "MM_DOT_PRODUCT( 1, dot01 );", // NOLINT @@ -4068,6 +4070,8 @@ static std::vector> cl_kernels{ "READ_BROW(brow6, 6);", // NOLINT "READ_BROW(brow7, 7);", // NOLINT "", // NOLINT +"#undef READ_BROW", // NOLINT +"", // NOLINT "#define MM_DOT_PRODUCT( _row, _dot ) arow = vload4(0, src0_read + _row * K); arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; _dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); _dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); _dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); _dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );", // NOLINT "MM_DOT_PRODUCT( 0, dot00 );", // NOLINT "MM_DOT_PRODUCT( 1, dot01 );", // NOLINT @@ -7723,7 +7727,7 @@ viennacl::ocl::program & RegisterKernels(viennacl::ocl::context *ctx) { options = " -DFFT "; #endif #ifdef HAS_HALF_SUPPORT - options += " -DHAS_HALF_SUPPORT; "; + options += " -DHAS_HALF_SUPPORT "; #endif bool is_beignet = ctx->devices()[0].opencl_c_version().find("beignet") != std::string::npos; diff --git a/src/caffe/greentea/cl_kernels.sh b/src/caffe/greentea/cl_kernels.sh index bea5c919666..ba10936d2ee 100755 --- a/src/caffe/greentea/cl_kernels.sh +++ b/src/caffe/greentea/cl_kernels.sh @@ -258,7 +258,7 @@ echo "#ifdef USE_FFT" >> $SOURCE echo " options = \" -DFFT \";" >> $SOURCE echo "#endif" >> $SOURCE echo "#ifdef HAS_HALF_SUPPORT" >> $SOURCE -echo " options += \" -DHAS_HALF_SUPPORT; \";" >> $SOURCE +echo " options += \" -DHAS_HALF_SUPPORT \";" >> $SOURCE echo "#endif" >> $SOURCE echo " bool is_beignet = ctx->devices()[0].opencl_c_version().find(\"beignet\")" >> $SOURCE echo " != std::string::npos;" >> $SOURCE diff --git a/src/caffe/greentea/cl_kernels/gemm.cl b/src/caffe/greentea/cl_kernels/gemm.cl index e2c297c1d7e..3712f4b7ca5 100644 --- a/src/caffe/greentea/cl_kernels/gemm.cl +++ b/src/caffe/greentea/cl_kernels/gemm.cl @@ -1351,6 +1351,8 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( READ_BROW(brow6, 6); READ_BROW(brow7, 7); +#undef READ_BROW + #define MM_DOT_PRODUCT( _row, _dot ) \ arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); \ arow.s0 = (mad24(local_x, 8, w) < K) ? arow.s0 : 0.0f; \ @@ -1543,6 +1545,8 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( READ_BROW(brow6, 6); READ_BROW(brow7, 7); +#undef READ_BROW + #define MM_DOT_PRODUCT( _row, _dot ) \ arow = vload4(0, src0_read + _row * K); \ arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \ From 5072f067758a3fa6f9d091aaac5c9906623bb04d Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Wed, 5 Jul 2017 17:55:10 +0800 Subject: [PATCH 28/33] Fix inner product layer for non-intel platform. Signed-off-by: Zhigang Gong --- src/caffe/layers/inner_product_layer.cu | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index f06db1a5782..b7049d7eaf4 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -783,7 +783,8 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, K_ <= max_image_size && !std::is_same::value && this->device_->CheckCapability("cl_intel_subgroups")) { - if (!test_only_ || copied_weight_data_ != this->blobs_[0]->data().get()) { + if (!test_only_ || + copied_weight_data_ != this->blobs_[0]->data().get()) { int height = !transpose_ ? N_ : K_; int width = !transpose_ ? K_ : N_; if (weight_image_) { @@ -797,18 +798,26 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, height, width, width, (int)0, NULL, NULL); copied_weight_data_ = this->blobs_[0]->data().get(); } - } - tune_innerprod_type(this->device_->id(), - transpose_ ? CblasNoTrans : CblasTrans, - (cl_mem) bottom_data, (cl_mem) weight, (cl_mem) weight_image_, - max_image_size); + tune_innerprod_type(this->device_->id(), + transpose_ ? CblasNoTrans : CblasTrans, + (cl_mem) bottom_data, (cl_mem) weight, + (cl_mem) weight_image_, + max_image_size); + + innerprod_common(this->device_->id(), + transpose_ ? CblasNoTrans : CblasTrans, + M_, N_, K_, (cl_mem) bottom_data, + (cl_mem) weight, (cl_mem) weight_image_, + (cl_mem) top_data, innerprod_type_, max_image_size); + } else { + greentea_gpu_gemm(this->device_->id(), CblasNoTrans, + transpose_ ? CblasNoTrans : CblasTrans, + M_, N_, K_, (Dtype) 1., + (cl_mem) bottom_data, 0, (cl_mem) weight, 0, + (Dtype) 0., (cl_mem) top_data, 0); + } - innerprod_common(this->device_->id(), - transpose_ ? CblasNoTrans : CblasTrans, - M_, N_, K_, (cl_mem) bottom_data, - (cl_mem) weight, (cl_mem) weight_image_, - (cl_mem) top_data, innerprod_type_, max_image_size); if (bias_term_) { // Execute kernel greentea_gpu_gemm(this->device_->id(), CblasNoTrans, From 3031f4e7b2e66240462f0a2a905227e8d73eeb91 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Thu, 6 Jul 2017 09:11:19 +0800 Subject: [PATCH 29/33] Fix kernel compilation issue for non-Intel Gen platforms. Signed-off-by: Zhigang Gong --- src/caffe/greentea/cl_kernels.cpp | 8 ++++++++ src/caffe/greentea/cl_kernels/gemm.cl | 4 ++++ src/caffe/greentea/cl_kernels/lrn.cl | 4 ++++ 3 files changed, 16 insertions(+) diff --git a/src/caffe/greentea/cl_kernels.cpp b/src/caffe/greentea/cl_kernels.cpp index 6f992063aff..7e827e9a50b 100644 --- a/src/caffe/greentea/cl_kernels.cpp +++ b/src/caffe/greentea/cl_kernels.cpp @@ -3262,6 +3262,9 @@ static std::vector> cl_kernels{ "#include \"header.cl\"", // NOLINT "#endif", // NOLINT "", // NOLINT +"#if defined(cl_intel_subgroups)", // NOLINT +"#pragma OPENCL EXTENSION cl_intel_subgroups : enable", // NOLINT +"", // NOLINT "#if TYPE != TYPE_DOUBLE", // NOLINT "", // NOLINT "#define TILE_M 32", // NOLINT @@ -5091,6 +5094,7 @@ static std::vector> cl_kernels{ "#undef SHUFFLE_TYPE8", // NOLINT "", // NOLINT "#endif", // NOLINT +"#endif", // NOLINT ""}, // NOLINT {"#ifndef __OPENCL_VERSION__", // NOLINT "#include \"header.cl\"", // NOLINT @@ -5525,6 +5529,9 @@ static std::vector> cl_kernels{ "}", // NOLINT "}", // NOLINT "", // NOLINT +"#if defined(cl_intel_subgroups)", // NOLINT +"#pragma OPENCL EXTENSION cl_intel_subgroups : enable", // NOLINT +"", // NOLINT "#define SIMD_WIDTH 16", // NOLINT "#define TILE_W SIMD_WIDTH", // NOLINT "#define TILE_H 8", // NOLINT @@ -5634,6 +5641,7 @@ static std::vector> cl_kernels{ "#undef TILE_W", // NOLINT "#undef TILE_H", // NOLINT "#undef SIMD_WIDTH", // NOLINT +"#endif", // NOLINT "", // NOLINT "__kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int_tp nthreads, __global const Dtype* in,", // NOLINT "const int_tp num, const int_tp channels,", // NOLINT diff --git a/src/caffe/greentea/cl_kernels/gemm.cl b/src/caffe/greentea/cl_kernels/gemm.cl index 3712f4b7ca5..d9a505637fc 100644 --- a/src/caffe/greentea/cl_kernels/gemm.cl +++ b/src/caffe/greentea/cl_kernels/gemm.cl @@ -2,6 +2,9 @@ #include "header.cl" #endif +#if defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable + #if TYPE != TYPE_DOUBLE #define TILE_M 32 @@ -2647,3 +2650,4 @@ __kernel void TEMPLATE(gemm_buffer_TT, Dtype)( #undef SHUFFLE_TYPE8 #endif +#endif diff --git a/src/caffe/greentea/cl_kernels/lrn.cl b/src/caffe/greentea/cl_kernels/lrn.cl index 93cd1c25e90..1aef0352b28 100644 --- a/src/caffe/greentea/cl_kernels/lrn.cl +++ b/src/caffe/greentea/cl_kernels/lrn.cl @@ -120,6 +120,9 @@ __kernel void TEMPLATE(lrn_compute_diff,Dtype)(const int_tp nthreads, } } +#if defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable + #define SIMD_WIDTH 16 #define TILE_W SIMD_WIDTH #define TILE_H 8 @@ -229,6 +232,7 @@ __kernel void TEMPLATE(lrn_fuse_pool_max,Dtype)( #undef TILE_W #undef TILE_H #undef SIMD_WIDTH +#endif __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int_tp nthreads, __global const Dtype* in, const int_tp num, const int_tp channels, From 1a99aea8b21ea8a7df2395be7457630fc440f726 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Sun, 2 Jul 2017 14:19:56 +0800 Subject: [PATCH 30/33] Adjust test cases for half precision. Signed-off-by: Zhigang Gong --- src/caffe/test/test_blob.cpp | 2 +- src/caffe/test/test_convolution_layer_spatial.cpp | 2 +- src/caffe/test/test_infogain_loss_layer.cpp | 6 ++++-- src/caffe/test/test_reduction_layer.cpp | 4 +++- src/caffe/test/test_scale_layer.cpp | 2 +- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/caffe/test/test_blob.cpp b/src/caffe/test/test_blob.cpp index 0c51af3be27..ea980bd0039 100644 --- a/src/caffe/test/test_blob.cpp +++ b/src/caffe/test/test_blob.cpp @@ -122,7 +122,7 @@ class BlobMathTest : public MultiDeviceTest { : blob_(new Blob(2, 3, 4, 5)), epsilon_(1e-6) { if (std::is_same::value) - epsilon_ = 1e-2; + epsilon_ = 5e-2; } virtual ~BlobMathTest() { delete blob_; } diff --git a/src/caffe/test/test_convolution_layer_spatial.cpp b/src/caffe/test/test_convolution_layer_spatial.cpp index 25a12b0e788..cfcc118ace5 100644 --- a/src/caffe/test/test_convolution_layer_spatial.cpp +++ b/src/caffe/test/test_convolution_layer_spatial.cpp @@ -450,7 +450,7 @@ TYPED_TEST(ConvolutionLayerTest_Spatial, top_data = this->blob_top_->cpu_data(); ref_top_data = this->ref_blob_top_->cpu_data(); Dtype delta = std::is_same::value ? - 5e-1 : 1e-4; + 8e-1 : 1e-4; for (int_tp i = 0; i < this->blob_top_->count(); ++i) { EXPECT_NEAR(top_data[i], ref_top_data[i], delta); } diff --git a/src/caffe/test/test_infogain_loss_layer.cpp b/src/caffe/test/test_infogain_loss_layer.cpp index 90f7e943c11..03dbb6c3328 100644 --- a/src/caffe/test/test_infogain_loss_layer.cpp +++ b/src/caffe/test/test_infogain_loss_layer.cpp @@ -103,7 +103,8 @@ TYPED_TEST(InfogainLossLayerTest, TestInfogainLoss) { } for ( int l = 0; l < this->num_labels_; l++ ) { EXPECT_NEAR(prob[i*this->num_labels_*this->inner_ + l*this->inner_ + j], - est_prob[l]/den, 1e-6); + est_prob[l]/den, + (std::is_same::value) ? 1e-3 : 1e-6); } } } @@ -120,7 +121,8 @@ TYPED_TEST(InfogainLossLayerTest, TestInfogainLoss) { } } EXPECT_NEAR(this->blob_top_loss_->cpu_data()[0], - loss/(this->outer_*this->inner_), 1e-6); + loss/(this->outer_*this->inner_), + (std::is_same::value) ? 1e-3 : 1e-6); } TYPED_TEST(InfogainLossLayerTest, TestGradient) { diff --git a/src/caffe/test/test_reduction_layer.cpp b/src/caffe/test/test_reduction_layer.cpp index 152535cfa8c..32d5716199e 100644 --- a/src/caffe/test/test_reduction_layer.cpp +++ b/src/caffe/test/test_reduction_layer.cpp @@ -71,7 +71,9 @@ class ReductionLayerTest : public MultiDeviceTest { } expected_result *= coeff; const Dtype computed_result = this->blob_top_->cpu_data()[n]; - EXPECT_FLOAT_EQ(expected_result, computed_result) + Dtype eps = std::is_same::value ? + 5e-2 : 1e-4; + EXPECT_NEAR(expected_result, computed_result, eps * expected_result) << "Incorrect result computed with op " << ReductionParameter_ReductionOp_Name(op) << ", coeff " << coeff; } diff --git a/src/caffe/test/test_scale_layer.cpp b/src/caffe/test/test_scale_layer.cpp index cf3f538fbd3..fc4b134f5ac 100644 --- a/src/caffe/test/test_scale_layer.cpp +++ b/src/caffe/test/test_scale_layer.cpp @@ -355,7 +355,7 @@ TYPED_TEST(ScaleLayerTest, TestForwardBroadcastMiddleWithParamAndBias) { ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); const Dtype delta = std::is_same::value ? - 5e-3 : 1e-5; + 5e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { for (int_tp h = 0; h < this->blob_bottom_->height(); ++h) { From 12f62b380ce90361401c7adc415cfe2ae267e031 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Thu, 6 Jul 2017 11:34:53 +0800 Subject: [PATCH 31/33] Fix lrn fusion for non-Intel Gen platform. Signed-off-by: Zhigang Gong --- src/caffe/layers/lrn_layer.cu | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/src/caffe/layers/lrn_layer.cu b/src/caffe/layers/lrn_layer.cu index 9ba52373ed3..19460de6cdb 100644 --- a/src/caffe/layers/lrn_layer.cu +++ b/src/caffe/layers/lrn_layer.cu @@ -259,27 +259,31 @@ void LRNLayer::CrossChannelForward_gpu( } else if (IsFusedWithPoolMax()) { // We can't make sure the fused kernel be the faster for all platforms. // have to apply a simple tuning here. - if (fuse_tuned_) - CrossChannelForward_fuse_pooling_gpu(bottom, top, tuned_use_fuse_); - else { - float elapsedTime[2]; - bool use_fuse[2] = {true, false}; - // warm up. - CrossChannelForward_fuse_pooling_gpu(bottom, top, true); - CrossChannelForward_fuse_pooling_gpu(bottom, top, false); - for (int i = 0; i < 2; i++) { - Timer timer; - timer.initted(); - timer.Start(); - int loop_cnt = 2; - for (int j = 0; j < loop_cnt; j++) { - CrossChannelForward_fuse_pooling_gpu(bottom, top, use_fuse[i]); + if (this->device_->CheckCapability("cl_intel_subgroups")) { + if (fuse_tuned_) + CrossChannelForward_fuse_pooling_gpu(bottom, top, tuned_use_fuse_); + else { + float elapsedTime[2]; + bool use_fuse[2] = {true, false}; + // warm up. + CrossChannelForward_fuse_pooling_gpu(bottom, top, true); + CrossChannelForward_fuse_pooling_gpu(bottom, top, false); + for (int i = 0; i < 2; i++) { + Timer timer; + timer.initted(); + timer.Start(); + int loop_cnt = 2; + for (int j = 0; j < loop_cnt; j++) { + CrossChannelForward_fuse_pooling_gpu(bottom, top, use_fuse[i]); + } + timer.Stop(); + elapsedTime[i] = timer.MilliSeconds() / loop_cnt; } - timer.Stop(); - elapsedTime[i] = timer.MilliSeconds() / loop_cnt; + tuned_use_fuse_ = elapsedTime[0] < elapsedTime[1]; + fuse_tuned_ = true; } - tuned_use_fuse_ = elapsedTime[0] < elapsedTime[1]; - fuse_tuned_ = true; + } else { + CrossChannelForward_fuse_pooling_gpu(bottom, top, false); } } } From 03cd9249b93604685c5049d457a07f77ed970998 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Thu, 6 Jul 2017 15:48:35 +0800 Subject: [PATCH 32/33] Lint fix. Signed-off-by: Zhigang Gong --- cmake/lint.cmake | 6 + include/caffe/layers/conv_spatial_layer.hpp | 26 ++- include/caffe/layers/inner_product_layer.hpp | 2 + include/caffe/layers/lrn_layer.hpp | 19 +- include/caffe/test/test_gradient_check_util.hpp | 1 - include/caffe/util/fp16.hpp | 11 +- include/caffe/util/mkl_alternate.hpp | 1 - src/caffe/greentea/greentea_math_functions.cpp | 10 +- src/caffe/greentea/libdnn.cpp | 9 +- src/caffe/greentea/libdnn_conv_spatial.cpp | 33 ++-- src/caffe/greentea/libdnn_pool.cpp | 5 +- src/caffe/layers/batch_norm_layer.cu | 21 ++- src/caffe/layers/conv_layer_spatial.cpp | 42 +++-- src/caffe/layers/inner_product_layer.cu | 235 ++++++++++++++---------- src/caffe/layers/lrn_layer.cpp | 11 +- src/caffe/layers/lrn_layer.cu | 21 ++- src/caffe/layers/power_layer.cu | 6 +- src/caffe/layers/softmax_layer.cpp | 3 +- src/caffe/solvers/adagrad_solver.cu | 4 +- src/caffe/solvers/nesterov_solver.cu | 3 +- src/caffe/syncedmem.cpp | 3 +- src/caffe/test/test_bias_layer.cpp | 24 +-- src/caffe/test/test_gradient_based_solver.cpp | 9 +- src/caffe/test/test_inner_product_layer.cpp | 62 ++++--- src/caffe/test/test_lrn_layer.cpp | 21 ++- src/caffe/test/test_power_layer.cpp | 4 +- src/caffe/test/test_syncedmem.cpp | 2 +- src/caffe/test/test_tanh_layer.cpp | 3 +- src/caffe/util/hdf5.cpp | 7 +- src/caffe/util/math_functions.cpp | 32 ++-- 30 files changed, 363 insertions(+), 273 deletions(-) diff --git a/cmake/lint.cmake b/cmake/lint.cmake index 8cca27d5248..4be4b04680a 100644 --- a/cmake/lint.cmake +++ b/cmake/lint.cmake @@ -8,6 +8,7 @@ set(LINT_COMMAND ${python_executable} ${CMAKE_SOURCE_DIR}/scripts/cpp_lint.py) set(SRC_FILE_EXTENSIONS h hpp hu c cpp cu cc) set(EXCLUDE_FILE_EXTENSTIONS pb.h pb.cc) set(LINT_DIRS include src/caffe examples tools python matlab) +set(EXCLUDE_DIRS include/3rdparty) cmake_policy(SET CMP0009 NEW) # suppress cmake warning @@ -25,6 +26,11 @@ foreach(ext ${EXCLUDE_FILE_EXTENSTIONS}) set(EXCLUDED_FILES ${EXCLUDED_FILES} ${FOUND_FILES}) endforeach() +foreach(dir ${EXCLUDE_DIRS}) + file(GLOB_RECURSE FOUND_FILES ${CMAKE_SOURCE_DIR}/${dir}/*.*) + set(EXCLUDED_FILES ${EXCLUDED_SOURCES} ${FOUND_FILES}) +endforeach() + # exclude generated pb files if(EXCLUDED_FILES) list(REMOVE_ITEM LINT_SOURCES ${EXCLUDED_FILES}) diff --git a/include/caffe/layers/conv_spatial_layer.hpp b/include/caffe/layers/conv_spatial_layer.hpp index 3c6bd671eca..ebdac86c5cf 100644 --- a/include/caffe/layers/conv_spatial_layer.hpp +++ b/include/caffe/layers/conv_spatial_layer.hpp @@ -188,25 +188,25 @@ class ConvolutionLayerSpatial : public BaseConvolutionLayer { std::map, cl_mem> subBufferMap; std::vector tmpSubBuffers; - bool IsFused() const - { - return (this->layer_param_.convolution_param().fuse_type() != ConvolutionParameter_FuseType_UNFUSED); + bool IsFused() const { + return (this->layer_param_.convolution_param().fuse_type() + != ConvolutionParameter_FuseType_UNFUSED); } - bool IsFusedWithMaxPoolAndReLU() const - { - return (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_MAX_POOLING_RELU); + bool IsFusedWithMaxPoolAndReLU() const { + return (this->layer_param_.convolution_param().fuse_type() + == ConvolutionParameter_FuseType_FUSED_CONV_MAX_POOLING_RELU); } - bool IsFusedWithEltwiseReLU() const - { - return (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_ELTWISE_RELU); + bool IsFusedWithEltwiseReLU() const { + return (this->layer_param_.convolution_param().fuse_type() + == ConvolutionParameter_FuseType_FUSED_CONV_ELTWISE_RELU); } - bool IsFusedWithReLU() const - { + bool IsFusedWithReLU() const { return IsFusedWithEltwiseReLU() || - (this->layer_param_.convolution_param().fuse_type() == ConvolutionParameter_FuseType_FUSED_CONV_RELU); + (this->layer_param_.convolution_param().fuse_type() + == ConvolutionParameter_FuseType_FUSED_CONV_RELU); } #endif @@ -270,9 +270,7 @@ class ConvolutionLayerSpatial : public BaseConvolutionLayer { Dtype negative_slope_; bool stable_prod_grad_; - }; - } // namespace caffe #endif // CAFFE_CONV_SPATIAL_LAYER_HPP_ diff --git a/include/caffe/layers/inner_product_layer.hpp b/include/caffe/layers/inner_product_layer.hpp index 12a52adfb59..cd8731e07c7 100644 --- a/include/caffe/layers/inner_product_layer.hpp +++ b/include/caffe/layers/inner_product_layer.hpp @@ -1,6 +1,7 @@ #ifndef CAFFE_INNER_PRODUCT_LAYER_HPP_ #define CAFFE_INNER_PRODUCT_LAYER_HPP_ +#include #include #include "caffe/blob.hpp" @@ -75,6 +76,7 @@ class InnerProductLayer : public Layer { weight_image_ = NULL; } #endif + protected: virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); diff --git a/include/caffe/layers/lrn_layer.hpp b/include/caffe/layers/lrn_layer.hpp index f059a7d60fd..b2696930004 100644 --- a/include/caffe/layers/lrn_layer.hpp +++ b/include/caffe/layers/lrn_layer.hpp @@ -33,18 +33,16 @@ class LRNLayer : public Layer { virtual inline int_tp ExactNumBottomBlobs() const { return 1; } virtual inline int_tp ExactNumTopBlobs() const { return 1; } - bool IsFused() const - { - return (this->layer_param_.lrn_param().fuse_type() != LRNParameter_FuseType_UNFUSED); + bool IsFused() const { + return (this->layer_param_.lrn_param().fuse_type() + != LRNParameter_FuseType_UNFUSED); } - - bool IsFusedWithPoolMax() const - { - return (this->layer_param_.lrn_param().fuse_type() == LRNParameter_FuseType_FUSED_POOL_MAX); + bool IsFusedWithPoolMax() const { + return (this->layer_param_.lrn_param().fuse_type() + == LRNParameter_FuseType_FUSED_POOL_MAX); } - protected: virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); @@ -68,7 +66,8 @@ class LRNLayer : public Layer { virtual void WithinChannelBackward(const vector*>& top, const vector& propagate_down, const vector*>& bottom); - virtual void CrossChannelForward_fuse_pooling_gpu(const vector*>& bottom, + virtual void CrossChannelForward_fuse_pooling_gpu( + const vector*>& bottom, const vector*>& top, bool use_fuse); @@ -91,7 +90,7 @@ class LRNLayer : public Layer { bool fuse_tuned_; bool tuned_use_fuse_; Blob lrn_top_blob_; - vector*> lrn_top_vec_; // for pooling fusing + vector*> lrn_top_vec_; // for pooling fusing // Fields used for normalization ACROSS_CHANNELS // scale_ stores the int_tpermediate summing results diff --git a/include/caffe/test/test_gradient_check_util.hpp b/include/caffe/test/test_gradient_check_util.hpp index 4d85246b166..8425fa15d29 100644 --- a/include/caffe/test/test_gradient_check_util.hpp +++ b/include/caffe/test/test_gradient_check_util.hpp @@ -27,7 +27,6 @@ class GradientChecker { : stepsize_(stepsize), threshold_(threshold), seed_(seed), kink_(kink), kink_range_(kink_range) { if (std::is_same::value) { - //stepsize_ = 10 * stepsize; threshold_ = 100 * threshold; stepsize_ = stepsize; } diff --git a/include/caffe/util/fp16.hpp b/include/caffe/util/fp16.hpp index f2738ede3fa..fe3b80d6f75 100644 --- a/include/caffe/util/fp16.hpp +++ b/include/caffe/util/fp16.hpp @@ -1,20 +1,19 @@ #ifndef CAFFE_UTIL_FP16_H_ #define CAFFE_UTIL_FP16_H_ +#include #include "3rdparty/half/half.hpp" using half_float::half; -#define HALF_MAX 0x1.ffcp15f -#define HALF_MIN 0x1.0p-14f - -#include +#define HALF_MAX 0x1.ffcp15f +#define HALF_MIN 0x1.0p-14f inline float fixup_arg_type(float v) { return v; } inline float fixup_arg_type(half_float::half v) { - return float(v); + return static_cast(v); } inline double fixup_arg_type(double v) { @@ -46,7 +45,7 @@ inline unsigned long fixup_arg_type(unsigned long v) { } inline float fixup_arg_type(const half_float::detail::expr& expr) { - return float(expr); + return static_cast(expr); } inline const void * fixup_arg_type(const boost::shared_ptr& share_ptr) { diff --git a/include/caffe/util/mkl_alternate.hpp b/include/caffe/util/mkl_alternate.hpp index 188a4199f3a..2b6685b7abc 100644 --- a/include/caffe/util/mkl_alternate.hpp +++ b/include/caffe/util/mkl_alternate.hpp @@ -90,7 +90,6 @@ DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]) inline void cblas_haxpby(const int_tp N, const half alpha, const half* X, const int_tp incX, const half beta, half* Y, const int_tp incY) { - for (int_tp n = 0; n < N; n++) Y[n * incY] *= beta; diff --git a/src/caffe/greentea/greentea_math_functions.cpp b/src/caffe/greentea/greentea_math_functions.cpp index c2aca7db0bf..85b01e675ac 100644 --- a/src/caffe/greentea/greentea_math_functions.cpp +++ b/src/caffe/greentea/greentea_math_functions.cpp @@ -412,7 +412,8 @@ void greentea_gpu_gemv(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, uint row_size = M; uint col_size = N; size_t localsize = 128; - size_t globalsize = (isTransA ? col_size : (row_size + 3) / 4 * localsize); + size_t globalsize = (isTransA ? + col_size : (row_size + 3) / 4 * localsize); uint argId = 0; k.arg(argId++, row_size); @@ -586,7 +587,7 @@ void greentea_gpu_axpy(const int_tp ctx_id, const int_tp N, const Dtype alpha, GREENTEA_CL_BLAS_CHECK( clblasSaxpy(N, alpha, X, offX, 1, Y, offY, 1, 1, &queue, 0, NULL, NULL)); - } else if (std::is_same::value){ + } else if (std::is_same::value) { GREENTEA_CL_BLAS_CHECK( clblasDaxpy(N, alpha, X, offX, 1, Y, offY, 1, 1, &queue, 0, NULL, NULL)); @@ -859,7 +860,7 @@ void greentea_gpu_dot(const int_tp ctx_id, const int_tp n, const cl_mem X, GREENTEA_CL_BLAS_CHECK( clblasSdot(n, gpuout, 0, X, offX, 1, Y, offY, 1, scratch, 1, &queue, 0, NULL, NULL)); - } else if (std::is_same::value){ + } else if (std::is_same::value) { GREENTEA_CL_BLAS_CHECK( clblasDdot(n, gpuout, 0, X, offX, 1, Y, offY, 1, scratch, 1, &queue, 0, NULL, NULL)); @@ -1214,7 +1215,8 @@ void greentea_gpu_add_scalar(const int_tp ctx_id, const int_tp N, viennacl::ocl::kernel &oclk_add_scalar = program.get_kernel( CL_KERNEL_SELECT("add_scalar")); - viennacl::ocl::enqueue(oclk_add_scalar(N, fixup_arg_type(alpha), WrapHandle(Y, &ctx), offY), + viennacl::ocl::enqueue(oclk_add_scalar(N, fixup_arg_type(alpha), + WrapHandle(Y, &ctx), offY), ctx.get_queue()); } diff --git a/src/caffe/greentea/libdnn.cpp b/src/caffe/greentea/libdnn.cpp index e7b90f0ef22..78443c344ba 100644 --- a/src/caffe/greentea/libdnn.cpp +++ b/src/caffe/greentea/libdnn.cpp @@ -69,7 +69,7 @@ std::string LibDNN::generate_header() { for (int_tp i = 2; i <= 16; i *= 2) { ss << "#define Dtype" << i << " double" << i << std::endl; } - } else if (std::is_same::value){ + } else if (std::is_same::value) { ss << "#define Dtype float" << std::endl; ss << "#define Dtype1 float" << std::endl; // float2, float4, float8, float16 @@ -207,8 +207,8 @@ std::string LibDNN::generate_header() { ss << "current.floatVal[1] = *(source + 1);" << std::endl; ss << "do {" << std::endl; ss << "expected.intVal = current.intVal;" << std::endl; - ss << "next.floatVal[0] = expected.floatVal[0] " << atomic_ops[i] << " operand;" - << std::endl; + ss << "next.floatVal[0] = expected.floatVal[0] " + << atomic_ops[i] << " operand;" << std::endl; if (std::is_same::value) { ss << "next.floatVal[1] = expected.floatVal[1]; " << std::endl; } @@ -230,7 +230,8 @@ std::string LibDNN::generate_header() { } // Memory set - ss << "__kernel void fill_memory(const int_tp n, const KERNEL_ARG_DTYPE alpha," + ss << "__kernel void fill_memory(const int_tp n, " + << "const KERNEL_ARG_DTYPE alpha," << "__global Dtype* x, const int_tp offx) {" << std::endl; ss << "for (int_tp index = get_global_id(0); index < n; " << "index += get_global_size(0)) {" << std::endl; diff --git a/src/caffe/greentea/libdnn_conv_spatial.cpp b/src/caffe/greentea/libdnn_conv_spatial.cpp index d168fc7169a..2f261bbb301 100644 --- a/src/caffe/greentea/libdnn_conv_spatial.cpp +++ b/src/caffe/greentea/libdnn_conv_spatial.cpp @@ -258,29 +258,39 @@ std::string LibDNNConvSpatial::generate_fw_kernels(int_tp kernelType, // SIMD16/8 mode will be used, // the compiler could choose to use two SIMD8 threads, // and if that happens the code will break. - ss << "#if defined(convolve_simd) || defined(Conv_Interleaved)" << std::endl; + ss << "#if defined(convolve_simd) || defined(Conv_Interleaved)" + << std::endl; ss << "#if TYPE == TYPE_HALF" << std::endl; ss << "#define INT_TYPE ushort" << std::endl; ss << "#define INT_TYPE2 ushort2" << std::endl; ss << "#define INT_TYPE4 ushort4" << std::endl; ss << "#define INT_TYPE8 ushort8" << std::endl; - ss << "#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read_us2" << std::endl; - ss << "#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read_us4" << std::endl; - ss << "#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read_us8" << std::endl; - ss << "#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read_us" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read_us2" + << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read_us4" + << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read_us8" + << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read_us" + << std::endl; ss << "#else" << std::endl; ss << "#define INT_TYPE uint" << std::endl; ss << "#define INT_TYPE2 uint2" << std::endl; ss << "#define INT_TYPE4 uint4" << std::endl; ss << "#define INT_TYPE8 uint8" << std::endl; - ss << "#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read2" << std::endl; - ss << "#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read4" << std::endl; - ss << "#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read8" << std::endl; - ss << "#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read" << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read2" + << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read4" + << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read8" + << std::endl; + ss << "#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read" + << std::endl; ss << "#endif" << std::endl; ss << "#endif" << std::endl; ss << "#define activation_function(x) (x)" << std::endl; - ss << "__attribute__((reqd_work_group_size(1, 1, SIMD_SIZE)))" << std::endl; + ss << "__attribute__((reqd_work_group_size(1, 1, SIMD_SIZE)))" + << std::endl; ss << "kernel void" << std::endl; ss << "convolve_simd(" << std::endl; ss << "__global Dtype* inputs_base," << std::endl; @@ -2244,7 +2254,8 @@ bool LibDNNConvSpatial::verify_result( dbgPrint(printf("test verification failed @ image %d group %d" "out_ch %d h %d w %d got %G expected %G\n", n, g, out_ch, h, w, - (double)data[offset], (double)verify_data[offset])); + static_cast(data[offset]), + static_cast(verify_data[offset]))); verificationFail = 1; goto out; } diff --git a/src/caffe/greentea/libdnn_pool.cpp b/src/caffe/greentea/libdnn_pool.cpp index 2074c69a9a8..c3be6b099b2 100755 --- a/src/caffe/greentea/libdnn_pool.cpp +++ b/src/caffe/greentea/libdnn_pool.cpp @@ -205,12 +205,13 @@ std::string LibDNNPool::generate_fw_kernels(std::string name, if (std::is_same::value) { ss << "#define DTYPE_MAX HALF_MAX" << std::endl; ss << "#define DTYPE_MIN HALF_MIN" << std::endl; - } else + } else { #endif - { ss << "#define DTYPE_MAX FLT_MAX" << std::endl; ss << "#define DTYPE_MIN FLT_MIN" << std::endl; +#ifdef HAS_HALF_SUPPORT } +#endif ss << "__kernel void " + name + "("; ss << "__global const Dtype* __restrict bottom_data, "; diff --git a/src/caffe/layers/batch_norm_layer.cu b/src/caffe/layers/batch_norm_layer.cu index 73747e0427d..75367088f11 100644 --- a/src/caffe/layers/batch_norm_layer.cu +++ b/src/caffe/layers/batch_norm_layer.cu @@ -10,10 +10,13 @@ namespace caffe { oclk_bn_use_global_stats.arg(argIdx++, num); \ oclk_bn_use_global_stats.arg(argIdx++, channels_); \ oclk_bn_use_global_stats.arg(argIdx++, spatial_dim); \ - oclk_bn_use_global_stats.arg(argIdx++, fixup_arg_type(scale_factor)); \ + oclk_bn_use_global_stats.arg(argIdx++, \ + fixup_arg_type(scale_factor)); \ oclk_bn_use_global_stats.arg(argIdx++, fixup_arg_type(eps_)); \ - oclk_bn_use_global_stats.arg(argIdx++, WrapHandle((cl_mem) this->blobs_[0]->gpu_data(), &ctx)); \ - oclk_bn_use_global_stats.arg(argIdx++, WrapHandle((cl_mem) this->blobs_[1]->gpu_data(), &ctx)); + oclk_bn_use_global_stats.arg(argIdx++, \ + WrapHandle((cl_mem) this->blobs_[0]->gpu_data(), &ctx)); \ + oclk_bn_use_global_stats.arg(argIdx++, \ + WrapHandle((cl_mem) this->blobs_[1]->gpu_data(), &ctx)); template @@ -131,16 +134,17 @@ void BatchNormLayer::Forward_gpu(const vector*>& bottom, OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), oclk_bn_use_global_stats.handle().get(), 3, NULL, global_work_size_, NULL, 0, NULL, NULL)); - } - else { + } else { viennacl::ocl::kernel &oclk_bn_use_global_stats = program.get_kernel( CL_KERNEL_SELECT("bn_use_global_stats_in_place")); SET_COMMON_KERNEL_PARAMS - oclk_bn_use_global_stats.arg(argIdx++, WrapHandle((cl_mem) top_data, &ctx)); + oclk_bn_use_global_stats.arg(argIdx++, + WrapHandle((cl_mem) top_data, &ctx)); OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), - oclk_bn_use_global_stats.handle().get(), 3, NULL, global_work_size_, NULL, 0, NULL, NULL)); + oclk_bn_use_global_stats.handle().get(), 3, NULL, + global_work_size_, NULL, 0, NULL, NULL)); } } else { if (fused_relu) { @@ -157,8 +161,7 @@ void BatchNormLayer::Forward_gpu(const vector*>& bottom, OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), oclk_bn_use_global_stats.handle().get(), 3, NULL, global_work_size_, NULL, 0, NULL, NULL)); - } - else { + } else { viennacl::ocl::kernel &oclk_bn_use_global_stats = program.get_kernel(CL_KERNEL_SELECT("bn_use_global_stats")); diff --git a/src/caffe/layers/conv_layer_spatial.cpp b/src/caffe/layers/conv_layer_spatial.cpp index b00e49d379a..95e10bb61a8 100644 --- a/src/caffe/layers/conv_layer_spatial.cpp +++ b/src/caffe/layers/conv_layer_spatial.cpp @@ -26,7 +26,7 @@ namespace caffe { -#define ALIGN(val,N) ( ( (val) + (N) - 1 ) & ~( (N) - 1 ) ) +#define ALIGN(val, N) (((val) + (N) - 1) & ~((N) - 1)) template void ConvolutionLayerSpatial::compute_output_shape() { @@ -73,14 +73,17 @@ void ConvolutionLayerSpatial::LayerSetUp( bias_ = NULL; if (IsFusedWithEltwiseReLU()) { - CHECK(this->layer_param().eltwise_param().coeff_size() == 0); - CHECK(bottom.size() == 2); + CHECK_EQ( + this->layer_param().convolution_param().eltwise_param().coeff_size(), + 0); + CHECK_EQ(bottom.size(), 2); op_ = this->layer_param_.eltwise_param().operation(); - CHECK(op_ == EltwiseParameter_EltwiseOp_SUM); + CHECK_EQ(op_, EltwiseParameter_EltwiseOp_SUM); } if (IsFusedWithReLU()) - negative_slope_ = this->layer_param_.convolution_param().relu_param().negative_slope(); + negative_slope_ = + this->layer_param_.convolution_param().relu_param().negative_slope(); else negative_slope_ = 0; @@ -111,7 +114,6 @@ void ConvolutionLayerSpatial::LayerSetUp( template void ConvolutionLayerSpatial::Reshape(const vector*>& bottom, const vector*>& top) { - //printf("handle layer %s bottom size %ld \n", this->layer_param_.name().c_str(), bottom.size()); if (IsFusedWithEltwiseReLU()) { const vector*> bottom_image(bottom.begin(), bottom.end() - 1); BaseConvolutionLayer::Reshape(bottom_image, top); @@ -160,7 +162,8 @@ template void ConvolutionLayerSpatial::Forward_cpu( const vector*>& bottom, const vector*>& top) { const Dtype* weight = this->blobs_[0]->cpu_data(); - CHECK(IsFusedWithEltwiseReLU() == false && IsFusedWithReLU() == false); + CHECK_EQ(IsFusedWithEltwiseReLU() == false && IsFusedWithReLU() == false, + true); for (int_tp i = 0; i < bottom.size(); ++i) { const Dtype* bottom_data = bottom[i]->cpu_data(); Dtype* top_data = top[i]->mutable_cpu_data(); @@ -228,7 +231,7 @@ void ConvolutionLayerSpatial::Backward_cpu( template void ConvolutionLayerSpatial::generate_key() { - CHECK((!std::is_same::value)); + CHECK_EQ((std::is_same::value), false); std::stringstream keyBuilder; if (std::is_same::value) keyBuilder << "float_"; @@ -264,7 +267,7 @@ void ConvolutionLayerSpatial::generate_key() { template std::string ConvolutionLayerSpatial::generate_specific_key( int_tp type, int_tp blockWidth, int_tp blockHeight, int_tp blockDepth) { - CHECK((!std::is_same::value)); + CHECK_EQ((std::is_same::value), false); std::stringstream keyBuilder; keyBuilder << short_key_ << "_" << type @@ -401,7 +404,7 @@ void ConvolutionLayerSpatial::swizzleWeights( * kernel_w_ + c) * M_ + od] = weight_cpu[((od * this->channels_ + id) * kernel_h_ + r) * kernel_w_ + c ]; - interleaveMatrix( cpu_swizzled_weight, tmpSwizzledWeight, + interleaveMatrix(cpu_swizzled_weight, tmpSwizzledWeight, kernel_w_ * kernel_h_ * this->channels_, M_, interleavedRows, nonInterleavedRows, blockWidth, rowAlignment); free(tmpSwizzledWeight); @@ -413,7 +416,7 @@ void ConvolutionLayerSpatial::calculate_global_size(int_tp batch, int_tp* wio, // work item output size size_t* lSize, // local size size_t* gSize) { // global size - CHECK((!std::is_same::value)); + CHECK_EQ((std::is_same::value), false); gSize[0] = ceil( (fmax(static_cast(output_w_) / wio[0], 1.0)) / lSize[0]) * lSize[0]; @@ -431,7 +434,7 @@ bool ConvolutionLayerSpatial::create_basic_kernel( const vector*>& top, int_tp blockWidth, int_tp blockHeight, int_tp blockDepth) { - CHECK((!std::is_same::value)); + CHECK_EQ((std::is_same::value), false); // Standard spatial setup is done here std::stringstream keyBuilder; std::stringstream multFunctionBuilder; @@ -551,7 +554,7 @@ cl_int ConvolutionLayerSpatial::convolve( const vector*>& bottom, const vector*>& top, int_tp index, int_tp numImages, kernelConfig* config) { - CHECK((!std::is_same::value)); + CHECK_EQ((std::is_same::value), false); viennacl::ocl::context &ctx = viennacl::ocl::get_context(this->device_->id()); viennacl::ocl::program & program = ctx.get_program(config->kernelName); viennacl::ocl::kernel &kernel = program.get_kernel(config->kernelName); @@ -792,7 +795,7 @@ float ConvolutionLayerSpatial::timed_convolve( int_tp index, int_tp numImages, kernelConfig* config) { // warm up. - CHECK((!std::is_same::value)); + CHECK_EQ((std::is_same::value), false); bool saved_tuned = tuned_; tuned_ = false; convolve(bottom, top, index, this->num_, config); @@ -856,7 +859,6 @@ bool ConvolutionLayerSpatial::verify_result( return true; else if (config->tested) return false; - greentea_memset(this->device_->id(), top[index]->count() * sizeof(Dtype), 0xff, @@ -886,10 +888,13 @@ bool ConvolutionLayerSpatial::verify_result( if (fabs(data[offset] - verify_data[offset]) > 0.1 * fabs(verify_data[offset] * err_factor) && !(fabs(verify_data[offset]) < 1e-3 * err_factor - && fabs(data[offset] - verify_data[offset]) < 1e-4 * err_factor)) { + && fabs(data[offset] - verify_data[offset]) < + 1e-4 * err_factor)) { dbgPrint(printf("test verification failed @ image %d group %d" "out_ch %d h %d w %d got %G expected %G\n", - n, g, out_ch, h, w, float(data[offset]), float(verify_data[offset]))); + n, g, out_ch, h, w, + static_cast(data[offset]), + static_cast(verify_data[offset]))); verificationFail = 1; break; } @@ -1242,7 +1247,8 @@ void ConvolutionLayerSpatial::setup_convolution( if (this->group_ == 1 && ((M_ % 8 == 0) && (M_ % 32 != 24))) { create_convolution_kernel(bottom, top, 5, 1, 8, 32); create_convolution_kernel(bottom, top, 5, 2, 8, 32); - if ((kernel_w_ < 4 || (!std::is_same::value)) && M_ % 32 == 0) + if ((kernel_w_ < 4 || (!std::is_same::value)) + && M_ % 32 == 0) create_convolution_kernel(bottom, top, 5, 1, 16, 32); if (kernel_w_ < 4 && (!std::is_same::value)) create_convolution_kernel(bottom, top, 5, 2, 16, 32); diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index b7049d7eaf4..4efe9c9441c 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -1,12 +1,14 @@ +#include +#include #include - -#include "caffe/filler.hpp" -#include "caffe/layers/inner_product_layer.hpp" -#include "caffe/util/math_functions.hpp" #ifdef USE_GREENTEA #include "viennacl/tools/sha1.hpp" -#include "caffe/util/benchmark.hpp" #endif +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/inner_product_layer.hpp" +#include "caffe/util/benchmark.hpp" +#include "caffe/util/math_functions.hpp" namespace caffe { @@ -15,15 +17,15 @@ struct gemm_callback_arg { std::vector imgs; }; -static void CL_CALLBACK gemm_callback (cl_event event, +static void CL_CALLBACK gemm_callback(cl_event event, cl_int event_command_exec_status, void *user_data) { struct gemm_callback_arg *arg = (struct gemm_callback_arg *) user_data; - for(int i = 0; i < arg->evs.size(); i++) { + for (int i = 0; i < arg->evs.size(); i++) { clReleaseEvent(arg->evs[i]); } - for(int i = 0; i < arg->imgs.size(); i++) { + for (int i = 0; i < arg->imgs.size(); i++) { clReleaseMemObject(arg->imgs[i]); } delete arg; @@ -42,7 +44,6 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, int width, int ld, int wait_list_size, cl_event *wait_list, cl_event *event) { - viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); viennacl::ocl::program &program = (Caffe::Get().GetDevice(ctx_id, false)) ->program(); @@ -52,13 +53,13 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, bool halfPrecisionMode = !std::is_same::value; memset(&desc, 0, sizeof(desc)); - int src_offset = halfPrecisionMode ? sizeof(unsigned short) * offset : sizeof(float) * offset; + int src_offset = sizeof(Dtype) * offset; if (!is_matrix_a && transpose) { // For matrix B with transpose, we need to handle them differently. // As we can't use the sub group block read to get a row easily, // we have to use CL_FLOAT type with read_imagef to get the row. cl_int err; - if(halfPrecisionMode) { + if (halfPrecisionMode) { format.image_channel_data_type = CL_HALF_FLOAT; } else { format.image_channel_data_type = CL_FLOAT; @@ -79,7 +80,7 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, OCL_CHECK(err); } - if(ld == width) { + if (ld == width) { size_t origin[] = {0, 0, 0}; size_t region[] = {(size_t)desc.image_width, (size_t)desc.image_height, 1}; @@ -110,7 +111,7 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, } else { if (*image == NULL) { desc.image_type = CL_MEM_OBJECT_IMAGE2D; - if(halfPrecisionMode) { + if (halfPrecisionMode) { format.image_channel_data_type = CL_HALF_FLOAT; format.image_channel_order = CL_R; } else { @@ -142,7 +143,8 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, size_t region[] = {(size_t)width, (size_t)height, 1}; OCL_CHECK(clEnqueueCopyBufferToImage(ctx.get_queue().handle().get(), buffer, *image, src_offset, - origin, region, wait_list_size, wait_list, event)); + origin, region, wait_list_size, + wait_list, event)); } else { viennacl::ocl::kernel &oclk_gemm_copy = program.get_kernel( CL_KERNEL_SELECT("gemm_buffer_copy_image_no_transpose")); @@ -166,13 +168,15 @@ static void greentea_gpu_gemm_copy_buffer_to_image(int_tp ctx_id, } template -static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, +static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int_tp M, const int_tp N, const int_tp K, const Dtype alpha, const cl_mem A, const int_tp offA, const cl_mem B, const int_tp offB, const Dtype beta, cl_mem C, const int_tp offC, bool is_image_a, bool is_image_b, - enum gemm_type_t gemm_type, const size_t max_image_size) { + enum gemm_type_t gemm_type, + const size_t max_image_size) { CHECK_EQ(gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || gemm_type == GEMM_TYPE_FAST_IMAGE_32_2 || gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE, true) @@ -193,7 +197,8 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP int ldB = widthB; int ldC = N; - int A_start_x = 0, A_start_y = 0, B_start_x = 0, B_start_y = 0, C_start_x = 0, C_start_y = 0; + int A_start_x = 0, A_start_y = 0, B_start_x = 0; + int B_start_y = 0, C_start_x = 0, C_start_y = 0; int blocksize = 1024; if (gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) blocksize = max_image_size; @@ -230,9 +235,9 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP else kernel_name += "T"; - if (TransB == CblasNoTrans) + if (TransB == CblasNoTrans) { kernel_name += "N_"; - else { + } else { kernel_name += "T_"; if (is_image_b || (K % use_buffer_indicator != 0)) { kernel_name += "SCALAR_"; @@ -251,19 +256,19 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP else kernel_name += "1"; - if(halfPrecisionMode) { + if (halfPrecisionMode) { kernel_name += "_half"; } else { kernel_name += "_float"; } oclk_gemm_float = &program.get_kernel(kernel_name); - while(C_start_y < M) { - blockC_width = std::min((int)N - C_start_x, blocksize); - blockC_height = std::min((int)M - C_start_y, blocksize); + while (C_start_y < M) { + blockC_width = std::min(static_cast(N) - C_start_x, blocksize); + blockC_height = std::min(static_cast(M) - C_start_y, blocksize); int isFirstColBlock = 1; - for(int k = 0; k < K; k += blocksize) { + for (int k = 0; k < K; k += blocksize) { cl_event ev[5]; cl_uint ev_idx = 0; memset(ev, 0, sizeof(cl_event) * 5); @@ -273,9 +278,10 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP blockA_height = std::min(heightA - A_start_y, blocksize); blockB_width = std::min(widthB - B_start_x, blocksize); blockB_height = std::min(heightB - B_start_y, blocksize); - int block_Ksize = std::min((int)K - k, blocksize); + int block_Ksize = std::min(static_cast(K) - k, blocksize); - int padded_k = block_Ksize + ((block_Ksize & 7) ? (8 - (block_Ksize & 7)) : 0); + int padded_k = block_Ksize + ((block_Ksize & 7) ? + (8 - (block_Ksize & 7)) : 0); int imageA_w = (TransA == CblasNoTrans) ? padded_k : blockA_width; int imageA_h = (TransA == CblasNoTrans) ? blockA_height : padded_k; int imageB_w = (TransB == CblasNoTrans) ? blockB_width : padded_k; @@ -288,7 +294,7 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP bool padding_A = false; bool padding_B = false; - if(halfPrecisionMode && is_image_b) { + if (halfPrecisionMode && is_image_b) { padding_A = true; } @@ -300,18 +306,22 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP } if (!is_image_a) { - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, blockA_offset, + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, + A, blockA_offset, true, TransA != CblasNoTrans, padding_A, imageA_h, imageA_w, - blockA_height, blockA_width, ldA, 0, NULL, &ev[ev_idx]); + blockA_height, blockA_width, ldA, 0, + NULL, &ev[ev_idx]); if (ev[ev_idx] != NULL) ev_idx++; } if (!is_image_b) { - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, blockB_offset, + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, + B, blockB_offset, false, false, padding_B, imageB_h, imageB_w, - blockB_height, blockB_width, ldB, 0, NULL, &ev[ev_idx]); + blockB_height, blockB_width, ldB, + 0, NULL, &ev[ev_idx]); if (ev[ev_idx] != NULL) ev_idx++; } @@ -321,18 +331,22 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP if (!is_image_a) { bool padding; padding = !is_image_b || halfPrecisionMode; - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, A, blockA_offset, + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImA, + A, blockA_offset, true, TransA != CblasNoTrans, padding, imageA_h, imageA_w, - blockA_height, blockA_width, ldA, 0, NULL, &ev[ev_idx]); + blockA_height, blockA_width, ldA, + 0, NULL, &ev[ev_idx]); if (ev[ev_idx] != NULL) ev_idx++; } - if(!is_image_b && (K % use_buffer_indicator != 0)) { - greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, B, blockB_offset, + if (!is_image_b && (K % use_buffer_indicator != 0)) { + greentea_gpu_gemm_copy_buffer_to_image(ctx_id, &ImB, + B, blockB_offset, false, true, false, imageB_h, imageB_w, - blockB_height, blockB_width, ldB, 0, NULL, &ev[ev_idx]); + blockB_height, blockB_width, ldB, 0, + NULL, &ev[ev_idx]); if (ev[ev_idx] != NULL) ev_idx++; } @@ -345,14 +359,13 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP size_t global[2]; if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE ) { - if(halfPrecisionMode) { + if (halfPrecisionMode) { global[0] = (size_t)( blockC_width + 15 ) & ~15; } else { global[0] = (size_t)( blockC_width + 7 ) & ~7; } - } - else { - if(halfPrecisionMode) { + } else { + if (halfPrecisionMode) { global[0] = (size_t)( (blockC_width / 2 ) + 15 ) ^ ~15; } else { global[0] = (size_t)( (blockC_width / 2 ) + 7 ) ^ ~7; @@ -371,9 +384,10 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP cl_uint arg_idx = 0; oclk_gemm_float->arg(arg_idx++, WrapHandle(ImA, &ctx)); - if (TransB == CblasNoTrans || is_image_b || (K % use_buffer_indicator != 0)) + if (TransB == CblasNoTrans || is_image_b || + (K % use_buffer_indicator != 0)) { oclk_gemm_float->arg(arg_idx++, WrapHandle(ImB, &ctx)); - else { + } else { oclk_gemm_float->arg(arg_idx++, WrapHandle(B, &ctx)); oclk_gemm_float->arg(arg_idx++, blockB_offset); oclk_gemm_float->arg(arg_idx++, ldB); @@ -397,53 +411,55 @@ static void greentea_gpu_fast_image_gemm(const int_tp ctx_id, const CBLAS_TRANSP oclk_gemm_float->handle().get(), 2, NULL, global, local, ev_idx, wait_list, &ev[ev_idx])); - if(TransA == CblasNoTrans) + if (TransA == CblasNoTrans) A_start_x += blockA_width; else A_start_y += blockA_height; - if(TransB == CblasNoTrans) + if (TransB == CblasNoTrans) B_start_y += blockB_height; else B_start_x += blockB_width; isFirstColBlock = 0; arg->evs.assign(ev, ev + ev_idx + 1); - clSetEventCallback(ev[ev_idx], CL_COMPLETE, &gemm_callback, (void*)arg); + clSetEventCallback(ev[ev_idx], CL_COMPLETE, &gemm_callback, + static_cast(arg)); } C_start_x += blockC_width; - if(TransA == CblasNoTrans) + if (TransA == CblasNoTrans) A_start_x = 0; else A_start_y = 0; - if(TransB == CblasNoTrans) { + if (TransB == CblasNoTrans) { B_start_x += blockB_width; B_start_y = 0; } else { B_start_y += blockB_height; B_start_x = 0; } - if(C_start_x >= N) { + if (C_start_x >= N) { C_start_x = 0; B_start_x = 0; B_start_y = 0; C_start_y += blockC_height; - if(TransA == CblasNoTrans) + if (TransA == CblasNoTrans) A_start_y += blockA_height; else A_start_x += blockA_width; } } - if(ImA && !is_image_a) + if (ImA && !is_image_a) clReleaseMemObject(ImA); - if(ImB && !is_image_b) + if (ImB && !is_image_b) clReleaseMemObject(ImB); } template -static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANSPOSE TransA, +static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, + const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int_tp M, const int_tp N, const int_tp K, const Dtype alpha, const cl_mem A, const int_tp offA, const cl_mem B, @@ -463,12 +479,12 @@ static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANS bool is_small_batch = (M == 2 || M == 4 || M == 8); viennacl::ocl::kernel *oclk_gemm_float; std::string kernel_name("gemm_buffer_"); - if(TransA == CblasNoTrans && TransB == CblasNoTrans) { + if (TransA == CblasNoTrans && TransB == CblasNoTrans) { kernel_name += "NN"; - if(halfPrecisionMode) { + if (halfPrecisionMode) { sub_group_size = 16; } - } else if(TransA == CblasNoTrans && TransB != CblasNoTrans) { + } else if (TransA == CblasNoTrans && TransB != CblasNoTrans) { if (M == 2) kernel_name +="NT_M_2"; else if (M == 4) @@ -477,16 +493,16 @@ static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANS kernel_name +="NT_M_8"; else kernel_name += "NT"; - } else if(TransA != CblasNoTrans && TransB == CblasNoTrans) { + } else if (TransA != CblasNoTrans && TransB == CblasNoTrans) { kernel_name += "TN"; - if(halfPrecisionMode) { + if (halfPrecisionMode) { sub_group_size = 16; } } else { kernel_name += "TT"; } - if(halfPrecisionMode) { + if (halfPrecisionMode) { kernel_name += "_half"; } else { kernel_name += "_float"; @@ -495,23 +511,24 @@ static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANS oclk_gemm_float = &program.get_kernel(kernel_name); size_t local[2] = {}; size_t global[2] = {}; - if (TransA == CblasNoTrans && TransB != CblasNoTrans && is_small_batch ) { - if(M == 8) + if (TransA == CblasNoTrans && TransB != CblasNoTrans && is_small_batch) { + if (M == 8) local[0] = 16; - else if(M == 4) + else if (M == 4) local[0] = 32; else local[0] = 64; local[1] = 1; - if(M == 8) + if (M == 8) global[0] = N * local[0]; else global[0] = (N + 3) / 4 * local[0]; global[1] = 1; } else { size_t lx = sub_group_size; - size_t ly = (TransB != CblasNoTrans && TransA == CblasNoTrans && halfPrecisionMode) ? 2 : 4; + size_t ly = (TransB != CblasNoTrans && + TransA == CblasNoTrans && halfPrecisionMode) ? 2 : 4; int dx = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 1 : 4; int dy = 8; size_t gx = (size_t)(N + dx - 1) / dx; @@ -535,18 +552,20 @@ static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANS oclk_gemm_float->arg(arg_idx++, fixup_arg_type(alpha)); oclk_gemm_float->arg(arg_idx++, fixup_arg_type(beta)); - if(TransB == CblasNoTrans || TransA != CblasNoTrans) { + if (TransB == CblasNoTrans || TransA != CblasNoTrans) { int stride = 256; - for(int start_index = 0; start_index < K; start_index += stride) { + for (int start_index = 0; start_index < K; start_index += stride) { oclk_gemm_float->arg(arg_idx, start_index); OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), - oclk_gemm_float->handle().get(), 2, NULL, - global, local, 0, - NULL, &ev)); + oclk_gemm_float->handle().get(), + 2, NULL, + global, local, 0, + NULL, &ev)); } } else { OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), - oclk_gemm_float->handle().get(), 2, NULL, + oclk_gemm_float->handle().get(), + 2, NULL, global, local, 0, NULL, &ev)); } @@ -556,9 +575,10 @@ static void greentea_gpu_fast_buffer_gemm(const int_tp ctx_id, const CBLAS_TRANS template static void innerprod_common(const int_tp ctx_id, const CBLAS_TRANSPOSE TransB, const int_tp M, const int_tp N, const int_tp K, - const cl_mem A, const cl_mem B, const cl_mem B_image, - cl_mem C, gemm_type_t gemm_type, const size_t max_image_size) { - + const cl_mem A, const cl_mem B, + const cl_mem B_image, + cl_mem C, gemm_type_t gemm_type, + const size_t max_image_size) { if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || gemm_type == GEMM_TYPE_FAST_IMAGE_32_2) { greentea_gpu_fast_image_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, @@ -567,15 +587,19 @@ static void innerprod_common(const int_tp ctx_id, const CBLAS_TRANSPOSE TransB, } else if (gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) { greentea_gpu_fast_image_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, (Dtype)1., A, 0, B_image, 0, (Dtype)0., C, - 0, false, true, GEMM_TYPE_FAST_IMAGE_B_IMAGE, max_image_size); + 0, false, true, + GEMM_TYPE_FAST_IMAGE_B_IMAGE, + max_image_size); } else if (gemm_type == GEMM_TYPE_FAST_BUFFER) { - greentea_gpu_fast_buffer_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, + greentea_gpu_fast_buffer_gemm(ctx_id, CblasNoTrans, + TransB, M, N, K, 1.f, A, 0, B, 0, 0.f, C, 0, gemm_type); - } else + } else { greentea_gpu_gemm(ctx_id, CblasNoTrans, TransB, M, N, K, (Dtype)1., A, 0, B, 0, (Dtype)0., C, 0); + } } @@ -590,9 +614,10 @@ void InnerProductLayer::generate_key() { viennacl::ocl::context &ctx = viennacl::ocl::get_context (this->device_->id()); - std::string prefix = ctx.current_device().name() + ctx.current_device().vendor() - + ctx.current_device().driver_version() - + std::to_string(ctx.current_device().max_compute_units()); + std::string prefix = ctx.current_device().name() + + ctx.current_device().vendor() + + ctx.current_device().driver_version() + + std::to_string(ctx.current_device().max_compute_units()); key_ = viennacl::tools::sha1(prefix + keyBuilder.str()); // short_key_ = keyBuilder.str(); } @@ -604,9 +629,9 @@ template void InnerProductLayer::generate_key(); template bool InnerProductLayer::load_cache() { - if (tuned_) + if (tuned_) { return true; - else { + } else { generate_key(); // Find cached kernel configuration string outputFile; @@ -632,20 +657,21 @@ template bool InnerProductLayer::load_cache(); template void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, - const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, + const CBLAS_TRANSPOSE TransB, const cl_mem A, + const cl_mem B, const cl_mem B_image, const size_t max_image_size) { if (std::is_same::value) { innerprod_type_ = GEMM_TYPE_DEFAULT; return; } else { - //1. load cache + // 1. load cache if (load_cache()) { return; } else { - //2. if not cached generate tuning + // 2. if not cached generate tuning uint element_size = 0; bool halfPrecisionMode = !std::is_same::value; - if(halfPrecisionMode) { + if (halfPrecisionMode) { element_size = sizeof(uint16_t); } else { element_size = sizeof(float); @@ -653,7 +679,9 @@ void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, viennacl::ocl::context &ctx = viennacl::ocl::get_context(ctx_id); cl_int err; - cl_mem C = clCreateBuffer(ctx.handle().get(), CL_MEM_ALLOC_HOST_PTR, M_ * N_ * element_size, NULL, &err); + cl_mem C = clCreateBuffer(ctx.handle().get(), + CL_MEM_ALLOC_HOST_PTR, + M_ * N_ * element_size, NULL, &err); OCL_CHECK(err); std::vector gemm_tests; @@ -662,18 +690,18 @@ void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, if (B_image != NULL) gemm_tests.push_back(GEMM_TYPE_FAST_IMAGE_B_IMAGE); gemm_tests.push_back(GEMM_TYPE_FAST_BUFFER); - if(!halfPrecisionMode) + if (!halfPrecisionMode) gemm_tests.push_back(GEMM_TYPE_DEFAULT); // warm up. - for( int i = 0; i < gemm_tests.size(); i++ ) { + for ( int i = 0; i < gemm_tests.size(); i++ ) { innerprod_common(ctx_id, TransB, M_, N_, K_, A, B, B_image, C, gemm_tests[i], max_image_size); } float fastest_time = 1e10; int fastest_index = -1; clFinish(ctx.get_queue().handle().get()); - for( int i = 0; i < gemm_tests.size(); i++ ) { + for ( int i = 0; i < gemm_tests.size(); i++ ) { Timer timer; timer.initted(); timer.Start(); @@ -696,7 +724,7 @@ void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, if (fastest_index >= 0) { innerprod_type_ = gemm_tests[fastest_index]; } - //3. store cache. + // 3. store cache. string outputFile; outputFile = cache_path_.str() + key_; std::ofstream outputKernel; @@ -712,15 +740,19 @@ void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, #ifdef HAS_HALF_SUPPORT template void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, - const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, - const cl_mem B_image, const size_t max_image_size); + const CBLAS_TRANSPOSE TransB, const cl_mem A, + const cl_mem B, const cl_mem B_image, + const size_t max_image_size); #endif template void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, - const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, - const cl_mem B_image, const size_t max_image_size); -template void InnerProductLayer::tune_innerprod_type(const int_tp ctx_id, - const CBLAS_TRANSPOSE TransB, const cl_mem A, const cl_mem B, - const cl_mem B_image, const size_t max_image_size); + const CBLAS_TRANSPOSE TransB, const cl_mem A, + const cl_mem B, const cl_mem B_image, + const size_t max_image_size); +template void InnerProductLayer::tune_innerprod_type( + const int_tp ctx_id, + const CBLAS_TRANSPOSE TransB, const cl_mem A, + const cl_mem B, const cl_mem B_image, + const size_t max_image_size); template void InnerProductLayer::Forward_gpu(const vector*>& bottom, @@ -755,11 +787,14 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, int width = !transpose_ ? K_ : N_; if (M_ != 1) { if (std::is_same::value) { - padded_height = !transpose_ ? height : (height + ((height & 7) ? 1 : 0)); + padded_height = !transpose_ ? height : + (height + ((height & 7) ? 1 : 0)); padded_width = !transpose_ ? width : (width + ((width & 7) ? 1 : 0)); } else { - padded_height = !transpose_ ? height : (height + ((height & 7) ? (8-(height%8)) : 0)); - padded_width = !transpose_ ? width : (width + ((width & 7) ? (8-(width%8)) : 0)); + padded_height = !transpose_ ? height : + (height + ((height & 7) ? (8-(height%8)) : 0)); + padded_width = !transpose_ ? width : + (width + ((width & 7) ? (8-(width%8)) : 0)); } } @@ -795,7 +830,7 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, &weight_image_, (cl_mem) weight, 0, false, !transpose_, true, padded_height, padded_width, - height, width, width, (int)0, NULL, NULL); + height, width, width, static_cast(0), NULL, NULL); copied_weight_data_ = this->blobs_[0]->data().get(); } diff --git a/src/caffe/layers/lrn_layer.cpp b/src/caffe/layers/lrn_layer.cpp index aad981e4c6e..05809358152 100644 --- a/src/caffe/layers/lrn_layer.cpp +++ b/src/caffe/layers/lrn_layer.cpp @@ -69,10 +69,10 @@ void LRNLayer::LayerSetUp(const vector*>& bottom, CHECK(this->phase_ == caffe::TEST); CHECK(this->layer_param_.lrn_param().pooling_param().pool() == PoolingParameter_PoolMethod_MAX); - CHECK(this->layer_param_.lrn_param().pooling_param().kernel_size(0) - < 6); - CHECK(this->layer_param_.lrn_param().pooling_param().stride(0) < 4); - CHECK(this->layer_param_.lrn_param().pooling_param().dilation_size() == 0); + CHECK_LT(this->layer_param_.lrn_param().pooling_param().kernel_size(0), + 6); + CHECK_LT(this->layer_param_.lrn_param().pooling_param().stride(0), 4); + CHECK_EQ(this->layer_param_.lrn_param().pooling_param().dilation_size(), 0); pool_w_ = this->layer_param_.lrn_param().pooling_param().kernel_size(0); pool_h_ = this->layer_param_.lrn_param().pooling_param().kernel_size(0); pool_stride_w_ = this->layer_param_.lrn_param().pooling_param().stride(0); @@ -88,7 +88,8 @@ void LRNLayer::LayerSetUp(const vector*>& bottom, } lrn_top_vec_.push_back(&lrn_top_blob_); LayerParameter pooling_param; - pooling_param.mutable_pooling_param()->set_pool(PoolingParameter_PoolMethod_MAX); + pooling_param.mutable_pooling_param()->set_pool + (PoolingParameter_PoolMethod_MAX); pooling_param.mutable_pooling_param()->add_kernel_size(pool_w_); pooling_param.mutable_pooling_param()->add_stride(pool_stride_w_); pool_layer_.reset(new PoolingLayer(pooling_param)); diff --git a/src/caffe/layers/lrn_layer.cu b/src/caffe/layers/lrn_layer.cu index 19460de6cdb..a62bfa42c2a 100644 --- a/src/caffe/layers/lrn_layer.cu +++ b/src/caffe/layers/lrn_layer.cu @@ -1,8 +1,8 @@ #include #include "caffe/layers/lrn_layer.hpp" -#include "caffe/util/math_functions.hpp" #include "caffe/util/benchmark.hpp" +#include "caffe/util/math_functions.hpp" namespace caffe { @@ -107,8 +107,9 @@ void LRNLayer::CrossChannelForward_fuse_pooling_gpu( const int_tp tile_pooled_block_w = (TILE_W - pool_w_) / pool_stride_w_ + 1; const int tiled_width = (width_ + tile_pooled_block_w * pool_stride_w_ - 1) / (tile_pooled_block_w * pool_stride_w_); - const int tiled_height = (height_ + tile_pooled_block_h * pool_stride_h_ - 1) - / (tile_pooled_block_h * pool_stride_h_); + const int tiled_height = + (height_ + tile_pooled_block_h * pool_stride_h_ - 1) + / (tile_pooled_block_h * pool_stride_h_); int_tp n_threads = num_ * tiled_width * tiled_height; size_t global_work_size_[2] = {(size_t)n_threads, simd_size}; size_t local_work_size[2] = {1, simd_size}; @@ -133,8 +134,8 @@ void LRNLayer::CrossChannelForward_fuse_pooling_gpu( oclk_lrn_fill.arg(argIdx++, tile_pooled_block_w); OCL_CHECK(clEnqueueNDRangeKernel(ctx.get_queue().handle().get(), oclk_lrn_fill.handle().get(), 2, NULL, - global_work_size_, local_work_size, 0, NULL, - NULL)); + global_work_size_, local_work_size, 0, + NULL, NULL)); } else { Dtype* top_lrn_data = lrn_top_blob_.mutable_gpu_data(); // Do LRN firstly. @@ -171,7 +172,7 @@ void LRNLayer::CrossChannelForward_fuse_pooling_gpu( pool_w_, pool_stride_h_, pool_stride_w_, 0, 0, WrapHandle((cl_mem) top_data, &ctx)), ctx.get_queue()); - } + } } template @@ -234,7 +235,6 @@ void LRNLayer::CrossChannelForward_gpu( global_work_size_, NULL, 0, NULL, NULL)); } else { - if (!IsFused()) { cl_uint argIdx = 0; int_tp n_threads = num_ * height_ * width_; @@ -260,9 +260,9 @@ void LRNLayer::CrossChannelForward_gpu( // We can't make sure the fused kernel be the faster for all platforms. // have to apply a simple tuning here. if (this->device_->CheckCapability("cl_intel_subgroups")) { - if (fuse_tuned_) + if (fuse_tuned_) { CrossChannelForward_fuse_pooling_gpu(bottom, top, tuned_use_fuse_); - else { + } else { float elapsedTime[2]; bool use_fuse[2] = {true, false}; // warm up. @@ -304,7 +304,8 @@ template void LRNLayer::CrossChannelForward_gpu( template void LRNLayer::CrossChannelForward_fuse_pooling_gpu( const vector*>& bottom, const vector*>& top, bool); template void LRNLayer::CrossChannelForward_fuse_pooling_gpu( - const vector*>& bottom, const vector*>& top, bool); + const vector*>& bottom, + const vector*>& top, bool); template void LRNLayer::Backward_gpu(const vector*>& top, diff --git a/src/caffe/layers/power_layer.cu b/src/caffe/layers/power_layer.cu index 2d2a9d250ac..2c687cbd742 100644 --- a/src/caffe/layers/power_layer.cu +++ b/src/caffe/layers/power_layer.cu @@ -135,11 +135,13 @@ void PowerLayer::Backward_gpu(const vector*>& top, // -> dy/dx = 2 * scale * (shift + scale * x) // = diff_scale * shift + diff_scale * scale * x greentea_gpu_axpby(this->device_->id(), count, - Dtype(diff_scale_ * scale_), (cl_mem) bottom_data, 0, + Dtype(diff_scale_ * scale_), + (cl_mem) bottom_data, 0, Dtype(0), (cl_mem) bottom_diff, 0); if (shift_ != Dtype(0)) { greentea_gpu_add_scalar(this->device_->id(), count, - Dtype(diff_scale_ * shift_), (cl_mem) bottom_diff, + Dtype(diff_scale_ * shift_), + (cl_mem) bottom_diff, 0); } } else if (shift_ == Dtype(0)) { diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp index 5015c84fe44..03d0a5c879b 100644 --- a/src/caffe/layers/softmax_layer.cpp +++ b/src/caffe/layers/softmax_layer.cpp @@ -43,7 +43,8 @@ void SoftmaxLayer::Forward_cpu(const vector*>& bottom, for (int_tp i = 0; i < outer_num_; ++i) { // initialize scale_data to the first plane caffe_cpu_copy(inner_num_, bottom_data + i * dim, scale_data); - // start max after the first inner_num values (j=1) since they were just copied + // start max after the first inner_num values + // (j=1) since they were just copied for (int_tp j = 1; j < channels; j++) { for (int_tp k = 0; k < inner_num_; k++) { scale_data[k] = std::max(scale_data[k], diff --git a/src/caffe/solvers/adagrad_solver.cu b/src/caffe/solvers/adagrad_solver.cu index 18065a4e607..ee11bac5bb9 100644 --- a/src/caffe/solvers/adagrad_solver.cu +++ b/src/caffe/solvers/adagrad_solver.cu @@ -38,7 +38,9 @@ void adagrad_update_gpu(device* dev, int_tp N, Dtype* g, Dtype* h, Dtype delta, CL_KERNEL_SELECT("ada_grad_update")); viennacl::ocl::enqueue( oclk_ada_grad_update(N, WrapHandle((cl_mem) g, &ctx), - WrapHandle((cl_mem) h, &ctx), fixup_arg_type(delta), fixup_arg_type(local_rate)), + WrapHandle((cl_mem) h, &ctx), + fixup_arg_type(delta), + fixup_arg_type(local_rate)), ctx.get_queue()); #endif // USE_GREENTEA } diff --git a/src/caffe/solvers/nesterov_solver.cu b/src/caffe/solvers/nesterov_solver.cu index fcb30bd6920..7be2e7e38e5 100644 --- a/src/caffe/solvers/nesterov_solver.cu +++ b/src/caffe/solvers/nesterov_solver.cu @@ -38,7 +38,8 @@ void nesterov_update_gpu(device* dev, int_tp N, Dtype* g, Dtype* h, CL_KERNEL_SELECT("nesterov_update")); viennacl::ocl::enqueue( oclk_nesterov_update(N, WrapHandle((cl_mem) g, &ctx), - WrapHandle((cl_mem) h, &ctx), fixup_arg_type(momentum), + WrapHandle((cl_mem) h, &ctx), + fixup_arg_type(momentum), fixup_arg_type(local_rate)), ctx.get_queue()); #endif // USE_GREENTEA diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index d8f3e3d2afc..9e22732ab0d 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -49,7 +49,8 @@ void CaffeMallocHost(void** ptr, int_tp size, device* dev) { #ifndef USE_GREENTEA *ptr = mkl_malloc(size ? size:1, 64); #else - *ptr = mkl_malloc(size ? ALIGN(size, OPENCL_CACHE_ALIGN) : 64, OPENCL_PAGE_ALIGN); + *ptr = mkl_malloc(size ? ALIGN(size, OPENCL_CACHE_ALIGN) : + 64, OPENCL_PAGE_ALIGN); #endif #else CHECK_EQ(0, posix_memalign(ptr, OPENCL_PAGE_ALIGN, diff --git a/src/caffe/test/test_bias_layer.cpp b/src/caffe/test/test_bias_layer.cpp index 22df442ec23..91532baf384 100644 --- a/src/caffe/test/test_bias_layer.cpp +++ b/src/caffe/test/test_bias_layer.cpp @@ -82,7 +82,7 @@ TYPED_TEST(BiasLayerTest, TestForwardEltwise) { const int_tp count = this->blob_top_->count(); const Dtype* in_data_a = this->blob_bottom_->cpu_data(); const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], @@ -105,7 +105,7 @@ TYPED_TEST(BiasLayerTest, TestForwardEltwiseInPlace) { const int_tp count = this->blob_bottom_->count(); const Dtype* in_data_a = orig_bottom.cpu_data(); const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], @@ -149,7 +149,7 @@ TYPED_TEST(BiasLayerTest, TestBackwardEltwiseInPlace) { caffe_copy(top_diff.count(), top_diff.cpu_data(), this->blob_bottom_->mutable_cpu_diff()); layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], @@ -176,7 +176,7 @@ TYPED_TEST(BiasLayerTest, TestForwardEltwiseWithParam) { const int_tp count = this->blob_top_->count(); const Dtype* in_data_a = this->blob_bottom_->cpu_data(); const Dtype* in_data_b = layer->blobs()[0]->cpu_data(); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], delta); @@ -192,7 +192,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastBegin) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { @@ -217,7 +217,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddle) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { @@ -244,7 +244,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddleInPlace) { shared_ptr > layer(new BiasLayer(layer_param)); layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { @@ -296,7 +296,7 @@ TYPED_TEST(BiasLayerTest, TestBackwardBroadcastMiddleInPlace) { caffe_copy(top_diff.count(), top_diff.cpu_data(), this->blob_bottom_->mutable_cpu_diff()); layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], @@ -319,7 +319,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddleWithParam) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { @@ -343,7 +343,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBroadcastEnd) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp n = 0; n < this->blob_bottom_->num(); ++n) { for (int_tp c = 0; c < this->blob_bottom_->channels(); ++c) { @@ -371,7 +371,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBias) { const int_tp count = this->blob_top_->count(); const Dtype* in_data = this->blob_bottom_->cpu_data(); const Dtype bias = *this->blob_bottom_bias_->cpu_data(); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { EXPECT_NEAR(data[i], in_data[i] + bias, delta); @@ -391,7 +391,7 @@ TYPED_TEST(BiasLayerTest, TestForwardBiasAxis2) { const int_tp count = this->blob_top_->count(); const Dtype* in_data = this->blob_bottom_->cpu_data(); const Dtype bias = *this->blob_bottom_bias_->cpu_data(); - const Dtype delta = std::is_same::value ? + const Dtype delta = std::is_same::value ? 1e-2 : 1e-5; for (int_tp i = 0; i < count; ++i) { EXPECT_NEAR(data[i], in_data[i] + bias, delta); diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 081ab955131..9089a5b08fc 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -276,8 +276,10 @@ class GradientBasedSolverTest : public MultiDeviceTest { Dtype element = 0; for (int k = 0; k < N; ++k) { // (i, k) in X^T (== (k, i) in X) times (k, j) in X. - const Dtype element_i = (i == D) ? Dtype(1) : data.cpu_data()[k * D + i]; - const Dtype element_j = (j == D) ? Dtype(1) : data.cpu_data()[k * D + j]; + const Dtype element_i = (i == D) ? + Dtype(1) : data.cpu_data()[k * D + i]; + const Dtype element_j = (j == D) ? + Dtype(1) : data.cpu_data()[k * D + j]; element += element_i * element_j; } if (j == D) { @@ -287,7 +289,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { } } for (int k = 0; k < N; ++k) { - const Dtype element_i = (i == D) ? Dtype(1) : data.cpu_data()[k * D + i]; + const Dtype element_i = (i == D) ? + Dtype(1) : data.cpu_data()[k * D + i]; grad -= element_i * targets.cpu_data()[k]; } // Scale the gradient over the N samples. diff --git a/src/caffe/test/test_inner_product_layer.cpp b/src/caffe/test/test_inner_product_layer.cpp index f9b37b32a19..fa8a0382c9c 100644 --- a/src/caffe/test/test_inner_product_layer.cpp +++ b/src/caffe/test/test_inner_product_layer.cpp @@ -18,7 +18,7 @@ template class InnerProductLayerTest : public MultiDeviceTest { typedef typename TypeParam::Dtype Dtype; -protected: + protected: InnerProductLayerTest() : blob_bottom_(new Blob(2, 3, 4, 5)), blob_bottom_nobatch_(new Blob(1, 2, 3, 4)), @@ -132,13 +132,13 @@ TYPED_TEST(InnerProductLayerTest, TestForward) { } } +#if 0 TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { typedef typename TypeParam::Dtype Dtype; FillerParameter filler_param; UniformFiller filler(filler_param); caffe::Caffe::SetDevice(0); - #if 0 - for(auto i = 1; i <= 64; i*=2) { + for (auto i = 1; i <= 64; i*=2) { Blob* const blob_bottom = new Blob(i, 392, 8, 8); Blob* const blob_top = new Blob(); filler.Fill(blob_bottom); @@ -168,7 +168,8 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { int_tp K = layer->blobs()[0]->shape(1); if (!std::is_same::value || i <= 2) { - caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, (Dtype)1., A, B, (Dtype)0., C); + caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, + (Dtype)1., A, B, (Dtype)0., C); const Dtype* data = blob_top->cpu_data(); const int_tp count = blob_top->count(); @@ -176,8 +177,7 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { EXPECT_NEAR(data[i], C[i], 1e-1); } } - if (Caffe::mode() == Caffe::GPU) - { + if (Caffe::mode() == Caffe::GPU) { Timer timer; timer.initted(); timer.Start(); @@ -188,15 +188,20 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6) { timer.Stop(); float elapsedTime = timer.MilliSeconds(); elapsedTime /= times; - std::cout << "MNK(" << M << ","<(M*K + K*N + M*N) * sizeof(Dtype) + /elapsedTime / 1e6 + << "GB/s" << std::endl; + std::cout << "FLOPS: " + << static_cast(M*N*(2*K-1)/elapsedTime/1e6) <<"GFLOPS" + << std::endl; } delete blob_bottom; delete blob_top; } - #endif } TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { @@ -204,8 +209,7 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { FillerParameter filler_param; UniformFiller filler(filler_param); caffe::Caffe::SetDevice(0); -#if 0 - for(auto i = 1; i <= 64; i*=2) { + for (auto i = 1; i <= 64; i*=2) { Blob* const blob_bottom = new Blob(i, 25088+1, 1, 1); Blob* const blob_top = new Blob(); filler.Fill(blob_bottom); @@ -235,7 +239,8 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { int_tp K = layer->blobs()[0]->shape(1); if (!std::is_same::value || i <= 2) { - caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, (Dtype)1., A, B, (Dtype)0., C); + caffe_cpu_gemm(CblasNoTrans, CblasTrans, M, N, K, (Dtype)1., + A, B, (Dtype)0., C); const Dtype* data = blob_top->cpu_data(); const int_tp count = blob_top->count(); @@ -244,8 +249,7 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { EXPECT_NEAR(data[i], C[i], 1e-1); } } - if (Caffe::mode() == Caffe::GPU) - { + if (Caffe::mode() == Caffe::GPU) { Timer timer; timer.initted(); timer.Start(); @@ -256,16 +260,22 @@ TYPED_TEST(InnerProductLayerTest, TestForwardVGGFC6_AddEdge) { timer.Stop(); float elapsedTime = timer.MilliSeconds(); elapsedTime /= times; - std::cout << "MNK(" << M << ","<(M*K + K*N + M*N) * sizeof(Dtype) + /elapsedTime / 1e6 + << "GB/s" << std::endl; + std::cout << "FLOPS: " + << static_cast(M*N*(2*K-1)/elapsedTime/1e6) <<"GFLOPS" + << std::endl; } delete blob_bottom; delete blob_top; } -#endif } +#endif template void gemv(const vector > >& A, @@ -314,6 +324,7 @@ template void gemv(const vector > >& A, const float alpha, const float beta); +#if 0 TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC6) { typedef typename TypeParam::Dtype Dtype; @@ -322,7 +333,6 @@ TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC6) { FillerParameter filler_param; UniformFiller filler(filler_param); filler.Fill(blob_bottom); - this->blob_bottom_vec_.clear(); this->blob_bottom_vec_.push_back(blob_bottom); this->blob_top_vec_.clear(); @@ -462,6 +472,7 @@ TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC8) { delete blob_bottom; delete blob_top; } +#endif TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC_dev1) { typedef typename TypeParam::Dtype Dtype; @@ -497,7 +508,6 @@ TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC_dev1) { for (int_tp i = 0; i < count; ++i) { EXPECT_NEAR(data[i], ref_data[i], 1e-1); } - Timer timer; timer.initted(); timer.Start(); @@ -515,7 +525,6 @@ TYPED_TEST(InnerProductLayerTest, TestForwardGemvFC_dev1) { TYPED_TEST(InnerProductLayerTest, TestGEMV) { typedef typename TypeParam::Dtype Dtype; if (Caffe::mode() == Caffe::GPU) { - Blob* const blob_bottom = new Blob(1, 4099, 1, 1); Blob* const blob_top = new Blob(); FillerParameter filler_param; @@ -543,7 +552,7 @@ TYPED_TEST(InnerProductLayerTest, TestGEMV) { Dtype beta = 2; unsigned int M = layer->blobs()[0]->shape(0); unsigned int N = layer->blobs()[0]->shape(1); - //add offset + // add offset unsigned int offA = M * N / 2; unsigned int offx = 0; unsigned int offy = M / 2; @@ -551,7 +560,8 @@ TYPED_TEST(InnerProductLayerTest, TestGEMV) { greentea_gpu_gemv(dc->id(), CblasNoTrans, M, N, alpha, - (cl_mem)layer->blobs()[0]->gpu_data(), offA, (cl_mem)x, + (cl_mem)layer->blobs()[0]->gpu_data(), + offA, (cl_mem)x, offx, beta, (cl_mem)y, offy); gemv(layer->blobs(), offA, M, N, diff --git a/src/caffe/test/test_lrn_layer.cpp b/src/caffe/test/test_lrn_layer.cpp index 67d02149647..4fdbdb8ceff 100644 --- a/src/caffe/test/test_lrn_layer.cpp +++ b/src/caffe/test/test_lrn_layer.cpp @@ -498,7 +498,8 @@ TYPED_TEST(LRNFuseLayerTest, TestForwardAcrossChannelsFusePoolMax) { lrnLayer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); lrnLayer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); LayerParameter pooling_param; - pooling_param.mutable_pooling_param()->set_pool(PoolingParameter_PoolMethod_MAX); + pooling_param.mutable_pooling_param()-> + set_pool(PoolingParameter_PoolMethod_MAX); pooling_param.mutable_pooling_param()->add_kernel_size(3); pooling_param.mutable_pooling_param()->add_stride(2); PoolingLayer pooling_layer(pooling_param); @@ -509,15 +510,20 @@ TYPED_TEST(LRNFuseLayerTest, TestForwardAcrossChannelsFusePoolMax) { // calculate result by lrn fused with pooling layer. LayerParameter fused_layer_param; fused_layer_param.set_phase(TEST); - fused_layer_param.mutable_lrn_param()->set_fuse_type(LRNParameter_FuseType_FUSED_POOL_MAX); + fused_layer_param.mutable_lrn_param()-> + set_fuse_type(LRNParameter_FuseType_FUSED_POOL_MAX); fused_layer_param.mutable_lrn_param()->set_unit_test_mode(true); - fused_layer_param.mutable_lrn_param()->mutable_pooling_param()->set_pool(PoolingParameter_PoolMethod_MAX); - fused_layer_param.mutable_lrn_param()->mutable_pooling_param()->add_kernel_size(3); - fused_layer_param.mutable_lrn_param()->mutable_pooling_param()->add_stride(2); + fused_layer_param.mutable_lrn_param()->mutable_pooling_param()-> + set_pool(PoolingParameter_PoolMethod_MAX); + fused_layer_param.mutable_lrn_param()->mutable_pooling_param()-> + add_kernel_size(3); + fused_layer_param.mutable_lrn_param()->mutable_pooling_param()-> + add_stride(2); bool test_fuse_kernel[2] = {true, false}; for (int_tp index = 0; index < 2; index++) { - fused_layer_param.mutable_lrn_param()->set_unit_test_fuse_kernel(test_fuse_kernel[index]); + fused_layer_param.mutable_lrn_param()-> + set_unit_test_fuse_kernel(test_fuse_kernel[index]); LRNLayer fused_layer(fused_layer_param); fused_layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); fused_layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); @@ -526,7 +532,8 @@ TYPED_TEST(LRNFuseLayerTest, TestForwardAcrossChannelsFusePoolMax) { EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i], this->epsilon_); } - memset(this->blob_top_->mutable_cpu_data(), 0, top_reference.count()); + caffe_set(top_reference.count(), TypeParam(0), + this->blob_top_->mutable_cpu_data()); } } diff --git a/src/caffe/test/test_power_layer.cpp b/src/caffe/test/test_power_layer.cpp index fe68870ed65..805243e7905 100644 --- a/src/caffe/test/test_power_layer.cpp +++ b/src/caffe/test/test_power_layer.cpp @@ -44,8 +44,8 @@ class PowerLayerTest : public MultiDeviceTest { const Dtype* top_data = this->blob_top_->cpu_data(); const Dtype min_precision = std::is_same::value ? 1e-3 : 1e-5; - const Dtype precision_factor = std::is_same::value ? - 1e-2 : 1e-4; + const Dtype precision_factor = + std::is_same::value ? 1e-2 : 1e-4; for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { Dtype expected_value = pow(shift + scale * bottom_data[i], power); if (power == Dtype(0) || power == Dtype(1) || power == Dtype(2)) { diff --git a/src/caffe/test/test_syncedmem.cpp b/src/caffe/test/test_syncedmem.cpp index f3db2fc6621..4c7030f8be1 100644 --- a/src/caffe/test/test_syncedmem.cpp +++ b/src/caffe/test/test_syncedmem.cpp @@ -6,8 +6,8 @@ #include "caffe/util/device_alternate.hpp" #include "caffe/util/math_functions.hpp" -#include "gtest/gtest.h" #include "caffe/test/test_caffe_main.hpp" +#include "gtest/gtest.h" #ifdef USE_GREENTEA #include "caffe/greentea/greentea.hpp" diff --git a/src/caffe/test/test_tanh_layer.cpp b/src/caffe/test/test_tanh_layer.cpp index 4568fd08208..5f053da187f 100644 --- a/src/caffe/test/test_tanh_layer.cpp +++ b/src/caffe/test/test_tanh_layer.cpp @@ -64,7 +64,8 @@ class TanHLayerTest : public MultiDeviceTest { for (int_tp i = 0; i < this->blob_bottom_->count(); ++i) { Dtype expected_value = tanh_naive(bottom_data[i]); Dtype precision = std::max( - Dtype(std::abs(expected_value * Dtype(precision_factor))), min_precision); + Dtype(std::abs(expected_value * Dtype(precision_factor))), + min_precision); EXPECT_NEAR(expected_value, top_data[i], precision); } } diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp index b38f891d15f..d47a6a668e1 100755 --- a/src/caffe/util/hdf5.cpp +++ b/src/caffe/util/hdf5.cpp @@ -91,7 +91,7 @@ void hdf5_load_nd_dataset(hid_t file_id, const char* dataset_name_, hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob, reshape); herr_t status = H5LTread_dataset_short( - file_id, dataset_name_, (short*)blob->mutable_cpu_data()); + file_id, dataset_name_, (int16_t*)(blob->mutable_cpu_data())); CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_; } #endif @@ -121,7 +121,7 @@ template <> void hdf5_save_nd_dataset( const hid_t file_id, const string& dataset_name, const Blob& blob, bool write_diff) { - //FIXME + // FIXME int_tp num_axes = blob.num_axes(); hsize_t *dims = new hsize_t[num_axes]; for (int_tp i = 0; i < num_axes; ++i) { @@ -134,10 +134,9 @@ void hdf5_save_nd_dataset( data = blob.cpu_data(); } herr_t status = H5LTmake_dataset_short( - file_id, dataset_name.c_str(), num_axes, dims, (const short*)(data)); + file_id, dataset_name.c_str(), num_axes, dims, (const int16_t*)(data)); CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name; delete[] dims; - } #endif diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 39501374327..c575a569197 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -88,7 +88,7 @@ void caffe_cpu_axpby(const int_tp N, const half alpha, const half* X, void vhAdd(const int_tp n, const half* a, const half* b, half* y) { - for(int i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { y[i] = a[i] + b[i]; } } @@ -100,7 +100,7 @@ void caffe_add(const int_tp n, const half* a, const half* b, } void vhSub(const int_tp n, const half* a, const half* b, half* y) { - for(int i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { y[i] = a[i] - b[i]; } } @@ -113,7 +113,7 @@ void caffe_sub(const int_tp n, const half* a, const half* b, void vhMul(const int_tp n, const half* a, const half* b, half* y) { - for(int i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { y[i] = a[i] * b[i]; } } @@ -126,7 +126,7 @@ void caffe_mul(const int_tp n, const half* a, const half* b, void vhDiv(const int_tp n, const half* a, const half* b, half* y) { - for(int i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { y[i] = a[i] / b[i]; } } @@ -137,9 +137,8 @@ void caffe_div(const int_tp n, const half* a, const half* b, vhDiv(n, a, b, y); } -void vhPowx(const int_tp n, const half*a, const half b, half* y) -{ - for( int i = 0; i < n; i++) +void vhPowx(const int_tp n, const half*a, const half b, half* y) { + for (int i = 0; i < n; i++) y[i] = pow(a[i], b); } @@ -150,7 +149,7 @@ void caffe_powx(const int_tp n, const half* a, const half b, } void vhSqr(const int_tp n, const half *a, half* y) { - for(int i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { y[i] = sqrt(a[i]); } } @@ -161,7 +160,7 @@ void caffe_sqr(const int_tp n, const half* a, half* y) { } void vhExp(const int_tp n, const half* a, half* y) { - for(int i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { y[i] = exp(a[i]); } } @@ -172,7 +171,7 @@ void caffe_exp(const int_tp n, const half* a, half* y) { } void vhLn(const int_tp n, const half* a, half* y) { - for(int i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { y[i] = log(a[i]); } } @@ -183,7 +182,7 @@ void caffe_log(const int_tp n, const half* a, half* y) { } void vhAbs(const int_tp n, const half *a, half* y) { - for(int i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { y[i] = fabs(a[i]); } } @@ -204,11 +203,13 @@ void caffe_sqrt(const int_tp n, const half* a, half* y) { } template<> -void caffe_rng_uniform(const int_tp n, const half a, const half b, half* r) { +void caffe_rng_uniform(const int_tp n, const half a, + const half b, half* r) { CHECK_GE(n, 0); CHECK(r); CHECK_LE(a, b); - boost::uniform_real random_distribution(float(a), caffe_nextafter(float(b))); + boost::uniform_real random_distribution(static_cast(a), + caffe_nextafter(static_cast(b))); boost::variate_generator> variate_generator( @@ -249,7 +250,7 @@ void caffe_rng_bernoulli(const int_tp n, const half p, boost::bernoulli_distribution> variate_generator( caffe_rng(), random_distribution); for (int_tp i = 0; i < n; ++i) { - //r[i] = static_cast(variate_generator()); + r[i] = static_cast(variate_generator()); } } template<> @@ -266,7 +267,7 @@ void caffe_rng_bernoulli(const int_tp n, const half p, boost::bernoulli_distribution> variate_generator( caffe_rng(), random_distribution); for (int_tp i = 0; i < n; ++i) { - //r[i] = static_cast(variate_generator()); + r[i] = static_cast(variate_generator()); } } @@ -275,7 +276,6 @@ void caffe_cpu_scale(const int_tp n, const half alpha, const half *x, half* y) { for (int_tp i = 0; i < n; i++) y[i] = x[i]; - //cblas_hcopy(n, x, 1, y, 1); caffe_scal(n, alpha, y); } From e38c6195c529e0a88c8b4b2c4281c59c76238124 Mon Sep 17 00:00:00 2001 From: Zhigang Gong Date: Thu, 6 Jul 2017 22:22:45 +0800 Subject: [PATCH 33/33] Move half.hpp's license to 3rdparty/half. --- include/3rdparty/{ => half}/LICENSE | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/3rdparty/{ => half}/LICENSE (100%) diff --git a/include/3rdparty/LICENSE b/include/3rdparty/half/LICENSE similarity index 100% rename from include/3rdparty/LICENSE rename to include/3rdparty/half/LICENSE