Code & Stream

CODE & STREAM
ABOUT | CONTACT


Top-P (Nucleus Sampling) & the Curious Case of Neural Text Degeneration

One of my favorite parts of my job is getting to answer random questions from our MVPs (Microsoft Valuable Professionals). For context, at Microsoft there are internal distribution lists that folks across the product teams, and engineering monitor which MVPs get access to ask questions. Often times questions/answers fall into the NDA only category so I can't talk about them publicly, but every once in awhile a fun technical question comes along that is in no way confidential, but isn't something that we would ever cover in the official Microsoft documentaion. This was one of those questions. I will scrub out any identifying details, and I may do some light editing, but for the most part below are my responses to the MVP.

Background

If you aren't familar, Top-P is one of the settings like temperature that APIs from LLM providers allow you modify to change how a model will respond.

OpenAI:
"An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both."


Anthropic:
In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both.

Recommended for advanced use cases only. You usually only need to use temperature.

maximum 1
minimum 0
MVP Question:

I am trying to understand more about top_p. Why does setting top_p to 0 still generate tokens? I thought setting it to zero would result in no response, but instead I get results.
My Answer:

If you are interested in understanding how top_p works at a deeper level as well as a bit of the history behind it, I would start with this paper where top_p (also known as nucleus sampling) was first introduced back in 2019:

The Curious Case of Neural Text Degeneration

While each LLM provider might have their own slightly unique internal implementation and may handle edge cases differently in addition to how top_p will interact with other parameters, fundamentally top_p (nucleus sampling) features will be based on the approach described in the paper.

To make things more concrete for myself, I found it helpful to put together an adapted toy version of some of the open source implementations of the algorithm described in the paper and then stepped through the execution with different values for p:

Top-p sampling code overview

Full code with comments included at the end of my message.

Code output example 1
Code output example 2

At least for me, it made it easier for me to see how setting top_p to 0 will end up essentially turning nucleus sampling into greedy sampling assuming a similar implementation to the one above is used.

Jeremy Howard touches on the Curious Case paper and shares a different implementation of nucleus sampling that takes a dependency on pytorch in this part of one his lectures which might be helpful to look at as well:

https://youtu.be/3oEb_fFmPnY?si=l2GU1PMnFnzen1Id&t=960

Jeremy Howard lecture screenshot

My toy code from above:

#########################################
import numpy as np

# Ensure p is between 0 and 1
p = .95

# Create sample logits (1000 random values between -10 and 10)
np.random.seed(42)  # for reproducibility
logits = np.random.uniform(-10, 10, 1000)

# Convert logits to probabilities
probs = np.exp(logits) / np.sum(np.exp(logits)) #softmax(x_i) = exp(x_i) / Σ exp(x_j)

# Sort probabilities in descending order
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]

# Calculate cumulative probabilities
cumulative_probs = np.cumsum(sorted_probs)

# Find the index where cumulative probability exceeds p
cutoff_idx = np.where(cumulative_probs > p)[0][0]
sorted_indices_to_remove = sorted_indices[cutoff_idx+1:]

# Zero out probabilities for tokens outside the nucleus
probs[sorted_indices_to_remove] = 0.0

# Re-normalize probabilities within the nucleus
probs /= np.sum(probs)

# Sample index according to the truncated+renormalized probability distribution
sampled_index = np.random.choice(len(probs), p=probs)

print(f"p = {p}")
print(f"Number of tokens in nucleus: {len(probs) - len(sorted_indices_to_remove)}")
print(f"Sampled index: {sampled_index}")
print(f"Probability of sampled index: {probs[sampled_index]:.4f}")

E-mail #2:

The MVP responded with a follow-up question that they were still confused:

Back to my original question: There appears to be consensus that Top-P = 0.0 empirically behaves like greedy sampling, but with no specific mechanism identified.

My response:

I think the best way of thinking about it is just that the behavior is a side effect of how the mathematical expressions get evaluated in common implementations of nucleus sampling when p = 0.

The intuition that top_p = 0 would exclude all tokens is a good one, but the reality when you look at what happens to p=0 in an implementation of the nucleus sampling algorithm is different. For common implementations p=0 will always take the highest probability token.

To understand what is happening you need to focus on this section of code where things start to get fun: (You could also write the next three lines as cutoff_idx = np.where(cumulative_probs > p)[0][0], but I broke it up to hopefully make more clear what is happening)

indices_where_true = np.where(cumulative_probs > p)
first_array = indices_where_true[0]
cutoff_idx = first_array[0]

The problem here is when p = 0 every single cumulative_probs value will be greater than 0 and therefore indices_where_true will consist of every possible index of probability values/tokens in the distribution with our highest probability token in the 0 index position since we normalized and presorted from highest to lowest probability.

The lines of code that are supposed to find the point to trim the tail of lower probability tokens when given a p of 0 will set a cutoff index at the highest prob token index.

Then:

sorted_indices_to_remove = sorted_indices[cutoff_idx+1:]

We take the highest probability token and go one index further [cutoff_idx+1] and trim off the tail which will consist of every other lower probability/token. So we now have only one token left in the nucleus and have greedy sampling since the only remaining token is index 0 where the highest probability token lives.

I will however still hedge and say I can't guarantee how any given LLM providers exact internal implementation will behave. OpenAI hasn't publicly released their implementation of nucleus sampling so by extension we can't do that with Azure OpenAI either. I can only say that my toy example, Jeremy Howard's pytorch implementation, and the original implementation from the shared repo for the Curious Case paper would all handle p=0 in the way I described above.

Since others have found this useful I will quick walk through this step-by-step using my toy code, please feel free to treat this as TLDR. There are also others on this dl who I am sure can explain this much better than I can:

First let's look at an example for p set to .9 to give a better sense of what things can look like with a GPT-2 size vocabulary. GPT-2's vocab has 50257 tokens.

We'll set p to .9. and will generate 50257 logits between -10 and 10: (we'll use a seed so anyone can re-run and get the same results)

Setting up logits generation

The logits variable now contains values that look like:

Logits array values

If helpful we can visualize the random logits to get a better sense of the distribution

Logits distribution visualization

Then we will convert the logits to probabilities using the softmax function:

Softmax conversion code

This would then look as follows:

Probabilities array

Probabilities summary

Note the current max probability at index 0:

Max probability at index 0

This sorting creates the following distribution:

Sorted probability distribution

Now we take the cumulative sum of our probabilities:

Cumulative sum code

Cumulative probabilities array

Find the indexes where cumulative_probs is greater than p. Note that this is a subset of our 50257 vocab of logits/tokens:

Finding indices where cumulative > p

Indices array result

Cutoff index calculation

Additional index information

So we will take our original distribution set the cut-off index (dotted line) and move + 1 index forward and remove the low probability tail (grey):

Distribution with cutoff visualization

Zero out probabilities outside of the nucleus:

Zero out code

Zeroed out distribution visualization

If we grab the sum of probs after zeroing out values outside the nucleus you can see we are no longer normalized with all values adding up to 1:

Sum before renormalization

So we then renormalize our probabilities:

Renormalization code

Renormalized probabilities

Final probabilities array

Number of tokens in nucleus: 5751
Sampled index: 2243
Probability of sampled index: 0.0004


Now if we have all the same code but set p = 0. Everything up to this point in the code will be the same we haven't used p yet:

Same sorted distribution for p=0

We'll still compute the same cumulative_sum:

Cumulative sum code

Cumulative probabilities

Note that index 0 of our cumulative sum is our highest probability token.

Now when we calculate:

indices_where_true = np.where(cumulative_probs > p)

We end up with an array that hasn't shrunk in size at all because every index in cumulative_probs has a value greater than 0. Unlike when we did this for p=.9 where we ended up with a subset of our original array.

All indices included for p=0

And now when we calculate our cutoff index we take index 0, and then we remove everything after that [cutoff_idx+1:].

Cutoff at index 0

Only one token remains

From this point on the remaining code becomes almost meaningless in that we are guaranteed to sample this token since there is only one token in the nucleus which is was our highest probability token.


For some examples of how the distribution of in nucleus, versus dropped/truncated for other p values might look I did a couple more runs and ran the results through matplotlib (Keep in mind we are using randomly generated logits, this is not the same as what you would get without random generation, but it still helps in understanding ):

p= .1
Number of tokens in nucleus: 264 | Sampled index: 2243 | Probability of sampled index: 0.0036

Distribution for p=0.1

p=.3
Number of tokens in nucleus: 887 | Sampled index: 2078 | Probability of sampled index: 0.0010

Distribution for p=0.3

p=.75
Number of tokens in nucleus: 3470 | Sampled index: 2200 | Probability of sampled index: 0.0003

Distribution for p=0.75


Cheers,

M