You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.3 KiB
75 lines
2.3 KiB
import torch
|
|
from flask import Flask, request
|
|
import torch.nn as nn
|
|
import pickle
|
|
import jieba
|
|
|
|
app = Flask(__name__)
|
|
# 初始化模型
|
|
result = ""
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.embedding = nn.Embedding(10002, 300, padding_idx=10001)
|
|
self.lstm = nn.LSTM(300, 128, 2, bidirectional=True, batch_first=True, dropout=0.5)
|
|
self.fc = nn.Linear(128 * 2, 2)
|
|
|
|
def forward(self, x):
|
|
x, _ = x
|
|
out = self.embedding(x) # [batch_size, seq_len, embeding]=[128, 32, 300]
|
|
out, _ = self.lstm(out)
|
|
out = self.fc(out[:, -1, :]) # 句子最后时刻的 hidden state
|
|
return out
|
|
|
|
|
|
model = Model()
|
|
model.load_state_dict(torch.load('./THUCNews/saved_dict/TextRNN.ckpt', map_location=torch.device("cpu")))
|
|
model.eval()
|
|
stopwords = open('./THUCNews/data/hit_stopwords.txt', encoding='utf8').read().split('\n')[:-1]
|
|
|
|
vocab = pickle.load(open('./THUCNews/data/vocab.pkl', 'rb'))
|
|
classes = ['negative', 'positive']
|
|
|
|
s = '空调吵,住在电梯旁,电梯门口放垃圾箱,极臭,布草间没关门,也臭,臭到房间里,门下塞毛巾也挡不住臭味,开窗外面吵,关窗空调吵,楼下早餐桌子上摆满垃圾没人整理,不能再差的体验了'
|
|
|
|
|
|
# s = '东东还算不错。重装系统时,网上查不到怎么修改BIOS,才能安装?问题请帮忙解决!'
|
|
|
|
|
|
@app.route('/api/content', methods=["POST"])
|
|
def content():
|
|
get_json = request.get_json()
|
|
global model
|
|
global result
|
|
global classes
|
|
result = ""
|
|
s = get_json.get("content")[0]
|
|
print(6777, s)
|
|
try:
|
|
s = list(jieba.lcut(s))
|
|
s = [i for i in s if i not in stopwords]
|
|
s = [vocab.get(i, 10000) for i in s]
|
|
if len(s) > 64:
|
|
s = s[:64]
|
|
else:
|
|
for i in range(64 - len(s)):
|
|
s.append(vocab['<PAD>'])
|
|
|
|
outputs = model((torch.LongTensor(s).unsqueeze(0), None))
|
|
print(torch.argmax(outputs))
|
|
result = classes[torch.argmax(outputs)]
|
|
except Exception as e: # 未捕获到异常,程序直接报错
|
|
result = e
|
|
return "pridicting"
|
|
|
|
|
|
@app.route('/api/model_res', methods=['GET'])
|
|
def model_res():
|
|
global result
|
|
return str(result)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run(host="127.0.0.1", port=8006)
|
|
|