回归

简单线性回归

简单线性回归是指, 对于一系列的离散数据点, 用一条直线来对其进行模拟.如图

image-20200717104254373

根据直线的解析式我们可以知道, 其实只要找到β0\beta_0β1\beta_1两个参数, 直线就被确定下来了. 那么在无数条接近数据点的直线中, 哪一条会是最有可能适合的呢? 这里首先要对误差进行定义, 如图中所示, 真实数据点与回归直线上数据点的距离就是误差, 距离越远误差就越大, 我们的目的就是找一条直线, 使得所有点误差的总和最小. 由于有些点在直线上方, 有些点在直线下方, 因此用y的差值的平方来衡量误差的大小. 可以记为:

SSE=iϵi2=i(yiy^i)2S S E=\sum_{i} \epsilon_{i}^{2}=\sum_{i}\left(y_{i}-\widehat{y}_{i}\right)^{2}

代入直线方程y^i=β0+β1xi\hat{y}_{i}=\beta_{0}+\beta_{1} x_{i} 有:

SSE=i(yiβ0β1xi)2S S E=\sum_{i}\left(y_{i}-\beta_{0}-\beta_{1} x_{i}\right)^{2}

于是, 问题转化为求β0\beta_0β1\beta_1, 使得SSE取得最小值的问题. 首先求偏导数, 另其为零

iϵi2β0=2i(yiβ0β1xi)=0\frac{\partial \sum_{i} \epsilon_{i}^{2}}{\partial \beta_{0}}=-2 \sum_{i}\left(y_{i}-\beta_{0}-\beta_{1} x_{i}\right)=0

iϵi2β1=2i(yiβ0β1xi)xi=0\frac{\partial \sum_{i} \epsilon_{i}^{2}}{\partial \beta_{1}}=-2 \sum_{i}\left(y_{i}-\beta_{0}-\beta_{1} x_{i}\right) x_{i}=0

虽然上面的公式比较复杂, 但实际上未知数只有β0\beta_0β1\beta_1两个, 假设一共有N个数据点, 对上式进行拆分

i=1Nyi=β0N+β1i=1Nxi\sum_{i=1}^{N} y_{i}=\beta_{0} \cdot N+\beta_{1} \sum_{i=1}^{N} x_{i}

i=1Nyixi=β0i=1Nxi+β1i=1Nxi2\sum_{i=1}^{N} y_{i} x_{i}=\beta_{0} \sum_{i=1}^{N} x_{i}+\beta_{1} \sum_{i=1}^{N} x_{i}^{2}

写成矩阵形式:

[yiyixi]=[Nxixixi2][β0β1]\left[\begin{array}{l} \sum y_{i} \\ \sum y_{i} x_{i} \end{array}\right]=\left[\begin{array}{cc} N & \sum x_{i} \\ \sum x_{i} & \sum x_{i}^{2} \end{array}\right]\left[\begin{array}{l} \beta_{0} \\ \beta_{1} \end{array}\right]

matlab提供内置函数polyfit()来解上面的方程来获得β1\beta_1β0\beta_0,

1
2
3
x =[-1.2 -0.5 0.3 0.9 1.8 2.6 3.0 3.5];
y =[-15.6 -8.5 2.2 4.5 6.6 8.2 8.9 10.0];
fit = polyfit(x,y,1);

计算结果为:

image-20200717110832385

其中fit(1)存储的是β1\beta_1, fit(2)存储的是β0\beta_0, 可以使用下面代码来绘制回归直线:

1
2
3
xfit = [x(1):0.1:x(end)]; yfit = fit(1)*xfit + fit(2);
plot(x,y,'ro',xfit,yfit); set(gca,'FontSize',14);
legend(2,'data points','best-fit');

image-20200717111313167

非线性回归

对于给定的离散数据点而言, 其实二者未必具有线性关系. matlab中可以使用scatter()函数来画出数据点的图像 , 使用corrcoef()函数可以用来求数据的相关系数. 相关系数的值在[-1,1]区间内, 大于0则称正相关, 小于0称负相关.

1
2
3
4
x =[-1.2 -0.5 0.3 0.9 1.8 2.6 3.0 3.5];
y =[-15.6 -8.5 2.2 4.5 6.6 8.2 8.9 10.0];
scatter(x,y); box on; axis square;
corrcoef(x,y)

绘图结果:

image-20200717112607774

相关系数:

image-20200717112647111

此结果是一个对称矩阵, 其中ans(1,1)ans(2,2)位置表示的是x与x, y与y自身的相关系数, 所以为1. ans(1,2)ans(2,1)则表示y 与x 的相关系数. 从相关系数的计算结果来看, 其实还是比较线性相关的.但实际上也可以通过曲线来做回归, 可以参考下面代码:

1
2
3
4
5
6
7
8
9
x =[-1.2 -0.5 0.3 0.9 1.8 2.6 3.0 3.5];
y =[-15.6 -8.5 2.2 4.5 6.6 8.2 8.9 10.0];
figure('Position', [50 50 1500 400]);
for i=1:3
subplot(1,3,i); p = polyfit(x,y,i); % i代表回归曲线的次数, i等于1时则返回一个一次多项式的矩阵形式, i等于2时则返回2次多项式的矩阵形式...
xfit = x(1):0.1:x(end); yfit = polyval(p,xfit); % 利用xfit的值, 求多项式p的函数值
plot(x,y,'ro',xfit,yfit); set(gca,'FontSize',14);
ylim([-17, 11]); legend(4,'Data points','Fitted curve');
end

绘图结果:

image-20200717113402412

从结果上看, 似乎三次曲线的效果比较好. 但是需要了解的是, 并不是使用曲线的次数越高越好.

三维空间回归

如果数据点是三维的, 需要使用做的是回归平面或者回归曲面, 需要用到matlab的内置函数regress(), 例如

1
2
3
4
5
6
7
8
9
10
11
12
13
14
load carsmall; % 使用matlab内置数据集
y = MPG;
x1 = Weight; x2 = Horsepower;
X = [ones(length(x1),1) x1 x2];
b = regress(y,X);
x1fit = min(x1):100:max(x1);
x2fit = min(x2):10:max(x2);
[X1FIT,X2FIT]=meshgrid(x1fit,x2fit);
YFIT=b(1)+b(2)*X1FIT+b(3)*X2FIT;
scatter3(x1,x2,y,'filled'); hold on;
mesh(X1FIT,X2FIT,YFIT); hold off;
xlabel('Weight');
ylabel('Horsepower');
zlabel('MPG'); view(50,10);

绘图结果:

image-20200717115633706

插值

插值主要用到的函数的超链接:

  • interp1: 一维数据插值(表查找)
  • pchip : 分段三次 Hermite 插值多项式 (PCHIP)
  • spline : 三次方样条数据插值
  • mkpp: 生成分段多项式

下面分别给出interp1的使用示例和spline的使用示例, 关于插值的内容, 我的学习效果不是很好, 如果以后会用到的话, 在进一步学习.

1
2
3
4
5
6
7
8
9
10
11
12
13
%% interp1
x = linspace(0, 2*pi, 40); x_m = x;
x_m([11:13, 28:30]) = NaN; % 把数据中指定位置的数据设为NaN 接下来进行插值
y_m = sin(x_m);
plot(x_m, y_m,'ro','MarkerFaceColor', 'r');
xlim([0, 2*pi]); ylim([-1.2, 1.2]); box on;
set(gca,'FontName', 'symbol','FontSize', 16);
set(gca,'XTick', 0:pi/2:2*pi);
m_i = ~isnan(x_m);
y_i = interp1(x_m(m_i),y_m(m_i), x);
hold on;
plot(x,y_i,'-b','LineWidth', 2);
hold off;

执行结果:
image-20200717120658128

1
2
3
4
5
6
%% spline
m_i = ~isnan(x_m);
y_i = spline(x_m(m_i), y_m(m_i), x);
hold on; plot(x,y_i,'-g','LineWidth', 2); hold off;
h = legend('Original', 'Linear', 'Spline');
set(h,'FontName', 'Times New Roman');

执行结果:

image-20200717120908266

资料链接

参考视频 参考讲义

完结

这门课就算是完整的听完了, 其实除了一些基础知识, 基本上也没记住什么知识点. 但是学习和记忆最好的方式就是把知识讲给别人听, 所以还是把课程里的大部分知识点流水账式的记录了下来, 以后如果用到的话也算是有据可查.