#需要更改训练集、验证集、测试集的文件位置importpaddlehubashub!hubinstallerniemodule=hub.Module(name="ernie")classDemoDataset(BaseNLPDataset):"""DemoDataset"""def__init__(self):# 数据集存放位置self.dataset_dir="work"super(DemoDataset,self).__init__(base_path=self.dataset_dir,train_file="train.tsv",dev_file="dev.tsv",test_file="test.tsv",# 如果还有待预测数据,可以放在predict.tsvpredict_file="predict.tsv",train_file_with_header=True,dev_file_with_header=True,test_file_with_header=True,predict_file_with_header=True,# 数据集类别label_list=["0","1"])dataset=DemoDataset()reader=hub.reader.ClassifyReader(dataset=dataset,vocab_path=module.get_vocab_path(),sp_model_path=module.get_spm_path(),word_dict_path=module.get_word_dict_path(),max_seq_len=128)strategy=hub.AdamWeightDecayStrategy(weight_decay=0.01,warmup_proportion=0.1,learning_rate=5e-5)config=hub.RunConfig(use_cuda=True,num_epoch=1,checkpoint_dir="ernie_txt_cls_turtorial_demo",batch_size=128,#一般为2^neval_interval=10,strategy=strategy)inputs,outputs,program=module.context(trainable=True,max_seq_len=128)# Use "pooled_output" for classification tasks on an entire sentence.pooled_output=outputs["pooled_output"]feed_list=[inputs["input_ids"].name,inputs["position_ids"].name,inputs["segment_ids"].name,inputs["input_mask"].name,]cls_task=hub.TextClassifierTask(data_reader=reader,feature=pooled_output,feed_list=feed_list,num_classes=dataset.num_labels,config=config)run_states=cls_task.finetune_and_eval()# 预测data=[[d.text_a]fordindataset.get_predict_examples()]run_states=cls_task.predict(data=data)results=[run_state.run_resultsforrun_stateinrun_states]