天天看点

task4-模型评估

【模型评估(2天)】 记录5个模型(逻辑回归、SVM、决策树、随机森林、XGBoost)关于accuracy、precision,recall和F1-score、auc值的评分表格,并画出ROC曲线。 

from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score, roc_curve

from matplotlib import pyplot as plt

# 定义评估函数

def model_metrics(clf, X_train, X_test, y_train, y_test):

    # 预测

    y_train_pred = clf.predict(X_train)

    y_test_pred = clf.predict(X_test)

    y_train_pred_proba = clf.predict_proba(X_train)[:, 1]

    y_test_pred_proba = clf.predict_proba(X_test)[:, 1]

    # 评估

    # 准确性

    print('准确性:')

    print('Train:{:.4f}'.format(accuracy_score(y_train, y_train_pred)))

    print('Test:{:.4f}'.format(accuracy_score(y_test, y_test_pred)))

    # 召回率

    print('召回率:')

    print('Train:{:.4f}'.format(recall_score(y_train, y_train_pred)))

    print('Test:{:.4f}'.format(recall_score(y_test, y_test_pred)))

    # f1_score

    print('f1_score:')

    print('Train:{:.4f}'.format(f1_score(y_train, y_train_pred)))

    print('Test:{:.4f}'.format(f1_score(y_test, y_test_pred)))

    # roc_auc

    print('roc_auc:')

    print('Train:{:.4f}'.format(roc_auc_score(y_train, y_train_pred_proba)))

    print('Test:{:.4f}'.format(roc_auc_score(y_test, y_test_pred_proba)))    

    # 描绘 ROC 曲线

    fpr_tr, tpr_tr, _ = roc_curve(y_train, y_train_pred_proba)

    fpr_te, tpr_te, _ = roc_curve(y_test, y_test_pred_proba)

    # KS

    print('KS:')

    print('Train:{:.4f}'.format(max(abs((fpr_tr - tpr_tr)))))

    print('Test:{:.4f}'.format(max(abs((fpr_te - tpr_te)))))

    # 绘图

    plt.plot(fpr_tr, tpr_tr, 'r-',

             7label="Train:AUC: {:.3f} KS:{:.3f}".format(roc_auc_score(y_train, y_train_pred_proba), 

                                                        max(abs((fpr_tr - tpr_tr)))))

    plt.plot(fpr_te, tpr_te, 'g-',

             label="Test:AUC: {:.3f} KS:{:.3f}".format(roc_auc_score(y_test, y_test_pred_proba),

                                                     max(abs((fpr_tr - tpr_tr)))))

    plt.plot([0, 1], [0, 1], 'd--')

    plt.legend(loc='best')

    plt.title("ROC curse")

    plt.show()

参考:https://blog.csdn.net/l75326747/article/details/84233247 

最优:https://blog.csdn.net/cchengone/article/details/88366003

继续阅读