2026/2/13 0:52:31
网站建设
项目流程
晋城企业网站建设公司,站长联盟,网站建设费分录,网站自助建设平台有哪些目录
一、项目准备与环境依赖
二、数据预处理
1. 数据集加载与划分
2. 构建自定义 Tokenizer
3. 词表构建与文本编码
三、构建 DataLoader
四、搭建 Transformer 翻译模型
1. 位置编码层
2. 完整翻译模型
五、模型训练
六、模型预测
七、全部完整代码 Transformer …目录一、项目准备与环境依赖二、数据预处理1. 数据集加载与划分2. 构建自定义 Tokenizer3. 词表构建与文本编码三、构建 DataLoader四、搭建 Transformer 翻译模型1. 位置编码层2. 完整翻译模型五、模型训练六、模型预测七、全部完整代码Transformer 打破了此前循环神经网络RNN等在序列建模任务中的垄断它以自注意力机制为核心成为 NLP、计算机视觉等领域的基础架构像 GPT、BERT、ViT 等知名模型均基于它构建。以下是其具体介绍及核心优点核心结构编码器通常由多层相同结构堆叠而成每层包含多头自注意力机制、前馈神经网络且配有残差连接与层归一化。先通过输入嵌入将文本等元素转为向量并添加位置编码保留顺序信息再经多头自注意力捕捉元素间全局依赖最后由前馈神经网络进一步处理特征。解码器同样由多层结构组成除了包含编码器类似的前馈神经网络等模块还多了掩码多头自注意力和编码器 - 解码器注意力模块。掩码机制能避免生成序列时模型看到后续元素编码器 - 解码器注意力则可让解码器获取编码器的上下文信息。核心优点并行计算效率高传统 RNN 需按顺序逐个处理序列元素当前输出依赖前一时刻结果无法发挥 GPU 并行计算优势。而 Transformer 借助自注意力机制可同时计算序列中所有元素的依赖关系比如处理百词句子时能同步计算所有词的注意力关联训练速度相比 RNN 可提升 10 - 100 倍原本 RNN 需数周训练的模型Transformer 仅需数天就能完成。长距离依赖建模强RNN 和 LSTM 处理长序列时易出现梯度消失问题导致难以捕捉远距离元素关联例如长句中难以关联首尾的指代关系。Transformer 的每个元素能直接与序列中所有其他元素建立联系通过注意力权重量化关联强度哪怕是序列两端的元素也能一次注意力计算就建立关联完美解决长距离依赖难题。适配多场景且泛化性好其设计的多头注意力可在多个子空间捕捉不同维度语义关联比如机器翻译中既能关注语法结构又能捕捉语义关联。同时还支持掩码注意力、交叉注意力等多种形式适配翻译、文本生成、摘要等不同任务。而且它的统一架构能跨领域迁移不仅在 NLP 领域表现出色在计算机视觉领域将图像分割为类似文本 token 的图像块后也能用 Transformer 建模块间关系在图像分类等任务上超越传统卷积神经网络。在自然语言处理领域机器翻译是极具代表性的任务之一而 Transformer 模型凭借其自注意力机制成为了机器翻译任务的主流架构。本文将详细介绍如何基于 PyTorch 框架从零构建一个简单的中英翻译模型涵盖数据预处理、模型搭建、训练及预测全流程。一、项目准备与环境依赖首先需要搭建对应的开发环境确保安装以下核心依赖库import torch import pandas as pd from sklearn.model_selection import train_test_split from typing import List from torch import nn,optim from torch.nn.utils.rnn import pad_sequence from tqdm import tqdm import json from nltk.tokenize.treebank import TreebankWordDetokenizer,TreebankWordTokenizer from torch.utils.data import Dataset, DataLoader import math二、数据预处理1. 数据集加载与划分本文使用的是中英双语平行语料库cmn.txt包含 29155 条中英对照语句。首先加载数据并划分训练集和测试集# 加载数据 datapd.read_csv(./data/cmn.txt,sep\t,headerNone,usecols[0,1],names[en,zh]) # 划分训练集和测试集8:2 train_df,test_dftrain_test_split(data,test_size0.2)数据集如图2. 构建自定义 Tokenizer为了将文本转换为模型可识别的数字序列我们构建了基础 Tokenizer 类并分别实现中文和英文的分词逻辑中文 Tokenizer按字符级分词因为中文汉字无天然分隔符英文 Tokenizer基于 TreebankWordTokenizer 实现单词级分词同时支持解码将数字序列转回文本。核心代码如下class BaseTokenizer: # 基础Tokenizer类定义通用逻辑 pad_index0 unk_index1 start_index2 end_index3 def __init__(self,vocab_list): self.vocab_listvocab_list self.vocab_size len(vocab_list) self.world2index{value:index for index,value in enumerate(vocab_list)} self.index2world{index:value for index,value in enumerate(vocab_list)} staticmethod def tokenize(text:str)-List[str]: pass def encode(self,text:str,is_markFalse)-List[int]: tokensself.tokenize(text) tokens_index[self.world2index.get(token,self.unk_index) for token in tokens] if is_mark: tokens_index.insert(0,self.start_index) tokens_index.append(self.end_index) return tokens_index classmethod def build_vocab(cls,sentences:List[str],unk_tokenunknown,pad_tokenpadding,start_tokenstart,end_tokenend,vocab_path./vocab.json): vocab_setset() for sentence in tqdm(sentences,desc构建词表): vocab_set.update(cls.tokenize(sentence)) vocab_list [pad_token, unk_token,start_token,end_token] sorted(list(vocab_set)) vocab_dict{index:value for index,value in enumerate(vocab_list)} with open(vocab_path,w,encodingutf-8) as f: json.dump(vocab_dict, f, ensure_asciiFalse, indent2) classmethod def read_vocab(cls,vocab_path./vocab.json): with open(vocab_path,r,encodingutf-8) as f: json_dictjson.load(f) sentences[value for key,value in json_dict.items()] return cls(sentences) class ChinseeTokenizer(BaseTokenizer): # 中文Tokenizer staticmethod def tokenize(text:str)-List[str]: return list(text) class EnglishTokenizer(BaseTokenizer): # 英文Tokenizer tokenizerTreebankWordTokenizer() detokenizerTreebankWordDetokenizer() classmethod def tokenize(cls,text:str)-List[str]: return cls.tokenizer.tokenize(text) def decode(self,indexs:List[str])-str: tokens[self.index2world.get(index,unknown) for index in indexs] return self.detokenizer.detokenize(tokens)3. 词表构建与文本编码基于训练集构建中英词表并将所有文本转换为数字序列最后保存为 JSONL 格式# 构建词表 ChinseeTokenizer.build_vocab(sentencestrain_df[zh].tolist(), vocab_path./zh_vocab.json) EnglishTokenizer.build_vocab(sentencestrain_df[en].tolist(), vocab_path./en_vocab.json) # 加载词表 cn_tokenizer ChinseeTokenizer.read_vocab(./zh_vocab.json) en_tokenizer EnglishTokenizer.read_vocab(./en_vocab.json) # 文本编码 train_df[en]train_df[en].apply(lambda x:en_tokenizer.encode(x,is_markTrue)) train_df[zh]train_df[zh].apply(lambda x:cn_tokenizer.encode(x)) test_df[en]test_df[en].apply(lambda x:en_tokenizer.encode(x,is_markTrue)) test_df[zh]test_df[zh].apply(lambda x:cn_tokenizer.encode(x)) # 保存编码后的数据 train_df.to_json(./train.jsonl,orientrecords,linesTrue) test_df.to_json(./test.jsonl,orientrecords,linesTrue)三、构建 DataLoader自定义 Dataset 类加载编码后的数据并通过 collate_fn 实现批次内序列的 padding 对齐class TranslationDataset(Dataset): def __init__(self,path): self.datapd.read_json(path,orientrecords,linesTrue).to_dict(orientrecords) def __len__(self): return len(self.data) def __getitem__(self,index): input_tensortorch.tensor(self.data[index][zh],dtypetorch.long) target_tensortorch.tensor(self.data[index][en],dtypetorch.long) return input_tensor,target_tensor # 自定义collate_fn实现padding def collate_fn(batch): input_tensor[tensor[0] for tensor in batch] target_tensor[tensor[1] for tensor in batch] input_tensorpad_sequence(sequencesinput_tensor,batch_firstTrue,padding_value0) target_tensorpad_sequence(sequencestarget_tensor,batch_firstTrue,padding_value0) return input_tensor,target_tensor # 构建DataLoader train_datasetTranslationDataset(./train.jsonl) test_datasetTranslationDataset(./test.jsonl) train_dataloaderDataLoader(train_dataset,batch_size32,shuffleTrue,collate_fncollate_fn) test_dataloaderDataLoader(test_dataset,batch_size32,collate_fncollate_fn)四、搭建 Transformer 翻译模型Transformer 模型核心由编码器Encoder、解码器Decoder组成我们还添加了位置编码层为序列注入位置信息1. 位置编码层class PositionalEncoding(nn.Module): def __init__(self,max_len,dim_model): super(PositionalEncoding,self).__init__() petorch.zeros([max_len,dim_model],dtypetorch.float) for pos in range(max_len): for i in range(0,dim_model,2): pe[pos,i] math.sin(pos/(10000**(i/dim_model))) pe[pos,i1] math.cos(pos/(10000**(i/dim_model))) self.register_buffer(pe,pe) def forward(self,x): seq_lenx.shape[1] part_peself.pe[0:seq_len] return xpart_pe2. 完整翻译模型class TranslationModel(nn.Module): def __init__(self,zh_vocab_size,en_vocab_size,zh_padding_idx,en_padding_idx): super(TranslationModel,self).__init__() # 嵌入层 self.zh_embeddingnn.Embedding(num_embeddingszh_vocab_size,embedding_dim128,padding_idxzh_padding_idx) self.en_embeddingnn.Embedding(num_embeddingsen_vocab_size,embedding_dim128,padding_idxen_padding_idx) # 位置编码 self.position_encodingPositionalEncoding(max_len500,dim_model128) # Transformer核心 self.transformernn.Transformer( d_model128, nhead8, num_encoder_layers6, num_decoder_layers6, batch_firstTrue, dropout0.1, ) # 输出层 self.linearnn.Linear(in_features128,out_featuresen_vocab_size) def forward(self,src,tgt,src_pad_mask,tgt_mask): memoryself.encode(src,src_pad_mask) outputsself.decode(tgt,memory,tgt_mask,src_pad_mask) return outputs def encode(self,src,src_pad_mask): zh_embedself.zh_embedding(src) zh_embedself.position_encoding(zh_embed) memoryself.transformer.encoder(srczh_embed,src_key_padding_masksrc_pad_mask) return memory def decode(self,tgt,memory,tgt_mask,memory_pad_mask): en_embedself.en_embedding(tgt) en_embedself.position_encoding(en_embed) outputself.transformer.decoder(tgten_embed,memorymemory,tgt_masktgt_mask,memory_key_padding_maskmemory_pad_mask) outputsself.linear(output) return outputs五、模型训练设置训练超参数定义损失函数和优化器分训练、验证阶段迭代训练模型并保存最优模型# 设备选择 device torch.device(cuda if torch.cuda.is_available() else cpu) # 初始化模型 modelTranslationModel( zh_vocab_sizecn_tokenizer.vocab_size, en_vocab_sizeen_tokenizer.vocab_size, zh_padding_idx0, en_padding_idx0 ).to(device) # 训练配置 epochs5 lr1e-4 loss_fnnn.CrossEntropyLoss(ignore_indexen_tokenizer.pad_index) # 忽略padding的损失 optimizeroptim.Adam(model.parameters(),lrlr) # 训练循环 best_lossfloat(inf) for epoch in range(epochs): print(f第{epoch1}轮) # 训练阶段 model.train() train_total_loss0.0 for train_x,train_y in tqdm(train_dataloader,desc训练): src,tgttrain_x.to(device),train_y.to(device) decoder_inputstgt[:,:-1] # 解码器输入去掉最后一个token decoder_targetstgt[:,1:] # 解码器目标去掉第一个token src_pad_mask(srcmodel.zh_embedding.padding_idx) tgt_maskmodel.transformer.generate_square_subsequent_mask(szdecoder_inputs.shape[1]).to(device) pred_ymodel(src,decoder_inputs,src_pad_mask,tgt_mask) lossloss_fn(pred_y.reshape(-1,pred_y.shape[-1]),decoder_targets.reshape(-1)) optimizer.zero_grad() loss.backward() optimizer.step() train_total_lossloss.item() # 验证阶段 model.eval() test_total_loss0.0 with torch.no_grad(): for test_x, test_y in tqdm(test_dataloader,desc验证): src,tgttest_x.to(device),test_y.to(device) decoder_inputstgt[:,:-1] decoder_targetstgt[:,1:] src_pad_mask(srcmodel.zh_embedding.padding_idx) tgt_maskmodel.transformer.generate_square_subsequent_mask(szdecoder_inputs.shape[1]).to(device) pred_ymodel(src,decoder_inputs,src_pad_mask,tgt_mask) lossloss_fn(pred_y.reshape(-1,pred_y.shape[-1]),decoder_targets.reshape(-1)) test_total_lossloss.item() # 计算平均损失 avg_train_losstrain_total_loss/len(train_dataloader) avg_test_losstest_total_loss/len(test_dataloader) print(f训练平均loss:{avg_train_loss},验证平均loss:{avg_test_loss}) # 保存最优模型 if test_total_lossbest_loss: best_losstest_total_loss torch.save(model.state_dict(),./best_model.pt)模型训练效果如下总体loss下降还是比较明显效果也是非常不错六、模型预测实现预测函数输入中文文本输出模型翻译的英文结果def predict(model,text,device): # 编码输入文本 textcn_tokenizer.encode(texttext) model.eval() with torch.no_grad(): srctorch.tensor(text,dtypetorch.long).unsqueeze(0).to(device) src_pad_mask(srcmodel.zh_embedding.padding_idx) # 编码器编码 memorymodel.encode(src,src_pad_mask) batch_sizesrc.shape[0] # 解码器初始输入start token decoder_inputtorch.full([batch_size,1],en_tokenizer.start_index,devicedevice) generated[] is_finishedtorch.full([batch_size],False,devicedevice) # 逐token生成 for i in range(500): tgt_maskmodel.transformer.generate_square_subsequent_mask(szdecoder_input.shape[1]).to(device) decoder_outputmodel.decode(decoder_input,memory,tgt_mask,src_pad_mask) # 取概率最大的token next_token_indextorch.argmax(decoder_output[:,-1,:],dim-1,keepdimTrue) generated.append(next_token_index) decoder_inputtorch.cat([decoder_input,next_token_index],dim-1) # 判断是否生成end token is_finished |(next_token_index.squeeze(1)en_tokenizer.end_index) if is_finished.all(): break # 处理生成结果 generated_tensortorch.cat(generated,dim-1) generated_listgenerated_tensor.tolist() for index,value in enumerate(generated_list): if en_tokenizer.end_index in value: end_posvalue.index(en_tokenizer.end_index) generated_list[index]value[:end_pos] # 解码为文本 return en_tokenizer.decode(generated_list[0]) # 加载最优模型并预测 devicetorch.device(cuda if torch.cuda.is_available() else cpu) modelTranslationModel( zh_vocab_sizecn_tokenizer.vocab_size, en_vocab_sizeen_tokenizer.vocab_size, zh_padding_idx0, en_padding_idx0 ).to(device) model.load_state_dict(torch.load(./best_model.pt)) # 测试预测 text我是你爸爸 resultpredict(model,text,device) print(f输入{text}) print(f输出{result}) # 输出Im your father.七、全部完整代码import torch import pandas as pd from sklearn.model_selection import train_test_split from typing import List from torch import nn,optim from torch.nn.utils.rnn import pad_sequence from tqdm import tqdm import json from nltk.tokenize.treebank import TreebankWordDetokenizer,TreebankWordTokenizer from torch.utils.data import Dataset, DataLoader import math # 加载数据 datapd.read_csv(./data/cmn.txt,sep\t,headerNone,usecols[0,1],names[en,zh]) # 划分训练集和验证集 train_df,test_dftrain_test_split(data,test_size0.2) # 构建分词器 class BaseTokenizer: # 构建tokenizer pad_index0 unk_index1 start_index2 end_index3 def __init__(self,vocab_list): self.vocab_listvocab_list self.vocab_size len(vocab_list) self.world2index{value:index for index,value in enumerate(vocab_list)} self.index2world{index:value for index,value in enumerate(vocab_list)} staticmethod def tokenize(text:str)-List[str]: pass def encode(self,text:str,is_markFalse)-List[int]: tokensself.tokenize(text) tokens_index[self.world2index.get(token,self.unk_index) for token in tokens] if is_mark: tokens_index.insert(0,self.start_index) tokens_index.append(self.end_index) return tokens_index classmethod def build_vocab( cls,sentences:List[str], unk_token:strunknown, pad_token:strpadding, start_token:strstart, end_token:strend, vocab_path:str./vocab.json ): vocab_setset() for sentence in tqdm(sentences,desc构建词表): vocab_set.update(cls.tokenize(sentence)) vocab_list [pad_token, unk_token,start_token,end_token] sorted(list(vocab_set)) vocab_dict{index:value for index,value in enumerate(vocab_list)} with open(vocab_path,w,encodingutf-8) as f: json.dump(vocab_dict, f, ensure_asciiFalse, indent2) classmethod def read_vocab(cls,vocab_path:str./vocab.json): with open(vocab_path,r,encodingutf-8) as f: json_dictjson.load(f) sentences[value for key,value in json_dict.items()] return cls(sentences) # 中文分词器 class ChinseeTokenizer(BaseTokenizer): staticmethod def tokenize(text:str)-List[str]: return list(text) # 英文分词器 class EnglishTokenizer(BaseTokenizer): tokenizerTreebankWordTokenizer() detokenizerTreebankWordDetokenizer() classmethod def tokenize(cls,text:str)-List[str]: return cls.tokenizer.tokenize(text) def decode(self,indexs:List[str])-str: tokens[self.index2world.get(index,unknown) for index in indexs] return self.detokenizer.detokenize(tokens) # 创建词表 ChinseeTokenizer.build_vocab(sentencestrain_df[zh].tolist(), vocab_path./zh_vocab.json) EnglishTokenizer.build_vocab(sentencestrain_df[en].tolist(), vocab_path./en_vocab.json) cn_tokenizer ChinseeTokenizer.read_vocab(./zh_vocab.json) en_tokenizer EnglishTokenizer.read_vocab(./en_vocab.json) train_df[en]train_df[en].apply(lambda x:en_tokenizer.encode(x,is_markTrue)) train_df[zh]train_df[zh].apply(lambda x:cn_tokenizer.encode(x)) test_df[en]test_df[en].apply(lambda x:en_tokenizer.encode(x,is_markTrue)) test_df[zh]test_df[zh].apply(lambda x:cn_tokenizer.encode(x)) train_df.to_json(./train.jsonl,orientrecords,linesTrue) test_df.to_json(./test.jsonl,orientrecords,linesTrue) # 构建Dataloader class TranslationDataset(Dataset): def __init__(self,path): self.datapd.read_json(path,orientrecords,linesTrue).to_dict(orientrecords) def __len__(self): return len(self.data) def __getitem__(self,index): input_tensortorch.tensor(self.data[index][zh],dtypetorch.long) target_tensortorch.tensor(self.data[index][en],dtypetorch.long) return input_tensor,target_tensor def collate_fn(batch): input_tensor[tensor[0] for tensor in batch] target_tensor[tensor[1] for tensor in batch] input_tensorpad_sequence(sequencesinput_tensor,batch_firstTrue,padding_value0) target_tensorpad_sequence(sequencestarget_tensor,batch_firstTrue,padding_value0) return input_tensor,target_tensor train_datasetTranslationDataset(./train.jsonl) test_datasetTranslationDataset(./test.jsonl) train_dataloaderDataLoader(train_dataset,batch_size32,shuffleTrue,collate_fncollate_fn) test_dataloaderDataLoader(test_dataset,batch_size32,collate_fncollate_fn) # 构建位置编码 class PositionalEncoding(nn.Module): def __init__(self,max_len,dim_model): super(PositionalEncoding,self).__init__() petorch.zeros([max_len,dim_model],dtypetorch.float) for pos in range(max_len): for i in range(0,dim_model,2): pe[pos,i] math.sin(pos/(10000**(i/dim_model))) pe[pos,i1] math.cos(pos/(10000**(i/dim_model))) self.register_buffer(pe,pe) def forward(self,x): seq_lenx.shape[1] part_peself.pe[0:seq_len] return xpart_pe # 构建模型 class TranslationModel(nn.Module): def __init__(self,zh_vocab_size,en_vocab_size,zh_padding_idx,en_padding_idx): super(TranslationModel,self).__init__() self.zh_embeddingnn.Embedding(num_embeddingszh_vocab_size,embedding_dim128,padding_idxzh_padding_idx) self.en_embeddingnn.Embedding(num_embeddingsen_vocab_size,embedding_dim128,padding_idxen_padding_idx) self.position_encodingPositionalEncoding(max_len500,dim_model128) self.transformernn.Transformer( d_model128, nhead8, num_encoder_layers6, num_decoder_layers6, batch_firstTrue, dropout0.1, ) self.linearnn.Linear(in_features128,out_featuresen_vocab_size) def forward(self,src,tgt,src_pad_mask,tgt_mask): memoryself.encode(src,src_pad_mask) outputsself.decode(tgt,memory,tgt_mask,src_pad_mask) return outputs def encode(self,src,src_pad_mask): zh_embedself.zh_embedding(src) zh_embedself.position_encoding(zh_embed) memoryself.transformer.encoder(srczh_embed,src_key_padding_masksrc_pad_mask) return memory def decode(self,tgt,memory,tgt_mask,memory_pad_mask): en_embedself.en_embedding(tgt) en_embedself.position_encoding(en_embed) outputself.transformer.decoder(tgten_embed,memorymemory,tgt_masktgt_mask,memory_key_padding_maskmemory_pad_mask) outputsself.linear(output) return outputs device torch.device(cuda if torch.cuda.is_available() else cpu) modelTranslationModel( zh_vocab_sizecn_tokenizer.vocab_size, en_vocab_sizeen_tokenizer.vocab_size, zh_padding_idx0, en_padding_idx0 ).to(device) # 定义模型超参数 epochs5 lr1e-4 loss_fnnn.CrossEntropyLoss(ignore_indexen_tokenizer.pad_index) optimizeroptim.Adam(model.parameters(),lrlr) # 模型训练和保存 best_lossfloat(inf) for epoch in range(epochs): print(f第{epoch1}轮) model.train() train_total_loss0.0 for train_x,train_y in tqdm(train_dataloader,desc训练): src,tgttrain_x.to(device),train_y.to(device) decoder_inputstgt[:,:-1] decoder_targetstgt[:,1:] src_pad_mask(srcmodel.zh_embedding.padding_idx) tgt_maskmodel.transformer.generate_square_subsequent_mask(szdecoder_inputs.shape[1]).to(device) pred_ymodel(src,decoder_inputs,src_pad_mask,tgt_mask) lossloss_fn(pred_y.reshape(-1,pred_y.shape[-1]),decoder_targets.reshape(-1)) optimizer.zero_grad() loss.backward() optimizer.step() train_total_lossloss.item() model.eval() test_total_loss0.0 with torch.no_grad(): for test_x, test_y in tqdm(test_dataloader,desc验证): src,tgttest_x.to(device),test_y.to(device) decoder_inputstgt[:,:-1] decoder_targetstgt[:,1:] src_pad_mask(srcmodel.zh_embedding.padding_idx) tgt_maskmodel.transformer.generate_square_subsequent_mask(szdecoder_inputs.shape[1]).to(device) pred_ymodel(src,decoder_inputs,src_pad_mask,tgt_mask) lossloss_fn(pred_y.reshape(-1,pred_y.shape[-1]),decoder_targets.reshape(-1)) test_total_lossloss.item() avg_train_losstrain_total_loss/len(train_dataloader) avg_test_losstest_total_loss/len(test_dataloader) print(f训练平均loss:{avg_train_loss},验证平均loss:{avg_test_loss}) if test_total_lossbest_loss: best_losstest_total_loss torch.save(model.state_dict(),./best_model.pt) # 模型测试 def predict(model,text,device): textcn_tokenizer.encode(texttext) model.eval() with torch.no_grad(): srctorch.tensor(text,dtypetorch.long).unsqueeze(0).to(device) src_pad_mask(srcmodel.zh_embedding.padding_idx) memorymodel.encode(src,src_pad_mask) batch_sizesrc.shape[0] decoder_inputtorch.full([batch_size,1],en_tokenizer.start_index,devicedevice) generated[] is_finishedtorch.full([batch_size],False,devicedevice) for i in range(500): tgt_maskmodel.transformer.generate_square_subsequent_mask(szdecoder_input.shape[1]).to(device) decoder_outputmodel.decode(decoder_input,memory,tgt_mask,src_pad_mask) next_token_indextorch.argmax(decoder_output[:,-1,:],dim-1,keepdimTrue) generated.append(next_token_index) decoder_inputtorch.cat([decoder_input,next_token_index],dim-1) is_finished |(next_token_index.squeeze(1)en_tokenizer.end_index) if is_finished.all(): break generated_tensortorch.cat(generated,dim-1) generated_listgenerated_tensor.tolist() for index,value in enumerate(generated_list): if en_tokenizer.end_index in value: end_posvalue.index(en_tokenizer.end_index) generated_list[index]value[:end_pos] return en_tokenizer.decode(generated_list[0]) devicetorch.device(cuda if torch.cuda.is_available() else cpu) modelTranslationModel( zh_vocab_sizecn_tokenizer.vocab_size, en_vocab_sizeen_tokenizer.vocab_size, zh_padding_idx0, en_padding_idx0 ).to(device) model.load_state_dict(torch.load(./best_model.pt)) text我是你爸爸 resultpredict(model,text,device) result