@ -1359,8 +1359,8 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
const int8x16_t v0_1hs = vsubq_s8 ( v0_1h , s8b ) ;
const int8x16_t v0_1hs = vsubq_s8 ( v0_1h , s8b ) ;
const int8x16_t v1_1hs = vsubq_s8 ( v1_1h , s8b ) ;
const int8x16_t v1_1hs = vsubq_s8 ( v1_1h , s8b ) ;
# if defined(__ARM_FEATURE_DOTPROD)
// dot product into int16x8_t
// dot product into int16x8_t
// assume that vdotq_s32 is always available, if not, should check for __ARM_FEATURE_DOTPROD
int32x4_t p_0 = vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_0ls , v1_0ls ) ;
int32x4_t p_0 = vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_0ls , v1_0ls ) ;
int32x4_t p_1 = vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_1ls , v1_1ls ) ;
int32x4_t p_1 = vdotq_s32 ( vdupq_n_s32 ( 0 ) , v0_1ls , v1_1ls ) ;
@ -1374,6 +1374,37 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
# else
# else
sum0 + = d0_0 * d1_0 * ( vgetq_lane_s32 ( p_0 , 0 ) + vgetq_lane_s32 ( p_0 , 1 ) + vgetq_lane_s32 ( p_0 , 2 ) + vgetq_lane_s32 ( p_0 , 3 ) ) ;
sum0 + = d0_0 * d1_0 * ( vgetq_lane_s32 ( p_0 , 0 ) + vgetq_lane_s32 ( p_0 , 1 ) + vgetq_lane_s32 ( p_0 , 2 ) + vgetq_lane_s32 ( p_0 , 3 ) ) ;
sum1 + = d0_1 * d1_1 * ( vgetq_lane_s32 ( p_1 , 0 ) + vgetq_lane_s32 ( p_1 , 1 ) + vgetq_lane_s32 ( p_1 , 2 ) + vgetq_lane_s32 ( p_1 , 3 ) ) ;
sum1 + = d0_1 * d1_1 * ( vgetq_lane_s32 ( p_1 , 0 ) + vgetq_lane_s32 ( p_1 , 1 ) + vgetq_lane_s32 ( p_1 , 2 ) + vgetq_lane_s32 ( p_1 , 3 ) ) ;
# endif
# else
const int16x8_t pl0l = vmull_s8 ( vget_low_s8 ( v0_0ls ) , vget_low_s8 ( v1_0ls ) ) ;
const int16x8_t pl0h = vmull_s8 ( vget_high_s8 ( v0_0ls ) , vget_high_s8 ( v1_0ls ) ) ;
const int16x8_t ph0l = vmull_s8 ( vget_low_s8 ( v0_0hs ) , vget_low_s8 ( v1_0hs ) ) ;
const int16x8_t ph0h = vmull_s8 ( vget_high_s8 ( v0_0hs ) , vget_high_s8 ( v1_0hs ) ) ;
const int16x8_t pl1l = vmull_s8 ( vget_low_s8 ( v0_1ls ) , vget_low_s8 ( v1_1ls ) ) ;
const int16x8_t pl1h = vmull_s8 ( vget_high_s8 ( v0_1ls ) , vget_high_s8 ( v1_1ls ) ) ;
const int16x8_t ph1l = vmull_s8 ( vget_low_s8 ( v0_1hs ) , vget_low_s8 ( v1_1hs ) ) ;
const int16x8_t ph1h = vmull_s8 ( vget_high_s8 ( v0_1hs ) , vget_high_s8 ( v1_1hs ) ) ;
const int16x8_t pl_0 = vaddq_s16 ( pl0l , pl0h ) ;
const int16x8_t ph_0 = vaddq_s16 ( ph0l , ph0h ) ;
const int16x8_t pl_1 = vaddq_s16 ( pl1l , pl1h ) ;
const int16x8_t ph_1 = vaddq_s16 ( ph1l , ph1h ) ;
const int16x8_t p_0 = vaddq_s16 ( pl_0 , ph_0 ) ;
const int16x8_t p_1 = vaddq_s16 ( pl_1 , ph_1 ) ;
// scalar
# if defined(__ARM_FEATURE_QRDMX)
sum0 + = d0_0 * d1_0 * vaddvq_s16 ( p_0 ) ;
sum1 + = d0_1 * d1_1 * vaddvq_s16 ( p_1 ) ;
# else
sum0 + = d0_0 * d1_0 * ( vgetq_lane_s16 ( p_0 , 0 ) + vgetq_lane_s16 ( p_0 , 1 ) + vgetq_lane_s16 ( p_0 , 2 ) + vgetq_lane_s16 ( p_0 , 3 ) + vgetq_lane_s16 ( p_0 , 4 ) + vgetq_lane_s16 ( p_0 , 5 ) + vgetq_lane_s16 ( p_0 , 6 ) + vgetq_lane_s16 ( p_0 , 7 ) ) ;
sum1 + = d0_1 * d1_1 * ( vgetq_lane_s16 ( p_1 , 0 ) + vgetq_lane_s16 ( p_1 , 1 ) + vgetq_lane_s16 ( p_1 , 2 ) + vgetq_lane_s16 ( p_1 , 3 ) + vgetq_lane_s16 ( p_1 , 4 ) + vgetq_lane_s16 ( p_1 , 5 ) + vgetq_lane_s16 ( p_1 , 6 ) + vgetq_lane_s16 ( p_1 , 7 ) ) ;
# endif
# endif
# endif
}
}