# Beam search

## 引入

**Beam search**是经常应用在Seq2seq任务中的解码方法, 算法本身复杂度不高, 但往往能大幅提高解码质量.

Beam search只用在**测试样本的解码过程中**, 在训练时是用不到的.

## 对比

Seq2seq任务中, 对于一个输入, 我们的任务是具有最大概率的解码序列. 可以用公式如下表示:

$$\arg \max *{y} P\left(y^{<1>}, y^{<2>}, y^{<3>}, \ldots y^{\<T*{y}>} | x^{<1>}, x^{<2>}, \ldots x^{\<T\_{x}>}\right)$$

如果不使用任何解码方法, 在seq2seq模型中我们使用的就是**贪心搜索**(greedy search). 具体来说, 在$$t$$时刻, 根据所有的输入(encoder的输入, 来自上一个元素的输出作为当前decoder的输入), 在字典中挑选出条件概率最大的词$$y^{<t>}$$, 之后的每个时刻一次类推.

因此, 贪心算法在每个时刻, 始终是选择最大概率的词, 但这样选择出的词拼接在一起组成的序列, 往往不是上式概率最大的情况, 甚至相差很大, 而我们需要的, 正是上式概率最大的序列. 这种情况的反面例子可以参考参考资料中的第一篇的`Why not a greedy search?`部分.

如果完全从序列的角度出发, 一定能找到最优的序列, 但是字典中所有单词组成的序列数量, 是无法枚举判断得到最优序列的.

Beam search就是一种折中的方案, 在$$t$$时刻确定输出时, 会考虑之前时刻的输出序列, 并且会考虑多种前置序列, 但也使用了超参数`beam width`, 保证了搜索范围的高效.

Beam search虽然可能不会找到最优的方案, 但已经能够保证在高效的前提下, 找到接近于最优答案, 甚至就是最优答案的结果.

## 原理

具体的原理可以参考资料中第一篇的`Beam Search`部分, 以及第二篇. 这两篇文章介绍算法时都集合了具体的例子进行推导, 形象易懂.

## 代码

```python
def beam_search(self, seq, topk=3):
    if not self.decode_build:
        self._build_decode_model()

    seq = np.repeat(seq, topk, axis=0)
    encoder_output = self.encoder_model.predict_on_batch([seq])

    final_results = []
    topk_prob = np.zeros((topk,), dtype=np.float32)
    decode_tokens = [[] for _ in range(topk)]

    target_seq = np.zeros((topk, self.length_limit), dtype=np.int32)
    target_seq[:, 0] = 2

    last_k = 1

    for i in range(self.length_limit - 1):
        if last_k == 0 or len(final_results) > topk * 3:
            break  # stop conditions

        target_output = self.decoder_model.predict_on_batch([seq, target_seq, encoder_output])
        output = np.exp(target_output[:, i, :])
        output = output / np.sum(output, axis=-1, keepdims=True)
        output = np.log(output + 1e-8)  # use `log` transformation to avoid tiny probability

        candidates = []

        for k, probs in zip(range(last_k), output):
            if target_seq[k, i] == 3:
                continue

            word_p_sort = sorted(list(enumerate(probs)), key=lambda x: x[1], reverse=True)
            for ind, wp in word_p_sort[:topk]:
                candidates.append((k, ind, topk_prob[k] + wp))

        candidates = sorted(candidates,key=lambda x: x[-1], reverse=True)
        candidates = candidates[:topk]

        target_seq_bk = target_seq.copy()

        for new_k, cand in enumerate(candidates):
            k, ind, seq_p = cand
            target_seq[new_k] = target_seq_bk[k]
            target_seq[new_k, i + 1] = ind
            topk_prob[new_k] = seq_p
            decode_tokens.append(decode_tokens[k] + [self.tar_token_dict[ind]])
            if ind == 3:
                final_results.append((decode_tokens[k], seq_p))

        decode_tokens = decode_tokens[topk:]
        last_k = len(decode_tokens)

    final_results = [(x, y / (len(x) + 1)) for x, y in final_results]
    final_results = sorted(final_results, key=lambda x: x[1], reverse=True)
    return final_results
```

这是结合在seq2seq模型中的一个beam search算法的代码. Beam search算法的唯一超参数就是上文中提到的`beam width`, 也就是代码中的`topK`.

从投入开始标志`<s>`(代码对应的字典中的index为2)为起始, 逐个预测之后的每个元素. 在预测过程中, 除了需要保存现有的`topK`个候选序列(保存在变量`target_seq`中), 还要存储对应序列的整体概率(保存在`topk_prob`中).

另外需要注意的是, 代码中关于**概率计算**和`终止符</s>`的处理方法.

## 参考资料

* [Seq2Seq中的beam search算法](https://zhuanlan.zhihu.com/p/36029811?group_id=972420376412762112)
* [seq2seq中的beam search算法过程](https://zhuanlan.zhihu.com/p/28048246)


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://blessbingo.gitbook.io/garnet/zi-ran-yu-yan-chu-li/jie-ma/beam-search.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
