回归
简单线性回归
简单线性回归是指, 对于一系列的离散数据点, 用一条直线来对其进行模拟.如图
根据直线的解析式我们可以知道, 其实只要找到β0和β1两个参数, 直线就被确定下来了. 那么在无数条接近数据点的直线中, 哪一条会是最有可能适合的呢? 这里首先要对误差进行定义, 如图中所示, 真实数据点与回归直线上数据点的距离就是误差, 距离越远误差就越大, 我们的目的就是找一条直线, 使得所有点误差的总和最小. 由于有些点在直线上方, 有些点在直线下方, 因此用y的差值的平方来衡量误差的大小. 可以记为:
SSE=i∑ϵi2=i∑(yi−yi)2
代入直线方程y^i=β0+β1xi 有:
SSE=i∑(yi−β0−β1xi)2
于是, 问题转化为求β0和β1, 使得SSE取得最小值的问题. 首先求偏导数, 另其为零
∂β0∂∑iϵi2=−2i∑(yi−β0−β1xi)=0
∂β1∂∑iϵi2=−2i∑(yi−β0−β1xi)xi=0
虽然上面的公式比较复杂, 但实际上未知数只有β0和β1两个, 假设一共有N个数据点, 对上式进行拆分
i=1∑Nyi=β0⋅N+β1i=1∑Nxi
i=1∑Nyixi=β0i=1∑Nxi+β1i=1∑Nxi2
写成矩阵形式:
[∑yi∑yixi]=[N∑xi∑xi∑xi2][β0β1]
matlab提供内置函数polyfit()
来解上面的方程来获得β1和β0,
1 2 3
| x =[-1.2 -0.5 0.3 0.9 1.8 2.6 3.0 3.5]; y =[-15.6 -8.5 2.2 4.5 6.6 8.2 8.9 10.0]; fit = polyfit(x,y,1);
|
计算结果为:
其中fit(1)
存储的是β1, fit(2)
存储的是β0, 可以使用下面代码来绘制回归直线:
1 2 3
| xfit = [x(1):0.1:x(end)]; yfit = fit(1)*xfit + fit(2); plot(x,y,'ro',xfit,yfit); set(gca,'FontSize',14); legend(2,'data points','best-fit');
|
非线性回归
对于给定的离散数据点而言, 其实二者未必具有线性关系. matlab中可以使用scatter()
函数来画出数据点的图像 , 使用corrcoef()
函数可以用来求数据的相关系数. 相关系数的值在[-1,1]区间内, 大于0则称正相关, 小于0称负相关.
1 2 3 4
| x =[-1.2 -0.5 0.3 0.9 1.8 2.6 3.0 3.5]; y =[-15.6 -8.5 2.2 4.5 6.6 8.2 8.9 10.0]; scatter(x,y); box on; axis square; corrcoef(x,y)
|
绘图结果:
相关系数:
此结果是一个对称矩阵, 其中ans(1,1)
和ans(2,2)
位置表示的是x与x, y与y自身的相关系数, 所以为1. ans(1,2)
和ans(2,1)
则表示y 与x 的相关系数. 从相关系数的计算结果来看, 其实还是比较线性相关的.但实际上也可以通过曲线来做回归, 可以参考下面代码:
1 2 3 4 5 6 7 8 9
| x =[-1.2 -0.5 0.3 0.9 1.8 2.6 3.0 3.5]; y =[-15.6 -8.5 2.2 4.5 6.6 8.2 8.9 10.0]; figure('Position', [50 50 1500 400]); for i=1:3 subplot(1,3,i); p = polyfit(x,y,i); xfit = x(1):0.1:x(end); yfit = polyval(p,xfit); plot(x,y,'ro',xfit,yfit); set(gca,'FontSize',14); ylim([-17, 11]); legend(4,'Data points','Fitted curve'); end
|
绘图结果:
从结果上看, 似乎三次曲线的效果比较好. 但是需要了解的是, 并不是使用曲线的次数越高越好.
三维空间回归
如果数据点是三维的, 需要使用做的是回归平面或者回归曲面, 需要用到matlab的内置函数regress()
, 例如
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| load carsmall; y = MPG; x1 = Weight; x2 = Horsepower; X = [ones(length(x1),1) x1 x2]; b = regress(y,X); x1fit = min(x1):100:max(x1); x2fit = min(x2):10:max(x2); [X1FIT,X2FIT]=meshgrid(x1fit,x2fit); YFIT=b(1)+b(2)*X1FIT+b(3)*X2FIT; scatter3(x1,x2,y,'filled'); hold on; mesh(X1FIT,X2FIT,YFIT); hold off; xlabel('Weight'); ylabel('Horsepower'); zlabel('MPG'); view(50,10);
|
绘图结果:
插值
插值主要用到的函数的超链接:
下面分别给出interp1
的使用示例和spline
的使用示例, 关于插值的内容, 我的学习效果不是很好, 如果以后会用到的话, 在进一步学习.
1 2 3 4 5 6 7 8 9 10 11 12 13
| x = linspace(0, 2*pi, 40); x_m = x; x_m([11:13, 28:30]) = NaN; y_m = sin(x_m); plot(x_m, y_m,'ro','MarkerFaceColor', 'r'); xlim([0, 2*pi]); ylim([-1.2, 1.2]); box on; set(gca,'FontName', 'symbol','FontSize', 16); set(gca,'XTick', 0:pi/2:2*pi); m_i = ~isnan(x_m); y_i = interp1(x_m(m_i),y_m(m_i), x); hold on; plot(x,y_i,'-b','LineWidth', 2); hold off;
|
执行结果:
1 2 3 4 5 6
| m_i = ~isnan(x_m); y_i = spline(x_m(m_i), y_m(m_i), x); hold on; plot(x,y_i,'-g','LineWidth', 2); hold off; h = legend('Original', 'Linear', 'Spline'); set(h,'FontName', 'Times New Roman');
|
执行结果:
资料链接
参考视频
参考讲义
完结
这门课就算是完整的听完了, 其实除了一些基础知识, 基本上也没记住什么知识点. 但是学习和记忆最好的方式就是把知识讲给别人听, 所以还是把课程里的大部分知识点流水账式的记录了下来, 以后如果用到的话也算是有据可查.