import torch from flask import Flask, request import torch.nn as nn import pickle import jieba app = Flask(__name__) # 初始化模型 result = "" class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.embedding = nn.Embedding(10002, 300, padding_idx=10001) self.lstm = nn.LSTM(300, 128, 2, bidirectional=True, batch_first=True, dropout=0.5) self.fc = nn.Linear(128 * 2, 2) def forward(self, x): x, _ = x out = self.embedding(x) # [batch_size, seq_len, embeding]=[128, 32, 300] out, _ = self.lstm(out) out = self.fc(out[:, -1, :]) # 句子最后时刻的 hidden state return out model = Model() model.load_state_dict(torch.load('./THUCNews/saved_dict/TextRNN.ckpt', map_location=torch.device("cpu"))) model.eval() stopwords = open('./THUCNews/data/hit_stopwords.txt', encoding='utf8').read().split('\n')[:-1] vocab = pickle.load(open('./THUCNews/data/vocab.pkl', 'rb')) classes = ['negative', 'positive'] s = '空调吵,住在电梯旁,电梯门口放垃圾箱,极臭,布草间没关门,也臭,臭到房间里,门下塞毛巾也挡不住臭味,开窗外面吵,关窗空调吵,楼下早餐桌子上摆满垃圾没人整理,不能再差的体验了' # s = '东东还算不错。重装系统时,网上查不到怎么修改BIOS,才能安装?问题请帮忙解决!' @app.route('/api/content', methods=["POST"]) def content(): get_json = request.get_json() global model global result global classes result = "" s = get_json.get("content")[0] print(6777, s) try: s = list(jieba.lcut(s)) s = [i for i in s if i not in stopwords] s = [vocab.get(i, 10000) for i in s] if len(s) > 64: s = s[:64] else: for i in range(64 - len(s)): s.append(vocab['']) outputs = model((torch.LongTensor(s).unsqueeze(0), None)) print(torch.argmax(outputs)) result = classes[torch.argmax(outputs)] except Exception as e: # 未捕获到异常,程序直接报错 result = e return "pridicting" @app.route('/api/model_res', methods=['GET']) def model_res(): global result return str(result) if __name__ == '__main__': app.run(host="127.0.0.1", port=8006)