订阅
纠错
加入自媒体

NLP ——从0开始快速上手百度 ERNIE

2020-12-17 10:53
程序媛驿站
关注

三、具体实现过程

开始写代码!

ChnSentiCorp任务运行的shell脚本是 ERNIE/ernie/run_classifier.py,该文件定义了分类任务Fine-tuning 的详细过程,下面我们将通过如下几个步骤进行详细剖析:

环境准备。导入相关的依赖,解析命令行参数;

实例化ERNIE 模型,优化器以及Tokenizer, 并设置超参数

定义辅助函数

运行训练循环

1. 环境准备

import相关的依赖,解析命令行参数。

import syssys.path.append('./ERNIE')import numpy as npfrom sklearn.metrics import f1_scoreimport paddle as Pimport paddle.fluid as Fimport paddle.fluid.layers as Limport paddle.fluid.dygraph as D
from ernie.tokenizing_ernie import ErnieTokenizerfrom ernie.modeling_ernie import ErnieModelForSequenceClassification2. 实例化ERNIE 模型,优化器以及Tokenizer, 并设置超参数

设置好所有的超参数,对于ERNIE任务学习率推荐取 1e-5/2e-5/5e-5, 根据显存大小调节BATCH大小, 最大句子长度不超过512.

BATCH=32MAX_SEQLEN=300LR=5e-5EPOCH=10
D.guard().__enter__() # 为了让Paddle进入动态图模式,需要添加这一行在最前面
ernie = ErnieModelForSequenceClassification.from_pretrained('ernie-1.0', num_labels=3)optimizer = F.optimizer.Adam(LR, parameter_list=ernie.parameters())tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')3. 定义辅助函数

(1)定义函数 make_data,将文本数据读入内存并转换为numpy List存储。

def make_data(path):    data = []    for i, l in enumerate(open(path)):        if i == 0:            continue        l = l.strip().split(' ')        text, label = l[0], int(l[1])        text_id, _ = tokenizer.encode(text) # ErnieTokenizer 会自动添加ERNIE所需要的特殊token,如[CLS], [SEP]        text_id = text_id[:MAX_SEQLEN]        text_id = np.pad(text_id, [0, MAX_SEQLEN-len(text_id)], mode='constant') # 对所有句子都补长至300,这样会比较费显存;        label_id = np.array(label+1)        data.append((text_id, label_id))    return data
train_data = make_data('./chnsenticorp/train/part.0')test_data = make_data('./chnsenticorp/dev/part.0')

(2)定义函数get_batch_data,用于获取BATCH条样本并按照批处理维度stack到一起。

def get_batch_data(data, i):    d = data[i*BATCH: (i + 1) * BATCH]    feature, label = zip(*d)    feature = np.stack(feature)  # 将BATCH行样本整合在一个numpy.array中    label = np.stack(list(label))    feature = D.to_variable(feature) # 使用to_variable将numpy.array转换为paddle tensor    label = D.to_variable(label)    return feature, label4. 运行训练循环

队训练数据重复EPOCH遍训练循环;每次循环开头都会重新shuffle数据。在训练过程中每间隔100步在验证数据集上进行测试并汇报结果(acc)。

for i in range(EPOCH):    np.random.shuffle(train_data) # 每个epoch都shuffle数据以获得最佳训练效果;    #train    for j in range(len(train_data) // BATCH):        feature, label = get_batch_data(train_data, j)        loss, _ = ernie(feature, labels=label) # ernie模型的返回值包含(loss, logits);其中logits目前暂时不需要使用        loss.backward()        optimizer.minimize(loss)        ernie.clear_gradients()        if j % 10 == 0:            print('train %d: loss %.5f' % (j, loss.numpy()))        # evaluate        if j % 100 == 0:            all_pred, all_label = [], []            with D.base._switch_tracer_mode_guard_(is_train=False): # 在这个with域内ernie不会进行梯度计算;                ernie.eval() # 控制模型进入eval模式,这将会关闭所有的dropout;                for j in range(len(test_data) // BATCH):                    feature, label = get_batch_data(test_data, j)                    loss, logits = ernie(feature, labels=label)                     all_pred.extend(L.argmax(logits, -1).numpy())                    all_label.extend(label.numpy())                ernie.train()            f1 = f1_score(all_label, all_pred, average='macro')            acc = (np.array(all_label) == np.array(all_pred)).astype(np.float32).mean()            print('acc %.5f' % acc)

训练过程中单次迭代输出的日志如下所示:

train 0: loss 0.05833acc 0.91723train 10: loss 0.03602train 20: loss 0.00047train 30: loss 0.02403train 40: loss 0.01642train 50: loss 0.12958train 60: loss 0.04629train 70: loss 0.00942train 80: loss 0.00068train 90: loss 0.05485train 100: loss 0.01527acc 0.92821train 110: loss 0.00927train 120: loss 0.07236train 130: loss 0.01391train 140: loss 0.01612

包含了当前 batch 的训练得到的Loss(ave loss)和每个Epochde 精度(acc)信息。训练完成后用户可以参考快速运行中的方法使用模型体验推理功能。

其它特性

ERNIE 还提供了混合精度训练、模型蒸馏等高级功能,可以在 README 中获得这些功能的使用方法。

图片标题


<上一页  1  2  3  
声明: 本文由入驻维科号的作者撰写,观点仅代表作者本人,不代表OFweek立场。如有侵权或其他问题,请联系举报。

发表评论

0条评论,0人参与

请输入评论内容...

请输入评论/评论长度6~500个字

您提交的评论过于频繁,请输入验证码继续

暂无评论

暂无评论

    人工智能 猎头职位 更多
    扫码关注公众号
    OFweek人工智能网
    获取更多精彩内容
    文章纠错
    x
    *文字标题:
    *纠错内容:
    联系邮箱:
    *验 证 码:

    粤公网安备 44030502002758号