gpt2解码参数解析


使用gpt2进行文本生成时,有几个参数可调,本文简要总结了一下这些个参数的作用。

默认的解码方式greedy search。后续的参数调整都是围绕着这种解码的缺点进行的。

1、 num_beams和early_stopping

greedy search的缺点在于,一个低概率的词,可能会使得后面高概率的词无法解码出来。 num_beams参数就是通过维持几个当前最优解,来规避掉这个问题。 理论上,该参数越大,效果越好,但是计算性能会越差。 一般配合参数early_stopping=True使用。该参数的作用是当解码的候选句子中有num_beams个已经到达结束时,结束beam search。

2、no_repeat_ngram_size

gpt2解码时容易出现重复的词语或者短句。该参数就是用来限制重复的。 no_repeat_ngram_size=2,表示在解码时,2 gram只允许出现一次,不允许重复解码出来。 有一点需要注意,目前我使用的闻仲gpt2模型,这个模型的tokenizer还是bpe的,也就是解码的token不是中文字符。 此时,如果想达到汉字意义上2gram不重复,no_repeat_ngram_size应该设为4。 使用该参数要小心,比如如果句子中有北京,设置了no_repeat_ngram_size=4时,解码的句子中就只允许出现一次北京了,这有时未必是一个好的选择。

3、num_return_sequences

当使用num_beams时,我们可能想把排名靠前的几个候选句子都返回,这可以通过num_return_sequences来实现,有一个前提条件是: num_return_sequences <= num_beams

4、do_sample

beam search每次都是从概率高的词中选择,导致概率低的词很难被选中。do_sample就是解决这个问题的。它的作用是从词分布中抽样一个词出来,这样,无论词的概率大小,都是有机会被选中的,这增加了生成的多样性。

5、temperature

do_sample带来的一个问题时,很多奇怪的词也可能被挑选出来。为了缓解这个问题,可以使用temperature。该参数的目的是,让概率高的词概率更高,概率低的词概率更低。 取值越小,分布越陡峭。取值为0时,就等同于greedy search了。

6、top_k

一般是配合do_sample,当sample时,仅从概率最高的k个词中进行sample。

7、top_p

也是为了配合do_sample,当sample时,仅从前面概率相加大于等于top_p的词中进行sample。该参数和top_k也可以配合使用。

下面给出一个经典的参数配置:

# set seed to reproduce results. Feel free to change the seed though to get different results
tf.random.set_seed(0)

# set top_k = 50 and set top_p = 0.95 and num_return_sequences = 3
sample_outputs = model.generate(
    input_ids,
    do_sample=True, 
    max_length=50, 
    top_k=50, 
    top_p=0.95, 
    num_return_sequences=3
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
  print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))


原创文章,转载请注明出处,否则拒绝转载!
本文链接:抬头看浏览器地址栏

上篇: NLP模型训练时数据预处理的教训
下篇: pytorch进程间通信