You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
4.9 KiB
154 lines
4.9 KiB
from flask import Flask, request, jsonify
|
|
from flask_cors import CORS
|
|
import pandas as pd
|
|
import numpy as np
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from sklearn.preprocessing import LabelEncoder
|
|
from sklearn.metrics import classification_report
|
|
from sklearn.metrics import roc_curve, auc
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import os
|
|
|
|
app = Flask(__name__)
|
|
CORS(app) # 启用 CORS
|
|
# 全局变量
|
|
data = None
|
|
rf_model = None
|
|
dt_model = None
|
|
|
|
def load_data(file_path):
|
|
"""加载数据集并进行初步处理"""
|
|
global data
|
|
data = pd.read_csv(file_path)
|
|
|
|
# 定义列名映射字典
|
|
col = {
|
|
'age': '年龄',
|
|
'bp': '血压',
|
|
'sg': '比重',
|
|
'al': '白蛋白',
|
|
'su': '糖',
|
|
'rbc': '红细胞',
|
|
'pc': '脓细胞',
|
|
'pcc': '脓细胞团',
|
|
'ba': '细菌',
|
|
'bgr': '随机血糖',
|
|
'bu': '血尿素',
|
|
'sc': '血清肌酐',
|
|
'sod': '钠',
|
|
'pot': '钾',
|
|
'hemo': '血红蛋白',
|
|
'pcv': '红细胞压积',
|
|
'wc': '白细胞计数',
|
|
'rc': '红细胞计数',
|
|
'htn': '高血压',
|
|
'dm': '糖尿病',
|
|
'cad': '冠心病',
|
|
'appet': '食欲',
|
|
'pe': '肢端水肿',
|
|
'ane': '贫血',
|
|
'classification': '分类'
|
|
}
|
|
|
|
# 重命名数据框中的列
|
|
data.rename(columns=col, inplace=True)
|
|
|
|
# 处理数据
|
|
preprocess_data()
|
|
|
|
def preprocess_data():
|
|
"""处理数据,包括填补缺失值和编码分类变量"""
|
|
global data
|
|
|
|
# 替换异常值
|
|
data['糖尿病'] = data['糖尿病'].replace({'\tno': 'no', '\tyes': 'yes', ' yes': 'yes'})
|
|
data['冠心病'] = data['冠心病'].replace('\tno', 'no')
|
|
data['分类'] = data['分类'].replace('ckd\t', 'ckd')
|
|
|
|
# 将必要的列转换为数值类型
|
|
data['红细胞压积'] = pd.to_numeric(data['红细胞压积'], errors='coerce')
|
|
data['白细胞计数'] = pd.to_numeric(data['白细胞计数'], errors='coerce')
|
|
data['红细胞计数'] = pd.to_numeric(data['红细胞计数'], errors='coerce')
|
|
|
|
# 填补缺失值
|
|
numeric_cols = data.select_dtypes(include=['number']).columns
|
|
for col in numeric_cols:
|
|
data[col].fillna(data[col].mean(), inplace=True)
|
|
|
|
# 对于分类变量,用众数填充
|
|
cat_cols = data.select_dtypes(include=['object']).columns
|
|
for col in cat_cols:
|
|
data[col].fillna(data[col].mode()[0], inplace=True)
|
|
|
|
# 标签编码
|
|
l = LabelEncoder()
|
|
for col in cat_cols:
|
|
data[col] = l.fit_transform(data[col])
|
|
|
|
def train_models():
|
|
"""训练决策树和随机森林模型"""
|
|
global rf_model, dt_model
|
|
x = data.drop(['id', '分类'], axis=1, errors='ignore') # 特征变量
|
|
y = data['分类'] # 目标变量
|
|
|
|
# 划分训练集和测试集
|
|
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.30, random_state=0)
|
|
|
|
# 创建并训练决策树模型
|
|
dt_model = DecisionTreeClassifier(max_depth=10)
|
|
dt_model.fit(X_train, y_train)
|
|
|
|
# 创建并训练随机森林模型
|
|
rf_model = RandomForestClassifier(n_estimators=100, max_depth=10)
|
|
rf_model.fit(X_train, y_train)
|
|
|
|
@app.route('/predict', methods=['POST'])
|
|
def predict():
|
|
"""进行预测"""
|
|
print( 444,request.json)
|
|
input_data = request.json
|
|
input_df = pd.DataFrame([input_data])
|
|
|
|
# 进行预测
|
|
dt_prediction = dt_model.predict(input_df)
|
|
rf_prediction = rf_model.predict(input_df)
|
|
|
|
return jsonify({
|
|
'decision_tree_prediction': int(dt_prediction[0]),
|
|
'random_forest_prediction': int(rf_prediction[0])
|
|
})
|
|
|
|
@app.route('/evaluate', methods=['GET'])
|
|
def evaluate():
|
|
"""评估模型性能"""
|
|
global data
|
|
x = data.drop(['id', '分类'], axis=1, errors='ignore')
|
|
y = data['分类']
|
|
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.30, random_state=0)
|
|
|
|
# 计算决策树模型的预测概率
|
|
y_pred_dt_proba = dt_model.predict_proba(X_test)[:, 1]
|
|
fpr_dt, tpr_dt, _ = roc_curve(y_test, y_pred_dt_proba)
|
|
roc_auc_dt = auc(fpr_dt, tpr_dt)
|
|
|
|
# 计算随机森林模型的预测概率
|
|
y_pred_rf_proba = rf_model.predict_proba(X_test)[:, 1]
|
|
fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_rf_proba)
|
|
roc_auc_rf = auc(fpr_rf, tpr_rf)
|
|
|
|
return jsonify({
|
|
'decision_tree_auc': roc_auc_dt,
|
|
'random_forest_auc': roc_auc_rf,
|
|
'decision_tree_report': classification_report(y_test, dt_model.predict(X_test), output_dict=True),
|
|
'random_forest_report': classification_report(y_test, rf_model.predict(X_test), output_dict=True)
|
|
})
|
|
|
|
if __name__ == '__main__':
|
|
# 使用相对路径加载数据和训练模型
|
|
relative_path = os.path.join(os.path.dirname(__file__), "kidney_disease.csv")
|
|
load_data(relative_path)
|
|
train_models()
|
|
app.run(debug=True)
|
|
|