diff --git a/ggml.c b/ggml.c index ffd54ec..8e051dd 100644 --- a/ggml.c +++ b/ggml.c @@ -1833,7 +1833,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest const block_q4_0 * restrict x = vx; const block_q4_0 * restrict y = vy; - ggml_float sumf = 0.0; + float sumf = 0.0; #if defined(__ARM_NEON) float sum0 = 0.0f; @@ -1928,7 +1928,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest #endif } - sumf = (ggml_float)(sum0 + sum1); + sumf = sum0 + sum1; #elif defined(__AVX512F__) // Initialize accumulator with zeros __m512 acc0 = _mm512_setzero_ps(); @@ -1962,6 +1962,10 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest __m256 acc = _mm256_setzero_ps(); // Main loop + // TODO: figure a way to do this in a portable way + #ifdef __GNUC__ + #pragma GCC unroll 16 + #endif for (int i = 0; i < nb; ++i) { // Compute combined scale for the block const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); @@ -1975,20 +1979,21 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest bx = _mm256_sub_epi8( bx, off ); by = _mm256_sub_epi8( by, off ); - // Sign-extend first 16 signed bytes into int16_t - __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); - __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - // Compute products of int16_t integers, add pairwise - __m256i i32 = _mm256_madd_epi16( x16, y16 ); + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(bx, bx); - // Sign-extend last 16 signed bytes into int16_t vectors - x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); - y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - // Accumulate products of int16_t integers - i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) ); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(by, bx); + + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + + const __m256i ones = _mm256_set1_epi16(1); + const __m256i i32 = _mm256_madd_epi16(ones, dot); // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps( i32 ); + const __m256 p = _mm256_cvtepi32_ps( i32 ); + // Apply the scale, and accumulate acc = _mm256_fmadd_ps( d, p, acc ); }