教程:遮罩阵列

先决条件

在阅读本教程之前,您应该了解一点Python。如果你想刷新你的记忆,看看 Python tutorial .

如果您想运行本教程中的示例,还应该 matplotlib 安装在您的计算机上。

学习者简介

本教程适用于对NumPy有基本了解并希望了解如何屏蔽数组和 numpy.ma 模块可以在实践中使用。

学习目标

完成本教程后,您应该能够:

  • 了解什么是掩蔽数组以及如何创建它们

  • 了解如何访问和修改屏蔽数组的数据

  • 决定在某些应用程序中何时使用屏蔽数组是合适的

什么是掩蔽阵列?

考虑以下问题。您的数据集缺少或无效的条目。如果你对这些数据做任何处理 skip 或者标记这些不需要的条目而不只是删除它们,您可能需要使用条件或以某种方式过滤数据。这个 numpy.ma 模块提供了与 NumPy ndarrays 添加了结构以确保在计算中不使用无效条目。

Reference Guide

屏蔽数组是标准的组合 numpy.ndarray 和A mask . 面具是 nomask ,表示关联数组的任何值都无效,或者是一个布尔数组,用于确定关联数组的每个元素的值是否有效。当遮罩的元素是 False ,关联数组的相应元素是有效的,并被称为未屏蔽。当遮罩的元素是 True ,关联数组的相应元素被称为屏蔽(无效)。

我们可以想出一个 MaskedArray 作为以下各项的组合:

  • 数据,作为常规 numpy.ndarray 任何形状或数据类型的;

  • 与数据形状相同的布尔掩码;

  • A fill_value ,可用于替换无效条目以返回标准的值。 numpy.ndarray .

什么时候有用?

在某些情况下,屏蔽数组可能比仅消除数组中的无效项更有用:

  • 当您想保留为以后处理而屏蔽的值,而不复制数组时;

  • 当您必须处理许多数组时,每个数组都有自己的掩码。如果掩码是数组的一部分,则可以避免bug,而且代码可能更紧凑;

  • 当您对丢失或无效的值有不同的标志,并且希望保留这些标志而不在原始数据集中替换它们,但将它们从计算中排除时;

  • 如果无法避免或消除缺失值,但又不想处理 NaN (不是数字)值。

掩蔽阵列也是一个好主意,因为 numpy.ma 模块还附带了most的具体实现 NumPy universal functions (ufuncs) ,这意味着您仍然可以对屏蔽数据应用快速矢量化函数和操作。然后输出一个屏蔽数组。我们将在下面看到一些在实践中如何工作的例子。

使用掩蔽数组查看COVID-19数据

Kaggle 有可能下载一个数据集,其中包含2020年初爆发的COVID-19病毒的初始数据。我们将查看这个文件中包含的数据的一小部分 who_covid_19_sit_rep_time_series.csv .

In [1]: import numpy as np

In [2]: import os

# The os.getcwd() function returns the current folder; you can change
# the filepath variable to point to the folder where you saved the .csv file
In [3]: filepath = os.getcwd()

In [4]: filename = os.path.join(filepath, "who_covid_19_sit_rep_time_series.csv")

数据文件包含不同类型的数据,其组织结构如下:

  • 第一行是一个标题行,它(主要)描述了下一行后面每列中的数据,从第四列开始,标题是观察的日期。

  • 第二行到第七行包含的摘要数据的类型与我们将要检查的不同,因此我们需要将其从我们将要处理的数据中排除。

  • 我们希望处理的数字数据从第4列第8行开始,并从那里延伸到最右边的一列和最下面的一行。

让我们研究一下这个文件中记录的前14天的数据。从数据库收集数据 .csv 文件,我们将使用 numpy.genfromtxt 函数,确保只选择具有实际数字的列,而不选择包含位置数据的前三列。我们还跳过了该文件的前7行,因为它们包含我们不感兴趣的其他数据。另外,我们将为这些数据提取有关日期和位置的信息。

# Note we are using skip_header and usecols to read only portions of the
# data file into each variable.
# Read just the dates for columns 3-7 from the first row
In [5]: dates = np.genfromtxt(filename, dtype=np.unicode_, delimiter=",",
   ...:                       max_rows=1, usecols=range(3, 17),
   ...:                       encoding="utf-8-sig")
   ...: 

# Read the names of the geographic locations from the first two
# columns, skipping the first seven rows
In [6]: locations = np.genfromtxt(filename, dtype=np.unicode_, delimiter=",",
   ...:                           skip_header=7, usecols=(0, 1),
   ...:                           encoding="utf-8-sig")
   ...: 

# Read the numeric data from just the first 14 days
In [7]: nbcases = np.genfromtxt(filename, dtype=np.int_, delimiter=",",
   ...:                         skip_header=7, usecols=range(3, 17),
   ...:                         encoding="utf-8-sig")
   ...: 

包含在 numpy.genfromtxt 函数调用,我们选择了 numpy.dtype 对于每个数据子集(整数或- numpy.int_ -或者一个字符串- numpy.unicode_ ). 我们还使用了 encoding argument to select utf-8-sig as the encoding for the file (read more about encoding in the official Python documentation ). 你可以阅读更多关于 numpy.genfromtxt 函数来自 Reference Documentation 或者从 Basic IO tutorial .

探索数据

首先,我们可以绘制出我们拥有的全部数据集,看看它是什么样子。为了得到一个可读的绘图,我们只选择几个日期显示在我们的 x-axis ticks . 还要注意,在plot命令中,我们使用 nbcases.T (这个词的转置) nbcases 数组)因为这意味着我们将把文件的每一行作为一个单独的行来绘制。我们选择绘制一条虚线(使用 '--' 线条样式)。看到了吗 matplotlib 有关此的详细信息,请参阅文档。

In [8]: import matplotlib.pyplot as plt

In [9]: selected_dates = [0, 3, 11, 13]

In [10]: plt.plot(dates, nbcases.T, '--');

In [11]: plt.xticks(selected_dates, dates[selected_dates]);

In [12]: plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020");
../_images/plot_covid_1.png

注解

如果在ipythonshell中执行上述命令,则可能需要使用该命令 plt.show() 显示图像窗口。还要注意,我们在行的末尾使用分号来抑制其输出,但这是可选的。

从1月24日到2月1日,这张图的形状很奇怪。很想知道这些数据是从哪里来的。如果我们看看 locations 我们从 .csv 文件中,我们可以看到有两列,第一列包含地区,第二列包含国家的名称。但是,只有前几行包含第一列(中国的省名)的数据。在那之后,我们只有国家名称。因此,将来自中国的所有数据归为一行是有意义的。为此,我们将从 nbcases 仅数组 locations 数组对应于中国。接下来,我们将使用 numpy.sum 函数对所有选定行求和 (axis=0 ):

In [13]: china_total = nbcases[locations[:, 1] == 'China'].sum(axis=0)

In [14]: china_total
Out[14]: 
array([  247,   288,   556,   817,   -22,   -22,   -15,   -10,    -9,
          -7,    -4, 11820, 14410, 17237])

这个数据有问题-我们不应该在累积数据集中有负值。发生什么事?

丢失的数据

看看这些数据,我们发现:有一个周期 丢失的数据

In [15]: nbcases
Out[15]: 
array([[  258,   270,   375, ...,  7153,  9074, 11177],
       [   14,    17,    26, ...,   520,   604,   683],
       [   -1,     1,     1, ...,   422,   493,   566],
       ...,
       [   -1,    -1,    -1, ...,    -1,    -1,    -1],
       [   -1,    -1,    -1, ...,    -1,    -1,    -1],
       [   -1,    -1,    -1, ...,    -1,    -1,    -1]])

所有的 -1 我们看到的价值观来自 numpy.genfromtxt 试图从原始数据中读取丢失的数据 .csv 文件。显然,我们不想把丢失的数据当作 -1 -我们只想跳过这个值,这样它就不会干扰我们的分析。导入后 numpy.ma 模块中,我们将创建一个新数组,这次将屏蔽无效值:

In [16]: from numpy import ma

In [17]: nbcases_ma = ma.masked_values(nbcases, -1)

如果我们看看 nbcases_ma 蒙面阵,我们有:

In [18]: nbcases_ma
Out[18]: 
masked_array(
  data=[[258, 270, 375, ..., 7153, 9074, 11177],
        [14, 17, 26, ..., 520, 604, 683],
        [--, 1, 1, ..., 422, 493, 566],
        ...,
        [--, --, --, ..., --, --, --],
        [--, --, --, ..., --, --, --],
        [--, --, --, ..., --, --, --]],
  mask=[[False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [ True, False, False, ..., False, False, False],
        ...,
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]],
  fill_value=-1)

我们可以看到这是一种不同的数组。正如引言中提到的,它有三个属性 (datamaskfill_value ). 请记住 mask 属性具有 True 对应于的元素的值 无效 数据(在 data 属性)。

注解

添加 -1 丢失数据不是问题 numpy.genfromtxt ;在这种情况下,将缺少的值替换为 0 可能还不错,但我们稍后会发现这远不是一个普遍的解决方案。此外,还可以调用 numpy.genfromtxt 函数使用 usemask 参数。如果 usemask=Truenumpy.genfromtxt 自动返回屏蔽数组。

让我们试着看看除去第一行数据(来自中国湖北省的数据)后的数据是什么样子的,这样我们可以更仔细地查看缺失的数据:

In [19]: plt.plot(dates, nbcases_ma[1:].T, '--');

In [20]: plt.xticks(selected_dates, dates[selected_dates]);

In [21]: plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020");
../_images/plot_covid_2.png

现在我们的数据已经被掩盖了,让我们试着总结一下中国的所有案例:

In [22]: china_masked = nbcases_ma[locations[:, 1] == 'China'].sum(axis=0)

In [23]: china_masked
Out[23]: 
masked_array(data=[278, 309, 574, 835, 10, 10, 17, 22, 23, 25, 28, 11821,
                   14411, 17238],
             mask=[False, False, False, False, False, False, False, False,
                   False, False, False, False, False, False],
       fill_value=999999)

注意 china_masked 是掩蔽数组,因此它的数据结构与常规NumPy数组不同。现在,我们可以使用 .data 属性:

In [24]: china_total = china_masked.data

In [25]: china_total
Out[25]: 
array([  278,   309,   574,   835,    10,    10,    17,    22,    23,
          25,    28, 11821, 14411, 17238])

这样更好:没有更多的负值。然而,我们仍然可以看到,在一些日子里,病例的累积数量似乎在下降(例如,从835例下降到10例),这与“累积数据”的定义不符。如果我们更仔细地看数据,我们可以看到,在中国大陆缺失数据的时期,香港、台湾、澳门和中国“未指定”地区都有有效数据。或许我们可以从中国的病例总数中剔除这些数据,以便更好地了解数据。

首先,我们将确定中国大陆地区的位置指数:

In [26]: china_mask = ((locations[:, 1] == 'China') &
   ....:               (locations[:, 0] != 'Hong Kong') &
   ....:               (locations[:, 0] != 'Taiwan') &
   ....:               (locations[:, 0] != 'Macau') &
   ....:               (locations[:, 0] != 'Unspecified*'))
   ....: 

现在, china_mask 是一个布尔值数组 (TrueFalse );我们可以使用 ma.nonzero 掩蔽阵列的方法:

In [27]: china_mask.nonzero()
Out[27]: 
(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 31, 33]),)

现在我们可以对中国大陆的条目进行正确的求和:

In [28]: china_total = nbcases_ma[china_mask].sum(axis=0)

In [29]: china_total
Out[29]: 
masked_array(data=[278, 308, 440, 446, --, --, --, --, --, --, --, 11791,
                   14380, 17205],
             mask=[False, False, False, False,  True,  True,  True,  True,
                    True,  True,  True, False, False, False],
       fill_value=999999)

我们可以用这些信息替换数据,绘制一个新的图表,重点放在中国大陆:

In [30]: plt.plot(dates, china_total.T, '--');

In [31]: plt.xticks(selected_dates, dates[selected_dates]);

In [32]: plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China");
../_images/plot_covid_3.png

很明显,掩蔽阵列是正确的解决方案。如果不正确描述曲线的演变,我们就无法表示缺失的数据。

拟合数据

我们可以考虑的一种可能性是,插入缺失的数据,以估计1月下旬的病例数。注意,我们可以使用 .mask 属性:

In [33]: china_total.mask
Out[33]: 
array([False, False, False, False,  True,  True,  True,  True,  True,
        True,  True, False, False, False])

In [34]: invalid = china_total[china_total.mask]

In [35]: invalid
Out[35]: 
masked_array(data=[--, --, --, --, --, --, --],
             mask=[ True,  True,  True,  True,  True,  True,  True],
       fill_value=999999,
            dtype=int64)

我们还可以使用此掩码的逻辑求反来访问有效项:

In [36]: valid = china_total[~china_total.mask]

In [37]: valid
Out[37]: 
masked_array(data=[278, 308, 440, 446, 11791, 14380, 17205],
             mask=[False, False, False, False, False, False, False],
       fill_value=999999)

现在,如果我们想为这些数据创建一个非常简单的近似值,我们应该考虑无效项周围的有效项。所以首先让我们选择数据有效的日期。请注意,我们可以使用来自 china_total 用于索引日期数组的掩码数组:

In [38]: dates[~china_total.mask]
Out[38]: 
array(['1/21/20', '1/22/20', '1/23/20', '1/24/20', '2/1/20', '2/2/20',
       '2/3/20'], dtype='<U7')

最后,我们可以使用 numpy.polyfitnumpy.polyval 函数创建一个三次多项式以尽可能地拟合数据:

In [39]: t = np.arange(len(china_total))

In [40]: params = np.polyfit(t[~china_total.mask], valid, 3)

In [41]: cubic_fit = np.polyval(params, t)

In [42]: plt.plot(t, china_total);

In [43]: plt.plot(t, cubic_fit, '--');
../_images/plot_covid_4.png

这一情节不是那么可读,因为线似乎是在对方,所以让我们总结在一个更详细的情节。我们将在可用时绘制真实数据,并显示不可用数据的立方拟合,使用此拟合计算2020年1月28日(记录开始后7天)观测病例数的估计值:

In [44]: plt.plot(t, china_total);

In [45]: plt.plot(t[china_total.mask], cubic_fit[china_total.mask], '--', color='orange');

In [46]: plt.plot(7, np.polyval(params, 7), 'r*');

In [47]: plt.xticks([0, 7, 13], dates[[0, 7, 13]]);

In [48]: plt.yticks([0, np.polyval(params, 7), 10000, 17500]);

In [49]: plt.legend(['Mainland China', 'Cubic estimate', '7 days after start']);

In [50]: plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China\n"
   ....:           "Cubic estimate for 7 days after start");
   ....: 
../_images/plot_covid_5.png

更多阅读

本教程中未涉及的主题可以在文档中找到: