>>> from env_helper import info; info()
页面更新时间: 2024-01-19 23:15:35
运行环境:
Linux发行版本: Debian GNU/Linux 12 (bookworm)
操作系统内核: Linux-6.1.0-17-amd64-x86_64-with-glibc2.36
Python版本: 3.11.2
5.10. 估计回归拟合¶
许多数据集包含多个定量变量,分析的目标通常是将这些变量相互关联。 我们之前讨论过可以通过显示两个变量的联合分布来实现这一点的函数。 但是,使用统计模型来估计两组噪声观测值之间的简单关系可能非常有帮助。 本章讨论的函数将通过线性回归的通用框架来实现。
本着 Tukey 的精神,seaborn 中的回归图主要是为了添加一个视觉指南,以帮助在探索性数据分析期间强调数据集中的模式。 也就是说,seaborn本身并不是一个统计分析包。 若要获取与回归模型拟合度相关的定量度量,应使用 statsmodels。 然而,seaborn 的目标是通过可视化快速轻松地探索数据集,因为这样做与通过统计表探索数据集一样重要(如果不是更重要的话)。
5.10.1. 用于绘制线性回归模型的函数¶
可用于可视化线性拟合的两个函数是 regplot()
和 lmplot()
。
在最简单的调用中,两个函数绘制两个变量 x
和 y
的散点图,然后拟合回归模型 y ~ x
,并绘制结果回归线和该回归的
\(95%\) 置信区间:
>>> import seaborn as sns
>>>
>>> sns.set_theme()
>>>
>>> tips = sns.load_dataset("tips")
>>> sns.regplot(x="total_bill", y="tip", data=tips);
>>> sns.lmplot(x="total_bill", y="tip", data=tips);
这些函数绘制相似的图,但 regplot()
是轴级函数,而 lmplot()
是图形级函数。 此外, regplot()
接受各种格式的 x
和 y
变量,包括简单的numpy数组、 pandas.Series
对象, 或作为对
pandas.DataFrame
中变量的引用。 传递给数据的数据框架对象。 相反,
lmplot()
将 data
作为必需的参数,并且必须将 x
和 y
变量指定为字符串。 最后,只有 lmplot()
将 hue
作为参数。
但其核心功能在其他方面是相似的,因此本教程将重点介绍 lmplot()
。
当其中一个变量采用离散值时,可以拟合线性回归,但是,这种数据集生成的简单散点图通常不是最优的:
>>> sns.lmplot(x="size", y="tip", data=tips);
一种选择是在离散值中添加一些随机噪声(“抖动”),以使这些值的分布更加清晰。 请注意,抖动仅应用于散点图数据,不会影响回归线拟合本身:
>>> sns.lmplot(x="size", y="tip", data=tips, x_jitter=.05);
第二种选择是折叠每个离散条柱中的观测值,以绘制集中趋势的估计值以及置信区间:
>>> import numpy as np
>>> sns.lmplot(x="size", y="tip", data=tips, x_estimator=np.mean);
5.10.2. 适合不同种类的模型¶
上面使用的简单线性回归模型非常容易拟合,但是,它不适用于某些类型的数据集。 Anscombe 的四重数据集显示了几个示例,其中简单的线性回归提供了对关系的相同估计,其中简单的目视检查清楚地显示了差异。 例如,在第一种情况下,线性回归是一个很好的模型:
>>> anscombe = sns.load_dataset("anscombe",data_home='seaborn-data',cache=True)
>>> sns.lmplot(x="x", y="y", data=anscombe.query("dataset == 'I'"),
>>> ci=None, scatter_kws={"s": 80});
第二个数据集中的线性关系是相同的,但图清楚地表明这不是一个好的模型:
>>> sns.lmplot(x="x", y="y", data=anscombe.query("dataset == 'II'"),
>>> ci=None, scatter_kws={"s": 80});
在这些高阶关系存在的情况下, lmplot()
和 regplot()
可以拟合一个多项式回归模型来探索数据集中简单的非线性趋势:
>>> sns.lmplot(x="x", y="y", data=anscombe.query("dataset == 'II'"),
>>> order=2, ci=None, scatter_kws={"s": 80});
另一个问题是由“异常值”观测值引起的,这些观测值由于所研究的主要关系以外的某种原因而偏离:
>>> sns.lmplot(x="x", y="y", data=anscombe.query("dataset == 'III'"),
>>> ci=None, scatter_kws={"s": 80});
在存在异常值的情况下,拟合稳健回归可能很有用,该回归使用不同的损失函数来降低相对较大的残差的权重:
以下需要安装:
sudo apt install -y python3-statsmodels
>>> sns.lmplot(x="x", y="y", data=anscombe.query("dataset == 'III'"),
>>> robust=True, ci=None, scatter_kws={"s": 80});
当变量y
是二进制变量时,简单的线性回归也“有效”,但提供了难以置信的预测:
>>> tips["big_tip"] = (tips.tip / tips.total_bill) > .15
>>> sns.lmplot(x="total_bill", y="big_tip", data=tips,
>>> y_jitter=.03);
这种情况下的解决方案是拟合逻辑回归,使回归线显示给定x
值时y = 1
的估计概率:
>>> sns.lmplot(x="total_bill", y="big_tip", data=tips,
>>> logistic=True, y_jitter=.03);
请注意,逻辑回归估计的计算密集度要高得多(稳健回归也是如此)。由于回归线周围的置信区间是使用引导过程计算的,因此您可能希望将其关闭以加快迭代速度(使用ci=None
)。
一种完全不同的方法是使用更平滑的 lowess 拟合非参数回归。这种方法具有最少的假设,尽管它是计算密集型的,因此目前根本不计算置信区间:
>>> sns.lmplot(x="total_bill", y="tip", data=tips,
>>> lowess=True, line_kws={"color": "C1"});
residplot()
函数是检查简单回归模型是否适合于数据集的有用工具。它拟合并去除一个简单的线性回归,然后绘制每个观测值的残差值。理想情况下,这些值应该随机分布在y = 0
周围:
>>> sns.residplot(x="x", y="y", data=anscombe.query("dataset == 'I'"),
>>> scatter_kws={"s": 80});
如果残差中存在结构,则表明简单的线性回归是不合适的:
>>> sns.residplot(x="x", y="y", data=anscombe.query("dataset == 'II'"),
>>> scatter_kws={"s": 80});
5.10.3. 其他变量的条件作用¶
上面的图表显示了探索一对变量之间关系的许多方法。然而,一个更有趣的问题通常是“这两个变量之间的关系作为第三个变量的函数是如何变化的?”这就是regplot()
和lmplot()
之间的主要区别所在。虽然regplot()
总是显示单个关系,但lmplot()
将regplot()
与FacetGrid
结合起来,使用hue
映射或faceting显示多个拟合。
分离关系的最佳方法是在同一轴上绘制两个级别,并使用颜色来区分它们:
>>> sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips);
与relplot()
不同,它不可能将一个不同的变量映射到散点图的样式属性,但你可以用标记形状冗余地编码hue
变量:
>>> sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
>>> markers=["o", "x"], palette="Set1");
要添加另一个变量,您可以绘制多个“分面”,每个级别变量都出现在网格的行或列中:
>>> sns.lmplot(x="total_bill", y="tip", hue="smoker", col="time", data=tips);
>>> sns.lmplot(x="total_bill", y="tip", hue="smoker",
>>> col="time", row="sex", data=tips, height=3);
5.10.4. 在其他上下文中绘制回归图¶
其他一些派生函数在更大、更复杂的绘图上下文中使用regplot()
。第一个是我们在发行版教程中介绍的jointplot()
函数。除了前面讨论的绘图样式,jointplot()
可以使用regplot()
通过传递kind="reg"
来显示关节轴上的线性回归拟合:
>>> sns.jointplot(x="total_bill", y="tip", data=tips, kind="reg");
使用带有kind="reg"
的pairplot()
函数将regplot()
和PairGrid
结合起来,以显示数据集中变量之间的线性关系。请注意这与lmplot()
的不同之处。在下面的图中,这两个轴并没有显示出以第三个变量的两个层次为条件的相同关系;相反,PairGrid()
用于显示数据集中变量的不同配对之间的多种关系:
>>> sns.pairplot(tips, x_vars=["total_bill", "size"], y_vars=["tip"],
>>> height=5, aspect=.8, kind="reg");
这两个函数都使用hue
参数对一个额外的分类变量进行调节:
>>> sns.pairplot(tips, x_vars=["total_bill", "size"], y_vars=["tip"],
>>> hue="smoker", height=5, aspect=.8, kind="reg");