输入“/”快速插入内容

Transformer 简单上手:文本二分类

2024年8月23日修改
作者:王几行XING
案例目标:根据影评的内容,来判断该影评是积极还是消极。二分类预测。
在过去,此类 文本分类 任务,可以典型的使用 RNN、LSTM 和 GRU 等序列模型完成。这里所谓的序列,指的是一段影评,加上最后 negative or positive 的结尾。
在有了 Transformer 这一神器之后,我们可以尝试用它。(从本案例结果来看,Transformer并没有完成超越。原因之一可能在于数据量过少。)
需要注意的是,在使用 Transformer 和 LSTM 等序列模型进行文本分类(如情感分析)时,标签并不是放在影评数据序列的末尾作为最后一个要预测的词汇。相反,标签是独立于文本序列的,并且在训练过程中用作监督信号。下面是这个过程的详细说明:
1.
输入数据(影评) :这些是模型的输入,通常是 文本序列 。在预处理阶段,这些文本序列会被转换为数字表示,比如通过词嵌入(word embeddings)。在 LSTM 网络中,这些序列作为输入被逐步处理。
2.
标签( 情感分类 :这些是训练数据的一部分,但它们不是序列的一部分,而是独立的信息。对于情感分析,每个影评会有一个标签,比如“正面”或“负面”。这些标签在训练过程中用来告诉模型每个输入序列所对应的正确输出。
3.
训练过程 :在训练 LSTM 网络时,网络会尝试学习如何从输入的文本序列中提取特征,并基于这些特征做出分类决策。网络的输出是一个分类结果(比如“正面”或“负面”),这个结果会与真实的标签进行比较,通过计算损失(如 交叉熵损失 )来评估模型的性能。然后使用 反向传播算法 来调整网络的权重,以减少预测和真实标签之间的差异。
4.
输出层 :LSTM 网络的最后通常会有一个全连接层(Dense Layer),该层的输出是分类决策(例如,在情感分析中是“正面”或“负面”)。这个输出是基于整个输入序列的信息,而不是基于序列中的最后一个词汇。
小结:标签在训练 LSTM 网络时用作外部监督信号,而不是作为输入序列的一部分。这样可以让模型学习如何从整个文本序列中提取有用的信息,并做出正确的分类决策。
第一步,下载数据
这里我们下载并探索IMDB数据集的,它使用了TensorFlow框架。IMDB数据集包含电影评论,这些评论已经被转换成了一系列数字,每个数字代表一个单词。这个数据集常用于 自然语言处理 和情感分析任务。
代码块
# 导入 TensorFlow 和 Keras 相关的库
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import imdb
# 下载 IMDB 数据集
# 这里只加载了前 10,000 个最常出现的单词(基于频率)
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)
# 探索数据集
# 查看训练数据的形状(即数据的维度)
train_data_shape = train_data.shape
# 查看测试数据的形状
test_data_shape = test_data.shape
# 查看训练数据中的第一条评论(以数字列表的形式表示,每个数字代表一个单词)
sample_review = train_data[0]
# 查看与第一条评论对应的标签(0代表负面评论,1代表正面评论)
sample_label = train_labels[0]
# 显示训练数据和测试数据的形状
print("Training data shape:", train_data.shape)
print("Test data shape:", test_data.shape)
# 打印出第一条评论和其对应的标签
print("Sample review:", train_data[0])
print("Sample label:", train_labels[0])
为了将数字展示为本文,我们需要先基于 index map 进行转换(decode)。
代码块
# 获取 IMDB 数据集的词汇索引映射(单词到整数索引的映射)
word_index = imdb.get_word_index()
# 生成反向词汇索引映射(整数索引到单词的映射)
# 这里将 word_index 中的键值对调换位置,创建一个新的字典
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
# 定义一个函数,用于将编码后的评论解码为文本
# 编码评论是指使用整数代替单词的评论
def decode_review(encoded_review):
# 将每个整数索引转换回对应的单词,如果索引不在 reverse_word_index 中,则返回 '?'
# 这里减去 3 是因为 0、1、2 是为 "padding"、"start of sequence" 和 "unknown" 保留的索引
return ' '.join([reverse_word_index.get(i - 3, '?') for i in encoded_review])
# 打印出训练集中第一条评论的解码文本
print(decode_review(train_data[0])) # 解码训练集中的第一条评论