1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
| close all clear clc
res =readmatrix('sj3.xlsx');
X = res(:,1:end-1); Y = res(:,end); [x,psin] = mapminmax(X',0,1); [y,psout] = mapminmax(Y',0,1);
num = size(res,1); state =randperm(num); ratio = 0.7; train_num = floor(num*ratio);
x_train = x(:,state(1:train_num))'; y_train = y(state(1:train_num))'; x_test = x(:,state(train_num+1:end))'; y_test = y(state(train_num+1:end))';
trees = 100; leaf = 3; wuc = 'on'; Importance = 'on';
net = TreeBagger(trees,x_train,y_train,'OOBPredictorImportance',Importance,... 'Method','regression','OOBPrediction',wuc,'minleaf',leaf); import = net.OOBPermutedPredictorDeltaError;
re1 = predict(net , x_train); re2 = predict(net , x_test);
Y_train = Y(state(1:train_num)); Y_test = Y(state(train_num+1:end));
pre1 = mapminmax('reverse',re1,psout); pre2 = mapminmax('reverse',re2,psout);
error1 = sqrt(mean(pre1 - Y_train).^2); error2 = sqrt(mean(pre2 - Y_test).^2);
R1 = 1 - norm(Y_train - pre1)^2 / norm(Y_train - mean(Y_train))^2; R2 = 1 - norm(Y_test - pre2)^2 / norm(Y_test - mean(Y_test))^2;
mae1 = mean(abs(Y_train - pre1)); mae2 = mean(abs(pre2 - Y_test));
disp('训练集预测精度指标如下:') disp(['训练集的R2:',num2str(R1)]) disp(['训练集的MAE:',num2str(mae1)]) disp(['训练集的RMSE:',num2str(error1)]) disp('测试集预测精度指标如下:') disp(['测试集的R2:',num2str(R2)]) disp(['测试集的MAE:',num2str(mae2)]) disp(['测试集的RMSE:',num2str(error2)])
figure plot(1:train_num,Y_train,'r-O',1:train_num,pre1,'b-+','LineWidth',1) legend('真实值','预测值') xlabel('样本点') ylabel('预测值') title('训练集预测结果对比') figure plot(1:num-train_num,Y_test,'r-O',1:num-train_num,pre2,'b-+','LineWidth',1) legend('真实值','预测值') xlabel('样本点') ylabel('预测值') title('测试集预测结果对比')
figure plot(1:trees,oobError(net),'r--O','LineWidth',1) legend('误差迭代曲线') xlabel('决策树(迭代次数)') ylabel('误差') grid
figure bar(import,'green') yticks([]) xlabel('特征') ylabel('重要性')
|