@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
}
}
if ( top_k > 0 & & top_k < n_logits ) {
sample_top_k ( logits_id , top_k ) ;
}
float maxl = - std : : numeric_limits < float > : : infinity ( ) ;
for ( const auto & kv : logits_id ) {
maxl = Max ( maxl , kv . first ) ;
}
sample_top_k ( logits_id , top_k > 0 ? Min ( top_k , n_logits ) : n_logits ) ;
// compute probs for the top k tokens
std : : vector < float > probs ;
probs . reserve ( logits_id . size ( ) ) ;
float maxl = logits_id [ 0 ] . first ;
double sum = 0.0 ;
for ( const auto & kv : logits_id ) {
const float p = expf ( kv . first - maxl ) ;
@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k(
break ;
}
}
cumsum = 1.0 / cumsum ;
for ( int i = 0 ; i < ( int ) probs . size ( ) ; i + + ) {
probs [ i ] * = cumsum ;
}
}
//printf("\n");
//for (int i = 0; i < (int) 10; i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
// printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
//}
//printf("\n\n");
//exit(0);