最后更新于
最后更新于
Beam search是经常应用在Seq2seq任务中的解码方法, 算法本身复杂度不高, 但往往能大幅提高解码质量.
Beam search只用在测试样本的解码过程中, 在训练时是用不到的.
Seq2seq任务中, 对于一个输入, 我们的任务是具有最大概率的解码序列. 可以用公式如下表示:
如果不使用任何解码方法, 在seq2seq模型中我们使用的就是贪心搜索(greedy search). 具体来说, 在时刻, 根据所有的输入(encoder的输入, 来自上一个元素的输出作为当前decoder的输入), 在字典中挑选出条件概率最大的词, 之后的每个时刻一次类推.
因此, 贪心算法在每个时刻, 始终是选择最大概率的词, 但这样选择出的词拼接在一起组成的序列, 往往不是上式概率最大的情况, 甚至相差很大, 而我们需要的, 正是上式概率最大的序列. 这种情况的反面例子可以参考参考资料中的第一篇的Why not a greedy search?
部分.
如果完全从序列的角度出发, 一定能找到最优的序列, 但是字典中所有单词组成的序列数量, 是无法枚举判断得到最优序列的.
Beam search就是一种折中的方案, 在时刻确定输出时, 会考虑之前时刻的输出序列, 并且会考虑多种前置序列, 但也使用了超参数beam width
, 保证了搜索范围的高效.
Beam search虽然可能不会找到最优的方案, 但已经能够保证在高效的前提下, 找到接近于最优答案, 甚至就是最优答案的结果.
具体的原理可以参考资料中第一篇的Beam Search
部分, 以及第二篇. 这两篇文章介绍算法时都集合了具体的例子进行推导, 形象易懂.
这是结合在seq2seq模型中的一个beam search算法的代码. Beam search算法的唯一超参数就是上文中提到的beam width
, 也就是代码中的topK
.
从投入开始标志<s>
(代码对应的字典中的index为2)为起始, 逐个预测之后的每个元素. 在预测过程中, 除了需要保存现有的topK
个候选序列(保存在变量target_seq
中), 还要存储对应序列的整体概率(保存在topk_prob
中).
另外需要注意的是, 代码中关于概率计算和终止符</s>
的处理方法.