草庐IT

【深度学习前沿应用】文本分类Fine-Tunning

灵彧universe 2023-03-28 原文

【深度学习前沿应用】文本分类Fine-Tunning


作者简介:在校大学生一枚,华为云享专家,阿里云星级博主,腾云先锋(TDP)成员,云曦智划项目总负责人,全国高等学校计算机教学与产业实践资源建设专家委员会(TIPCC)志愿者,以及编程爱好者,期待和大家一起学习,一起进步~ . 博客主页ぃ灵彧が的学习日志 . 本文专栏机器学习 . 专栏寄语:若你决定灿烂,山无遮,海无拦 .

(文章目录)


前言

应用BERT模型做短文本情绪分类

#导入相关的模块 import paddle import paddlenlp as ppnlp from paddlenlp.data import Stack, Pad, Tuple import paddle.nn.functional as F import numpy as np from functools import partial #partial()函数可以用来固定某些参数值,并返回一个新的callable对象 ppnlp.__version__

一、数据加载及预处理

(一)、数据导入

数据集为公开中文情感分析数据集ChnSenticorp。使用PaddleNLP的.datasets.ChnSentiCorp.get_datasets方法即可以加载该数据集。

#采用paddlenlp内置的ChnSentiCorp语料,该语料主要可以用来做情感分类。训练集用来训练模型,验证集用来选择模型,测试集用来评估模型泛化性能。 train_ds, dev_ds, test_ds = ppnlp.datasets.ChnSentiCorp.get_datasets(['train','dev','test']) #获得标签列表 label_list = train_ds.get_labels() #看看数据长什么样子,分别打印训练集、验证集、测试集的前3条数据。 print("训练集数据:{}\n".format(train_ds[0:1])) print("验证集数据:{}\n".format(dev_ds[0:1])) print("测试集数据:{}\n".format(test_ds[0:1])) print("训练集样本个数:{}".format(len(train_ds))) print("验证集样本个数:{}".format(len(dev_ds))) print("测试集样本个数:{}".format(len(test_ds))) 输出结果如下图1所示:


(二)、数据预处理

#调用ppnlp.transformers.BertTokenizer进行数据处理,tokenizer可以把原始输入文本转化成模型model可接受的输入数据格式。 tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained("bert-base-chinese") #数据预处理 def convert_example(example,tokenizer,label_list,max_seq_length=256,is_test=False): if is_test: text = example else: text, label = example #tokenizer.encode方法能够完成切分token,映射token ID以及拼接特殊token encoded_inputs = tokenizer.encode(text=text, max_seq_len=max_seq_length) # print('===================') # print(encoded_inputs) input_ids = encoded_inputs["input_ids"] segment_ids = encoded_inputs["token_type_ids"] if not is_test: label_map = {} for (i, l) in enumerate(label_list): label_map[l] = i label = label_map[label] label = np.array([label], dtype="int64") return input_ids, segment_ids, label else: return input_ids, segment_ids #数据迭代器构造方法 def create_dataloader(dataset, trans_fn=None, mode='train', batch_size=1, use_gpu=False, pad_token_id=0, batchify_fn=None): if trans_fn: dataset = dataset.apply(trans_fn, lazy=True) if mode == 'train' and use_gpu: sampler = paddle.io.DistributedBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=True) else: shuffle = True if mode == 'train' else False #如果不是训练集,则不打乱顺序 sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle) #生成一个取样器 dataloader = paddle.io.DataLoader(dataset, batch_sampler=sampler, return_list=True, collate_fn=batchify_fn) return dataloader #使用partial()来固定convert_example函数的tokenizer, label_list, max_seq_length, is_test等参数值 trans_fn = partial(convert_example, tokenizer=tokenizer, label_list=label_list, max_seq_length=128, is_test=False) batchify_fn = lambda samples, fn=Tuple(Pad(axis=0,pad_val=tokenizer.pad_token_id), Pad(axis=0, pad_val=tokenizer.pad_token_id), Stack(dtype="int64")):[data for data in fn(samples)] #训练集迭代器 train_loader = create_dataloader(train_ds, mode='train', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn) #验证集迭代器 dev_loader = create_dataloader(dev_ds, mode='dev', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn) #测试集迭代器 test_loader = create_dataloader(test_ds, mode='test', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)

二、BERT预训练模型加载

#加载预训练模型Bert用于文本分类任务的Fine-tune网络BertForSequenceClassification, 它在BERT模型后接了一个全连接层进行分类。 #由于本任务中的情感分类是二分类问题,设定num_classes为2 model = ppnlp.transformers.BertForSequenceClassification.from_pretrained("bert-base-chinese", num_classes=2)

三、训练模型

(一)、设置训练超参数

#设置训练超参数 #学习率 learning_rate = 1e-5 #训练轮次 epochs = 8 #学习率预热比率 warmup_proption = 0.1 #权重衰减系数 weight_decay = 0.01 num_training_steps = len(train_loader) * epochs num_warmup_steps = int(warmup_proption * num_training_steps) def get_lr_factor(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) else: return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) #学习率调度器 lr_scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate, lr_lambda=lambda current_step: get_lr_factor(current_step)) #优化器 optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, parameters=model.parameters(), weight_decay=weight_decay, apply_decay_param_fun=lambda x: x in [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ]) #损失函数 criterion = paddle.nn.loss.CrossEntropyLoss() #评估函数 metric = paddle.metric.Accuracy()

(二)、评估函数

#评估函数 def evaluate(model, criterion, metric, data_loader): model.eval() metric.reset() losses = [] for batch in data_loader: input_ids, segment_ids, labels = batch logits = model(input_ids, segment_ids) loss = criterion(logits, labels) losses.append(loss.numpy()) correct = metric.compute(logits, labels) metric.update(correct) accu = metric.accumulate() print("eval loss: %.5f, accu: %.5f" % (np.mean(losses), accu)) model.train() metric.reset()

(三)、模型训练

#开始训练 global_step = 0 for epoch in range(1, epochs + 1): for step, batch in enumerate(train_loader): #从训练数据迭代器中取数据 # print(batch) input_ids, segment_ids, labels = batch logits = model(input_ids, segment_ids) loss = criterion(logits, labels) #计算损失 probs = F.softmax(logits, axis=1) correct = metric.compute(probs, labels) metric.update(correct) acc = metric.accumulate() global_step += 1 if global_step % 50 == 0 : print("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (global_step, epoch, step, loss, acc)) loss.backward() optimizer.step() lr_scheduler.step() optimizer.clear_gradients() evaluate(model, criterion, metric, dev_loader)

四、模型预测

def predict(model, data, tokenizer, label_map, batch_size=1): examples = [] for text in data: input_ids, segment_ids = convert_example(text, tokenizer, label_list=label_map.values(), max_seq_length=128, is_test=True) examples.append((input_ids, segment_ids)) batchify_fn = lambda samples, fn=Tuple(Pad(axis=0, pad_val=tokenizer.pad_token_id), Pad(axis=0, pad_val=tokenizer.pad_token_id)): fn(samples) batches = [] one_batch = [] for example in examples: one_batch.append(example) if len(one_batch) == batch_size: batches.append(one_batch) one_batch = [] if one_batch: batches.append(one_batch) results = [] model.eval() for batch in batches: input_ids, segment_ids = batchify_fn(batch) input_ids = paddle.to_tensor(input_ids) segment_ids = paddle.to_tensor(segment_ids) logits = model(input_ids, segment_ids) probs = F.softmax(logits, axis=1) idx = paddle.argmax(probs, axis=1).numpy() idx = idx.tolist() labels = [label_map[i] for i in idx] results.extend(labels) return results data = ['这个商品虽然看着样式挺好看的,但是不耐用。', '这个老师讲课水平挺高的。'] label_map = {0: '负向情绪', 1: '正向情绪'} predictions = predict(model, data, tokenizer, label_map, batch_size=32) for idx, text in enumerate(data): print('预测文本: {} \n情绪标签: {}'.format(text, predictions[idx])) 输出结果如下图2所示:


总结

本系列文章内容为根据清华社出版的《机器学习实践》所作的相关笔记和感悟,其中代码均为基于百度飞桨开发,若有任何侵权和不妥之处,请私信于我,定积极配合处理,看到必回!!!

最后,引用本次活动的一句话,来作为文章的结语~( ̄▽ ̄~)~:

【**学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。**】

有关【深度学习前沿应用】文本分类Fine-Tunning的更多相关文章

  1. ruby - 将差异补丁应用于字符串/文件 - 2

    对于具有离线功能的智能手机应用程序,我正在为Xml文件创建单向文本同步。我希望我的服务器将增量/差异(例如GNU差异补丁)发送到目标设备。这是计划:Time=0Server:hasversion_1ofXmlfile(~800kiB)Client:hasversion_1ofXmlfile(~800kiB)Time=1Server:hasversion_1andversion_2ofXmlfile(each~800kiB)computesdeltaoftheseversions(=patch)(~10kiB)sendspatchtoClient(~10kiBtransferred)Cl

  2. ruby - 使用 ruby​​ 将 HTML 转换为纯文本并维护结构/格式 - 2

    我想将html转换为纯文本。不过,我不想只删除标签,我想智能地保留尽可能多的格式。为插入换行符标签,检测段落并格式化它们等。输入非常简单,通常是格式良好的html(不是整个文档,只是一堆内容,通常没有anchor或图像)。我可以将几个正则表达式放在一起,让我达到80%,但我认为可能有一些现有的解决方案更智能。 最佳答案 首先,不要尝试为此使用正则表达式。很有可能你会想出一个脆弱/脆弱的解决方案,它会随着HTML的变化而崩溃,或者很难管理和维护。您可以使用Nokogiri快速解析HTML并提取文本:require'nokogiri'h

  3. ruby-on-rails - Rails 应用程序之间的通信 - 2

    我构建了两个需要相互通信和发送文件的Rails应用程序。例如,一个Rails应用程序会发送请求以查看其他应用程序数据库中的表。然后另一个应用程序将呈现该表的json并将其发回。我还希望一个应用程序将存储在其公共(public)目录中的文本文件发送到另一个应用程序的公共(public)目录。我从来没有做过这样的事情,所以我什至不知道从哪里开始。任何帮助,将不胜感激。谢谢! 最佳答案 无论Rails是什么,几乎所有Web应用程序都有您的要求,大多数现代Web应用程序都需要相互通信。但是有一个小小的理解需要你坚持下去,网站不应直接访问彼此

  4. ruby - 无法运行 Rails 2.x 应用程序 - 2

    我尝试运行2.x应用程序。我使用rvm并为此应用程序设置其他版本的ruby​​:$rvmuseree-1.8.7-head我尝试运行服务器,然后出现很多错误:$script/serverNOTE:Gem.source_indexisdeprecated,useSpecification.Itwillberemovedonorafter2011-11-01.Gem.source_indexcalledfrom/Users/serg/rails_projects_terminal/work_proj/spohelp/config/../vendor/rails/railties/lib/r

  5. ruby-on-rails - Rails 应用程序中的 Rails : How are you using application_controller. rb 是新手吗? - 2

    刚入门rails,开始慢慢理解。有人可以解释或给我一些关于在application_controller中编码的好处或时间和原因的想法吗?有哪些用例。您如何为Rails应用程序使用应用程序Controller?我不想在那里放太多代码,因为据我了解,每个请求都会调用此Controller。这是真的? 最佳答案 ApplicationController实际上是您应用程序中的每个其他Controller都将从中继承的类(尽管这不是强制性的)。我同意不要用太多代码弄乱它并保持干净整洁的态度,尽管在某些情况下ApplicationContr

  6. ruby-on-rails - 如何在我的 Rails 应用程序 View 中打印 ruby​​ 变量的内容? - 2

    我是一个Rails初学者,但我想从我的RailsView(html.haml文件)中查看Ruby变量的内容。我试图在ruby​​中打印出变量(认为它会在终端中出现),但没有得到任何结果。有什么建议吗?我知道Rails调试器,但更喜欢使用inspect来打印我的变量。 最佳答案 您可以在View中使用puts方法将信息输出到服务器控制台。您应该能够在View中的任何位置使用Haml执行以下操作:-puts@my_variable.inspect 关于ruby-on-rails-如何在我的R

  7. ruby-on-rails - 如何在 Gem 中获取 Rails 应用程序的根目录 - 2

    是否可以在应用程序中包含的gem代码中知道应用程序的Rails文件系统根目录?这是gem来源的示例:moduleMyGemdefself.included(base)putsRails.root#returnnilendendActionController::Base.send:include,MyGem谢谢,抱歉我的英语不好 最佳答案 我发现解决类似问题的解决方案是使用railtie初始化程序包含我的模块。所以,在你的/lib/mygem/railtie.rbmoduleMyGemclassRailtie使用此代码,您的模块将在

  8. 世界前沿3D开发引擎HOOPS全面讲解——集3D数据读取、3D图形渲染、3D数据发布于一体的全新3D应用开发工具 - 2

    无论您是想搭建桌面端、WEB端或者移动端APP应用,HOOPSPlatform组件都可以为您提供弹性的3D集成架构,同时,由工业领域3D技术专家组成的HOOPS技术团队也能为您提供技术支持服务。如果您的客户期望有一种在多个平台(桌面/WEB/APP,而且某些客户端是“瘦”客户端)快速、方便地将数据接入到3D应用系统的解决方案,并且当访问数据时,在各个平台上的性能和用户体验保持一致,HOOPSPlatform将帮助您完成。利用HOOPSPlatform,您可以开发在任何环境下的3D基础应用架构。HOOPSPlatform可以帮您打造3D创新型产品,HOOPSSDK包含的技术有:快速且准确的CAD

  9. 叮咚买菜基于 Apache Doris 统一 OLAP 引擎的应用实践 - 2

    导读:随着叮咚买菜业务的发展,不同的业务场景对数据分析提出了不同的需求,他们希望引入一款实时OLAP数据库,构建一个灵活的多维实时查询和分析的平台,统一数据的接入和查询方案,解决各业务线对数据高效实时查询和精细化运营的需求。经过调研选型,最终引入ApacheDoris作为最终的OLAP分析引擎,Doris作为核心的OLAP引擎支持复杂地分析操作、提供多维的数据视图,在叮咚买菜数十个业务场景中广泛应用。作者|叮咚买菜资深数据工程师韩青叮咚买菜创立于2017年5月,是一家专注美好食物的创业公司。叮咚买菜专注吃的事业,为满足更多人“想吃什么”而努力,通过美好食材的供应、美好滋味的开发以及美食品牌的孵

  10. 【鸿蒙应用开发系列】- 获取系统设备信息以及版本API兼容调用方式 - 2

    在应用开发中,有时候我们需要获取系统的设备信息,用于数据上报和行为分析。那在鸿蒙系统中,我们应该怎么去获取设备的系统信息呢,比如说获取手机的系统版本号、手机的制造商、手机型号等数据。1、获取方式这里分为两种情况,一种是设备信息的获取,一种是系统信息的获取。1.1、获取设备信息获取设备信息,鸿蒙的SDK包为我们提供了DeviceInfo类,通过该类的一些静态方法,可以获取设备信息,DeviceInfo类的包路径为:ohos.system.DeviceInfo.具体的方法如下:ModifierandTypeMethodDescriptionstatic StringgetAbiList​()Obt

随机推荐