介绍
随着机器学习(Machine Learning, ML)和自然语言处理(Natural Language Processing, NLP)技术的快速进展,新算法具备生成文本的能力,这些文本也变得越来越接近人类写出的内容。GPT21就是其中一个算法,它被应用在很多开源项目2中。GPT2以WebText为训练集,WebText包含4500万条来自Reddit(一个对新闻进行评论的网络社区)的外链。其中占据外链内容前10的主要数据3来自Google,Archive,Blogspot,Github,NYTimes,WordPress,Washington Post,Wikia,BBC以及The Guardian。受过训练的GPT2模型能根据具体数据集再被进一步调校,比如说最终能够抓取某个数据集的风格或者能够做文档分类。
这一功能基于迁移学习的实现,即一种从源设定中提取知识并应用到另一个不同目标设置上的技术4。如果想要了解GTP2算法更详细的解释以及算法构架,请参考原始文献5,OpenAI的博客6和Jay Alammar的指导说明7。
数据集
用来训练GPT2的数据集来自《瑞克和莫蒂》前三季的台词。我提前过滤了其中不属于Rick,Morty,Summer,Beth和Jerry的对话。这些数据下载后以生文本格式进行存储。每一行数据代表一位主角的发言,同时包含了对他们语气/动作及对话场景的描述。数据集被分为训练集和测试集,分别有6905行和1454行。原文件在此可供查看(https://github.com/e-tony/Story_Generator/tree/main/data)。训练集是用来训练模型的,测试集则用来评估模型效果。
训练模型
Hugging Face’s Transformers库提供了一个简单的GPT2模型训练脚本(https://github.com/huggingface/transformers/tree/master/examples/language-modeling#gpt-2gpt-and-causal-language-modeling)。
接下来,你可以在Google Colab notebook(https://colab.research.google.com/drive/1opXtwhZ02DjdyoVlafiF3Niec4GqPJvC?usp=sharing)环境下开始训练自己的模型。一旦完成了模型训练,你需要将训练输出文件夹下载下来,文件夹里包含了所有相关模型的文件,这一步对之后加载模型至关重要。你还可以将自己的模型上传到Hugging Face的模型中心8,让其他人也能看到它。这个训练好的模型在使用测试数据评估时,会获得17分左右的复杂度得分。
搭建应用
首先,我们新建一个叫做Story_Generator的项目文件夹 ,并在Python 3.7的环境下开始试验:
mkdir Story_Generator
cd Story_Generator
python3.7 -m venv venv
source venv/bin/activate
复制
下一步,安装所有需要的依赖:
pip install streamlit-nightly==0.69.3.dev20201025
pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install git+git://github.com/huggingface/transformers@59b5953d89544a66d73
复制
整个应用通过app.py实现。我们需要创建python文件并导入我们新安装的依赖:
import urllib
import streamlit as st
import torch
from transformers import pipeline
复制
在更进一步之前,需要加载训练好的模型。利用@st_cache的装饰器,执行一次load_model()函数并将结果存到本地缓存。这个操作能够增幅程序性能。接着用pipeline()函数加载文本生成器模型即可(将代码中的模型路径换成你自己的模型或者也可以直接用模型中心里我预先训练过的mypre-trainedmodel,https://huggingface.co/e-tony/gpt2-rnm):
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_model():
return pipeline("text-generation", model="e-tony/gpt2-rnm")model = load_model()
复制
使用Streamlit中的text_area()函数可以生成一个简单文本框。我们可以额外定义文本框的高度及其容纳的最大字符数(需要注意的是越大的文本生成时间越长):
textbox = st.text_area('Start your story:', '', height=200, max_chars=1000)
复制
搞定了代码的开头部分,我们现在可以运行程序,看看目前的进度(也可以通过刷新页面获取及时进度)
streamlit run app.py
复制
接下来,加入一个自由调节插件,用来实现用户自定义模型需要生成的角色数:
slider = st.slider('Max story length (in characters)', 50, 200)
复制
现在我们可以开始生成文本啦!来做一个执行生成命令的按钮吧:
button = st.button('Generate')
复制
我们的应用要感知“按下按钮”的动作,此功能借助一个简单的条件语句实现。文本生成后会打印到屏幕上:
if button:
output_text = model(textbox, max_length=slider)[0]['generated_text']
for i, line in enumerate(output_text.split("\n")):
if ":" in line:
speaker, speech = line.split(':')
st.markdown(f'__{speaker}__: {speech}')
else:
st.markdown(line)
复制
下面向之前定义的文本框输入提示语生成故事:
Rick: Come on, flip the pickle, Morty. You're not gonna regret it. The payoff is huge.
瑞克:莫蒂,快呀,把泡菜黄瓜翻过来,你不会后悔的。你会得到巨大回报的。
输出:
Rick: Come on, flip the pickle, Morty. You're not gonna regret it. The payoff is huge. You don't have to be bad, Morty.
瑞克:莫蒂,快呀,把泡菜黄瓜翻过来,你不会后悔的。你会得到巨大回报的。莫蒂,你不用扮演坏人的。
(瑞克台词结束)
【换景退出。莫蒂在家里】
很棒的输出!模型根据提示输出了新内容,而且看上去不错。我们还能通过调整decoding方法的参数来进一步提升输出的质量。Hugging Face的帖子里有关于不同方法解码的更多概述9。现在,让我们替换模型函数并给更多参数赋值:
output_text = model(textbox, do_sample=True, max_length=slider, top_k=50, top_p=0.95, num_returned_sequences=1)[0]['generated_text']
复制
简而言之,do_sample会随机挑选下一个词语,top_k过滤控制最有可能在下一个词出现的词汇的个数,top_p允许后面生成词语数量的动态增加或减少,num_returned_sequences参数负责输出多个相互独立的样本以供进一步筛选和评估(在我们的案例出只输出了一组样本)。
通过调节这些参数,你就能获得不同类型的输出结果。让我们用这种解码方法生成另一个输出吧。
输出:
Rick: Come on, flip the pickle, Morty. You're not gonna regret it. The payoff is huge.
瑞克:莫蒂,快呀,把腌黄瓜翻过来,你不会后悔的。这是个大惊喜。
Morty: Ew, no, Rick! Where are you?
莫蒂:呃,不,瑞克!你在哪?
Rick: Morty, just do it! [laughing] Just flip the pickle!
瑞克:莫蒂,就这样做【大笑】。快把腌黄瓜翻过来!
Morty: I'm a Morty, okay?
莫蒂:我是莫蒂,好吗?
Rick: Come on, Morty. Don't be ashamed to be a Morty. Just flip the pickle.
瑞克:别介,莫蒂。你是莫蒂没什么好羞耻的。快把腌黄瓜翻过来。
现在我们的输出看起来更像样了。尽管模型还会输出一些不合逻辑甚至无意义的语句,但新模型配合解码方法能够解决问题。不巧的是,由于模型受到网络数据文本的训练,有时会生成具有伤害性的、粗鲁的、暴力的或者带有歧视性意味的用词。针对这一问题我们通过应用“坏词”过滤器来解决,过滤器根据一个含有451个词汇的列表对暴力词汇进行简单检查以发现伤害性用词。我强烈建议读者考虑再增加别的过滤器,比如针对仇恨言论的过滤器。过滤器功能的增加方法如下:
def load_bad_words() -> list:
res_list = []file = urllib.request.urlopen("https://raw.githubusercontent.com/RobertJGabriel/Google-profanity-words/master/list.txt")
for line in file:
dline = line.decode("utf-8")
res_list.append(dline.split("\n")[0])
return res_listBAD_WORDS = load_bad_words()
def filter_bad_words(text):
explicit = False
res_text = text.lower()
for word in BAD_WORDS:
if word in res_text:
res_text = res_text.replace(word, word[0]+"*"*len(word[1:]))
explicit = Trueif not explicit:
return textoutput_text = ""
for oword,rword in zip(text.split(" "), res_text.split(" ")):
if oword.lower() == rword:
output_text += oword+" "
else:
output_text += rword+" "return output_textoutput_text = filter_bad_words(model(textbox, do_sample=True, max_length=slider, top_k=50, top_p=0.95, num_returned_sequences=1)[0]['generated_text'])
复制
最终的app.py文件如下:
import urllib
import streamlit as st
import torch
from transformers import pipelinedef load_bad_words() -> list:
res_list = []file = urllib.request.urlopen("https://raw.githubusercontent.com/RobertJGabriel/Google-profanity-words/master/list.txt")
for line in file:
dline = line.decode("utf-8")
res_list.append(dline.split("\n")[0])
return res_listBAD_WORDS = load_bad_words()
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_model():
return pipeline("text-generation", model="e-tony/gpt2-rnm")def filter_bad_words(text):
explicit = False
res_text = text.lower()
for word in BAD_WORDS:
if word in res_text:
res_text = res_text.replace(word, word[0]+"*"*len(word[1:]))
explicit = Trueif not explicit:
return textoutput_text = ""
for oword,rword in zip(text.split(" "), res_text.split(" ")):
if oword.lower() == rword:
output_text += oword+" "
else:
output_text += rword+" "return output_textmodel = load_model()
textbox = st.text_area('Start your story:', '', height=200, max_chars=1000)
slider = slider = st.slider('Max text length (in characters)', 50, 1000)
button = st.button('Generate')if button:
output_text = filter_bad_words(model(textbox, do_sample=True, max_length=slider, top_k=50, top_p=0.95, num_returned_sequences=1)[0]['generated_text'])
for i, line in enumerate(output_text.split("\n")):
if ":" in line:
speaker, speech = line.split(':')
st.markdown(f'__{speaker}__: {speech}')
else:
st.markdown(line)
复制
此外,欢迎大家到Github库上查看我的演示案例,其中还包括对此应用的外观效果及其他功能性修饰相关的实用代码。
https://github.com/e-tony/Story_Generator
https://share.streamlit.io/e-tony/story_generator/main/app.py
现在我们已经可以上线功能了!
应用部署
我们使用Streamlit Sharing部署应用。准备一个公开的Github库以及相关需求文档(requirement.txt文件),再加上app.py文件就,就基本完成了。你的需求文档大致将包含这些内容:
-f https://download.pytorch.org/whl/torch_stable.html
streamlit-nightly==0.69.3.dev20201025
torch==1.6.0+cpu
torchvision==0.7.0+cpu
transformers @ git+git://github.com/huggingface/transformers@59b5953d89544a66d73
复制
在Streamlit Sharing网站上(https://share.streamlit.io/)和Github库完成连接,模型很快就能上线了。
伦理考量
需要提醒大家,我们这里讨论的应用仅限个人娱乐使用!在其他场景下使用GPT2模型之前请谨慎考量。尽管我们移除了原始训练集中涉及到某些领域的数据,GPT2模型仍然大量使用了网络上未经筛选的内容进行预先训练,其中就会包含很多偏见和歧视意味的言论。
OpenAI的模型卡片说明也指出了这些担忧:
我们认为某些二次使用的案例大概率包含以下情景:
- 写作辅助:语法助手,自动补充填词(常规文本或是代码文本)
- 创业写作和艺术创作:对创意性、科幻性文本生成进行探索;帮助诗歌和其他文学艺术作品的创作
- 娱乐用途:用于游戏,聊天机器人,讲笑话等
超出应用案例的范畴:
类似GPT2这样基于大范围语言训练的模型并不会辨别语句的真伪,因此我们不支持需要保证文本真实性的使用案例。此外,由于训练数据特性,GPT2系统本身就反映出一定程度的偏见,所以我们也不推荐直接将模型用于人类交互,除非部署者提前对预期使用案例进行了偏见问题的相关研究。在我们进行过的分析中,针对性别、种族以及宗教问题,使用774M的数据量和1.5B数据量的训练集并未发现统计学意义上的显著差异,这意味着使用所有版本的GPT2模型都应该同样谨慎,尤其是要处理与人类社会属性偏见敏感的案例时。
以下这个案例也说明该模型可能生成具有偏见性的结果:
>>> from transformers import pipeline, set_seed
>>> generator = pipeline('text-generation', model='gpt2')
>>> set_seed(42)
>>> generator("The man worked as a", max_length=10, num_return_sequences=5)[{'generated_text': 'The man worked as a waiter at a Japanese restaurant'},
{'generated_text': 'The man worked as a bouncer and a boun'},
{'generated_text': 'The man worked as a lawyer at the local firm'},
{'generated_text': 'The man worked as a waiter in a cafe near'},
{'generated_text': 'The man worked as a chef in a strip mall'}]>>> set_seed(42)
>>> generator("The woman worked as a", max_length=10, num_return_sequences=5)[{'generated_text': 'The woman worked as a waitress at a Japanese restaurant'},
{'generated_text': 'The woman worked as a waitress at a local restaurant'},
{'generated_text': 'The woman worked as a waitress at the local supermarket'},
{'generated_text': 'The woman worked as a nurse in a health center'},
{'generated_text': 'The woman worked as a maid in Daphne'}]
复制
我强烈建议读者谨慎考虑相关模型在现实场景中的应用。关于机器学习在伦理道德方面的问题,例如EML,AINow等社区还有很多可供参考的内容。
结论
恭喜!看到这里,你的应用已经具备上线能力了。
借助一些开源的框架,我们得以实现GPT2模型的快速调教,并制作出有趣的应用原型,接着部署使用模型。模型生成的故事还能被进一步润色,借助其他具有更多高级功能的模型,解码方法甚至语言预测构架,都能让故事变得更精彩。
欢迎大家到Github测试和查看关于此项目的更多信息。
https://github.com/e-tony/Story_Generator
https://share.streamlit.io/e-tony/story_generator/main/app.py
相关参考:
[1]: GPT2(https://github.com/openai/gpt-2)
[2]: The Top 30 Gpt2 Open Source Projects(https://awesomeopensource.com/projects/gpt-2)
[3]: The State of Transfer Learning in NLP(https://ruder.io/state-of-transfer-learning-in-nlp/)
[4]: Top 1,000 domains present in WebText(https://github.com/openai/gpt-2/blob/master/domains.txt)
[5]: A. Radford, Jeffrey Wu, R. Child, David Luan, Dario Amodei, and Ilya Sutskever 2019. Language Models are Unsupervised Multitask Learners.
[6]: Better Language Models and Their Implications(https://openai.com/blog/better-language-models/)
[7]: The Illustrated GPT-2 (Visualizing Transformer Language Models)(http://jalammar.github.io/illustrated-gpt2/)
[8]: Model sharing and uploading(https://huggingface.co/transformers/model_sharing.html)
[9]: How to generate text: using different decoding methods for language generation with Transformers(https://huggingface.co/blog/how-to-generate)
[10]: Deploy an app(https://docs.streamlit.io/en/stable/deploy_streamlit_app.html)
[11]: The Institute for Ethical AI & Machine Learning(https://ethical.institute/)
[12]: AI Now Institute (https://ainowinstitute.org/ )