@ -656,10 +656,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
# define QK8_0 32
typedef struct {
float d ; // delta
float s ; // d * sum(qs[i])
float s0 ; // d * sum(qs[i]) low
float s1 ; // d * sum(qs[i]) high
int8_t qs [ QK8_0 ] ; // quants
} block_q8_0 ;
static_assert ( sizeof ( block_q8_0 ) = = 2 * sizeof ( float ) + QK8_0 , " wrong q8_0 block size/padding " ) ;
static_assert ( sizeof ( block_q8_0 ) = = 3 * sizeof ( float ) + QK8_0 , " wrong q8_0 block size/padding " ) ;
// reference implementation for deterministic creation of model files
@ -1299,13 +1300,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
y [ i ] . d = d ;
int sum = 0 ;
for ( int l = 0 ; l < QK8_0 ; + + l ) {
const float v = x [ i * QK8_0 + l ] * id ;
y [ i ] . qs [ l ] = roundf ( v ) ;
sum + = y [ i ] . qs [ l ] ;
int sum0 = 0 ;
int sum1 = 0 ;
for ( int l = 0 ; l < QK8_0 / 2 ; + + l ) {
const float v0 = x [ i * QK8_0 + l ] * id ;
const float v1 = x [ i * QK8_0 + QK8_0 / 2 + l ] * id ;
y [ i ] . qs [ l ] = roundf ( v0 ) ;
y [ i ] . qs [ QK8_0 / 2 + l ] = roundf ( v1 ) ;
sum0 + = y [ i ] . qs [ l ] ;
sum1 + = y [ i ] . qs [ QK8_0 / 2 + l ] ;
}
y [ i ] . s = d * sum ;
y [ i ] . s0 = d * sum0 ;
y [ i ] . s1 = d * sum1 ;
}
}
@ -1335,9 +1345,24 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
y [ i ] . d = d ;
int32x4_t accv = vdupq_n_s32 ( 0 ) ;
int32x4_t accv0 = vdupq_n_s32 ( 0 ) ;
int32x4_t accv1 = vdupq_n_s32 ( 0 ) ;
for ( int l = 0 ; l < 8 ; l + + ) {
// low half
for ( int l = 0 ; l < 4 ; l + + ) {
const float32x4_t v = vmulq_n_f32 ( srcv [ l ] , id ) ;
const int32x4_t vi = vcvtnq_s32_f32 ( v ) ;
y [ i ] . qs [ 4 * l + 0 ] = vgetq_lane_s32 ( vi , 0 ) ;
y [ i ] . qs [ 4 * l + 1 ] = vgetq_lane_s32 ( vi , 1 ) ;
y [ i ] . qs [ 4 * l + 2 ] = vgetq_lane_s32 ( vi , 2 ) ;
y [ i ] . qs [ 4 * l + 3 ] = vgetq_lane_s32 ( vi , 3 ) ;
accv0 = vaddq_s32 ( accv0 , vi ) ;
}
// high half
for ( int l = 4 ; l < 8 ; l + + ) {
const float32x4_t v = vmulq_n_f32 ( srcv [ l ] , id ) ;
const int32x4_t vi = vcvtnq_s32_f32 ( v ) ;
@ -1346,12 +1371,17 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
y [ i ] . qs [ 4 * l + 2 ] = vgetq_lane_s32 ( vi , 2 ) ;
y [ i ] . qs [ 4 * l + 3 ] = vgetq_lane_s32 ( vi , 3 ) ;
accv = vaddq_s32 ( accv , vi ) ;
accv 1 = vaddq_s32 ( accv 1 , vi ) ;
}
int32_t sum = vaddvq_s32 ( accv ) ;
y [ i ] . s = d * sum ;
const int32_t sum0 = vaddvq_s32 ( accv0 ) ;
const int32_t sum1 = vaddvq_s32 ( accv1 ) ;
y [ i ] . s0 = d * sum0 ;
y [ i ] . s1 = d * sum1 ;
}
# elif defined(__AVX2__) || defined(__AVX__)
// TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
for ( int i = 0 ; i < nb ; i + + ) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps ( x ) ;
@ -1398,7 +1428,9 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
# if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s
y [ i ] . s = d * hsum_i32_8 ( _mm256_add_epi32 ( _mm256_add_epi32 ( i0 , i1 ) , _mm256_add_epi32 ( i2 , i3 ) ) ) ;
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
y [ i ] . s0 = d * hsum_i32_8 ( _mm256_add_epi32 ( i0 , i1 ) ) ;
y [ i ] . s1 = d * hsum_i32_8 ( _mm256_add_epi32 ( i2 , i3 ) ) ;
// Convert int32 to int16
i0 = _mm256_packs_epi32 ( i0 , i1 ) ; // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@ -2395,7 +2427,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const block_q8_0 * restrict y0 = & y [ i + 0 ] ;
const block_q8_0 * restrict y1 = & y [ i + 1 ] ;
sum8 + = x0 - > d * y0 - > s + x1 - > d * y1 - > s ;
sum8 + = x0 - > d * ( y0 - > s 0 + y0 - > s1 ) + x1 - > d * ( y1 - > s 0 + y1 - > s1 ) ;
const uint8x16_t m4b = vdupq_n_u8 ( 0xf ) ;
@ -2562,7 +2594,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const block_q8_0 * restrict y0 = & y [ i + 0 ] ;
const block_q8_0 * restrict y1 = & y [ i + 1 ] ;
summs + = x0 - > m * y0 - > s + x1 - > m * y1 - > s ;
summs + = x0 - > m * ( y0 - > s 0 + y0 - > s1 ) + x1 - > m * ( y1 - > s 0 + y1 - > s1 ) ;
const uint8x16_t m4b = vdupq_n_u8 ( 0xf ) ;
@ -2575,22 +2607,22 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const int8x16_t v0_1l = vreinterpretq_s8_u8 ( vandq_u8 ( v0_1 , m4b ) ) ;
const int8x16_t v0_1h = vreinterpretq_s8_u8 ( vshrq_n_u8 ( v0_1 , 4 ) ) ;
// interleave
const int8x16_t v0_0lz = vzip1q_s8 ( v0_0l , v0_0h ) ;
const int8x16_t v0_0hz = vzip2q_s8 ( v0_0l , v0_0h ) ;
const int8x16_t v0_1lz = vzip1q_s8 ( v0_1l , v0_1h ) ;
const int8x16_t v0_1hz = vzip2q_s8 ( v0_1l , v0_1h ) ;
// load y
const int8x16_t v1_0l = vld1q_s8 ( y0 - > qs ) ;
const int8x16_t v1_0h = vld1q_s8 ( y0 - > qs + 16 ) ;
const int8x16_t v1_1l = vld1q_s8 ( y1 - > qs ) ;
const int8x16_t v1_1h = vld1q_s8 ( y1 - > qs + 16 ) ;
// interleave
const int8x16_t v1_0ls = vuzp1q_s8 ( v1_0l , v1_0h ) ;
const int8x16_t v1_0hs = vuzp2q_s8 ( v1_0l , v1_0h ) ;
const int8x16_t v1_1ls = vuzp1q_s8 ( v1_1l , v1_1h ) ;
const int8x16_t v1_1hs = vuzp2q_s8 ( v1_1l , v1_1h ) ;
# if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32 ( vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_0l , v1_0l s ) , v0_0h , v1_0h s ) ;
const int32x4_t p_1 = vdotq_s32 ( vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_1l , v1_1l s ) , v0_1h , v1_1h s ) ;
const int32x4_t p_0 = vdotq_s32 ( vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_0lz , v1_0l ) , v0_0hz , v1_0h ) ;
const int32x4_t p_1 = vdotq_s32 ( vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_1lz , v1_1l ) , v0_1hz , v1_1h ) ;
sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( p_0 ) , x0 - > d * y0 - > d ) ;
sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( p_1 ) , x1 - > d * y1 - > d ) ;
@ -2627,7 +2659,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const float * d0 = & x [ i ] . d ;
const float * d1 = & y [ i ] . d ;
summs + = x [ i ] . m * y [ i ] . s ;
summs + = x [ i ] . m * ( y [ i ] . s 0 + y [ i ] . s1 ) ;
const __m256 d0v = _mm256_broadcast_ss ( d0 ) ;
const __m256 d1v = _mm256_broadcast_ss ( d1 ) ;
@ -2845,88 +2877,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
float32x4_t sumv0 = vdupq_n_f32 ( 0.0f ) ;
float32x4_t sumv1 = vdupq_n_f32 ( 0.0f ) ;
for ( int i = 0 ; i < nb ; i + = 2 ) {
float summs0 = 0.0f ;
float summs1 = 0.0f ;
for ( int i = 0 ; i < nb ; + + i ) {
const block_q4_3 * restrict x0_0 = & x [ 2 * ( i + 0 ) + 0 ] ;
const block_q4_3 * restrict x0_1 = & x [ 2 * ( i + 0 ) + 1 ] ;
const block_q4_3 * restrict x1_0 = & x [ 2 * ( i + 1 ) + 0 ] ;
const block_q4_3 * restrict x1_1 = & x [ 2 * ( i + 1 ) + 1 ] ;
const block_q8_0 * restrict y0 = & y [ i + 0 ] ;
const block_q8_0 * restrict y1 = & y [ i + 1 ] ;
const uint8x16_t m4b = vdupq_n_u8 ( 0xf ) ;
const float x0_0d = GGML_FP16_TO_FP32 ( x0_0 - > d ) ;
const float x0_1d = GGML_FP16_TO_FP32 ( x0_1 - > d ) ;
const float x1_0d = GGML_FP16_TO_FP32 ( x1_0 - > d ) ;
const float x1_1d = GGML_FP16_TO_FP32 ( x1_1 - > d ) ;
const float x0_0m = GGML_FP16_TO_FP32 ( x0_0 - > m ) ;
const float x0_1m = GGML_FP16_TO_FP32 ( x0_1 - > m ) ;
const float x1_0m = GGML_FP16_TO_FP32 ( x1_0 - > m ) ;
const float x1_1m = GGML_FP16_TO_FP32 ( x1_1 - > m ) ;
summs0 + = GGML_FP16_TO_FP32 ( x0_0 - > m ) * y0 - > s0 ;
summs1 + = GGML_FP16_TO_FP32 ( x0_1 - > m ) * y0 - > s1 ;
const uint8x16_t v0_0 = vcombine_u8 ( vld1_u8 ( x0_0 - > qs ) , vld1_u8 ( x0_1 - > qs ) ) ;
const uint8x16_t v0_1 = vcombine_u8 ( vld1_u8 ( x1_0 - > qs ) , vld1_u8 ( x1_1 - > qs ) ) ;
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8 ( vandq_u8 ( v0_0 , m4b ) ) ;
const int8x16_t v0_0l = vreinterpretq_s8_u8 ( vandq_u8 ( v0_0 , vdupq_n_u8 ( 0xf ) ) ) ;
const int8x16_t v0_0h = vreinterpretq_s8_u8 ( vshrq_n_u8 ( v0_0 , 4 ) ) ;
const int8x16_t v0_1l = vreinterpretq_s8_u8 ( vandq_u8 ( v0_1 , m4b ) ) ;
const int8x16_t v0_1h = vreinterpretq_s8_u8 ( vshrq_n_u8 ( v0_1 , 4 ) ) ;
// interleave
const int8x16_t v0_0lz = vzip1q_s8 ( v0_0l , v0_0h ) ;
const int8x16_t v0_0hz = vzip2q_s8 ( v0_0l , v0_0h ) ;
const int8x16_t v0_1lz = vzip1q_s8 ( v0_1l , v0_1h ) ;
const int8x16_t v0_1hz = vzip2q_s8 ( v0_1l , v0_1h ) ;
// load y
const int8x16_t v1_0l = vld1q_s8 ( y0 - > qs ) ;
const int8x16_t v1_0h = vld1q_s8 ( y0 - > qs + 16 ) ;
const int8x16_t v1_1l = vld1q_s8 ( y1 - > qs ) ;
const int8x16_t v1_1h = vld1q_s8 ( y1 - > qs + 16 ) ;
const int16x8_t sy0_0 = vaddq_s16 ( vmovl_s8 ( vget_low_s8 ( v1_0l ) ) , vmovl_s8 ( vget_high_s8 ( v1_0l ) ) ) ;
const int16x8_t sy0_1 = vaddq_s16 ( vmovl_s8 ( vget_low_s8 ( v1_0h ) ) , vmovl_s8 ( vget_high_s8 ( v1_0h ) ) ) ;
const int16x8_t sy1_0 = vaddq_s16 ( vmovl_s8 ( vget_low_s8 ( v1_1l ) ) , vmovl_s8 ( vget_high_s8 ( v1_1l ) ) ) ;
const int16x8_t sy1_1 = vaddq_s16 ( vmovl_s8 ( vget_low_s8 ( v1_1h ) ) , vmovl_s8 ( vget_high_s8 ( v1_1h ) ) ) ;
sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( vaddl_s16 ( vget_low_s16 ( sy0_0 ) , vget_high_s16 ( sy0_0 ) ) ) , x0_0m * y0 - > d ) ;
sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( vaddl_s16 ( vget_low_s16 ( sy0_1 ) , vget_high_s16 ( sy0_1 ) ) ) , x0_1m * y0 - > d ) ;
sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( vaddl_s16 ( vget_low_s16 ( sy1_0 ) , vget_high_s16 ( sy1_0 ) ) ) , x1_0m * y1 - > d ) ;
sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( vaddl_s16 ( vget_low_s16 ( sy1_1 ) , vget_high_s16 ( sy1_1 ) ) ) , x1_1m * y1 - > d ) ;
const float x0_0d = GGML_FP16_TO_FP32 ( x0_0 - > d ) ;
const float x0_1d = GGML_FP16_TO_FP32 ( x0_1 - > d ) ;
# if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_0lz , v1_0l ) ) , x0_0d * y0 - > d ) ;
sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_0hz , v1_0h ) ) , x0_1d * y0 - > d ) ;
sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_1lz , v1_1l ) ) , x1_0d * y1 - > d ) ;
sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_1hz , v1_1h ) ) , x1_1d * y1 - > d ) ;
sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_0hz , v1_0h ) ) , x0_1d * y0 - > d ) ;
# else
const int16x8_t pl0l = vmull_s8 ( vget_low_s8 ( v0_0lz ) , vget_low_s8 ( v1_0l ) ) ;
const int16x8_t pl0h = vmull_s8 ( vget_high_s8 ( v0_0lz ) , vget_high_s8 ( v1_0l ) ) ;
const int16x8_t ph0l = vmull_s8 ( vget_low_s8 ( v0_0hz ) , vget_low_s8 ( v1_0h ) ) ;
const int16x8_t ph0h = vmull_s8 ( vget_high_s8 ( v0_0hz ) , vget_high_s8 ( v1_0h ) ) ;
const int16x8_t pl1l = vmull_s8 ( vget_low_s8 ( v0_1lz ) , vget_low_s8 ( v1_1l ) ) ;
const int16x8_t pl1h = vmull_s8 ( vget_high_s8 ( v0_1lz ) , vget_high_s8 ( v1_1l ) ) ;
const int16x8_t ph1l = vmull_s8 ( vget_low_s8 ( v0_1hz ) , vget_low_s8 ( v1_1h ) ) ;
const int16x8_t ph1h = vmull_s8 ( vget_high_s8 ( v0_1hz ) , vget_high_s8 ( v1_1h ) ) ;
const int32x4_t pl0 = vaddq_s32 ( vpaddlq_s16 ( pl0l ) , vpaddlq_s16 ( pl0h ) ) ;
const int32x4_t ph0 = vaddq_s32 ( vpaddlq_s16 ( ph0l ) , vpaddlq_s16 ( ph0h ) ) ;
const int32x4_t pl1 = vaddq_s32 ( vpaddlq_s16 ( pl1l ) , vpaddlq_s16 ( pl1h ) ) ;
const int32x4_t ph1 = vaddq_s32 ( vpaddlq_s16 ( ph1l ) , vpaddlq_s16 ( ph1h ) ) ;
sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( pl0 ) , x0_0d * y0 - > d ) ;
sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( ph0 ) , x0_1d * y0 - > d ) ;
sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( pl1 ) , x1_0d * y1 - > d ) ;
sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( ph1 ) , x1_1d * y1 - > d ) ;
sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( ph0 ) , x0_1d * y0 - > d ) ;
# endif
}
* s = vaddvq_f32 ( sumv0 ) + vadd v q_f32( sumv 1) ;
* s = vaddvq_f32 ( vadd q_f32( sumv 0, sumv 1) ) + summs0 + summs1 ;
# elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps ( ) ;
@ -2971,9 +2968,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
const float d1 = GGML_FP16_TO_FP32 ( x [ 2 * i + 1 ] . d ) ;
const float m1 = GGML_FP16_TO_FP32 ( x [ 2 * i + 1 ] . m ) ;
int sy_0 = 0 ;
int sy_1 = 0 ;
int sxy_0 = 0 ;
int sxy_1 = 0 ;
@ -2993,15 +2987,11 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
const int y0_1 = y0 [ 2 * ( j + QK8_0 / 4 ) + 0 ] ;
const int y1_1 = y0 [ 2 * ( j + QK8_0 / 4 ) + 1 ] ;
sy_0 + = y0_0 + y1_0 ;
sy_1 + = y0_1 + y1_1 ;
sxy_0 + = x0_0 * y0_0 + x1_0 * y1_0 ;
sxy_1 + = x0_1 * y0_1 + x1_1 * y1_1 ;
}
sumf + = ( d0 * sxy_0 + m0 * sy_0 ) * y [ i ] . d ;
sumf + = ( d1 * sxy_1 + m1 * sy_1 ) * y [ i ] . d ;
sumf + = ( d0 * sxy_0 + d1 * sxy_1 ) * y [ i ] . d + m0 * y [ i ] . s0 + m1 * y [ i ] . s1 ;
}
* s = sumf ;
# endif