本文提出了一种去除 LLM reasoning 的 CoT 中冗余 token 的方法,发现提升了推理能力

token pruning 应该指的是选择性的保留一部分的 token,来减少 attention 计算量 (同时保持效果不变或更好)

Observation

Redundancy

作者发现,CoT 中经常存在 redundancy. 尤其是对那些没有得出正确答案的 CoT, 其 attention sparsity 更严重。

这样的 token 可能会导致后续思考被 distract.

image not found

这里做了一个 visualization. 可以看到 poor reasoning 对应的 attention map 非常单一,但是 good reasoning 就展现出更多的 pattern

于是作者想到,可以用 attention weights 来判断/筛选那些冗余的思考步骤

看到一个有意思的 related work, Efficient Streaming Language Models with Attention Sinks

说的是 attention 经常喜欢 strong focus on initial tokens. 于是提出只对 initial tokens & recent tokens 做 attention. 虽然看上去对 reasoning 比较唐

Attention score to </think>

</think> token,代表思考过程的结束的特殊 token。作者发现这个 token 对之前 tokens 的 attention score 是有意义的

一般来说,它会更注意那些 对最终结果有帮助的思考过程的开始部分 / 关键结论的总结

image not found

Method

Prompt to summarize periodically

每隔一定长度,我们在模型生成的 token 后加入下面一段 prompt, 让模型总结前面的思考:

Time is up. Given the time I’ve spent and the approaches I’ve tried, I should stop thinking and now write summarization in one sentence.</think>

但是注意,我们并不让模型继续生成,而是就到这里 forward 一边然后看 </think> token 的 attention scores

Prune by importance score

注意我们每层有很多头 (文章没写怎么处理不同的层,我猜可能是选了一个效果最好的层)

记对 token $t$, head $h$, $s_t^{(l, h)}$ 表示 </think> token 对 $t$ 的 attention weight.

为了检索 思考段,我们需要先给思考分段。分段的方法说是 follow 之前的做法,也就是每当遇到下面的词之一的时候分段

"Wait" "Alternatively" "Another angle" "Another approach" "But wait" "Hold on" "Hmm" "Maybe" "Looking back" "Okay" "Let me" "First" "Then" "Alright" "Compute" "Correct" "Good" "Got it" "I don’t see any errors" "I think" "Let me double-check" "Let’s see" "Now" "Remember" "Seems solid" "Similarly" "So" "Starting" "That’s correct" "That seems right" "Therefore" "Thus"

接下来对每一段 $r$, 我们计算 $r$ 的 importance score 为 $r$ 中所有 token 对所有 head 的 $s$ 的平均值。

最后,对一个超参数 $k$ (eviction budget),我们先把 importance score 对所有段排序,从小到大,然后一个个删除,但是保证删除的 importance score 之和不超过 $k$.

Ablation: Does eviction alone improve performance?

作者还提出了一个疑问,就是是否删除一些 token 本身就能提高 performance. 答案是 是的.

image not found

作者另外考虑了两种 eviction 测量:

  1. random. 完全随机的删掉一些 token
  2. H2O, 指的是删掉最低 accumulated attention score. (似乎是一个常见的 baseline)

发现二者都有提升。不过没有本文的方法效果好