迭代数组

注解

数组支持迭代器协议,可以像Python列表一样迭代。看到了吗 索引、切片和迭代 快速入门指南中有关基本用法和示例的部分。本文档的其余部分介绍 nditer 对象并涵盖更高级的用法。

迭代器对象 nditer 在numpy 1.6中引入,提供了许多灵活的方法来系统地访问一个或多个数组的所有元素。本页介绍了在python中使用对象对数组进行计算的一些基本方法,然后总结了如何在cython中加速内部循环。因为 Python 的暴露 nditer 是C数组迭代器API的一种相对简单的映射,这些思想也将帮助从C或C++中进行数组迭代。

单阵列迭代

最基本的任务 nditer 是访问数组的每个元素。使用标准的python迭代器接口逐个提供每个元素。

例子

>>> a = np.arange(6).reshape(2,3)
>>> for x in np.nditer(a):
...     print(x, end=' ')
...
0 1 2 3 4 5

对于这个迭代,需要注意的一件重要事情是,选择顺序来匹配数组的内存布局,而不是使用标准的C或Fortran顺序。这样做是为了提高访问效率,反映了这样一种观点:默认情况下,一个人只想访问每个元素,而不关心特定的顺序。我们可以通过迭代上一个数组的转置来看到这一点,而不是按照C顺序获取该转置的副本。

例子

>>> a = np.arange(6).reshape(2,3)
>>> for x in np.nditer(a.T):
...     print(x, end=' ')
...
0 1 2 3 4 5
>>> for x in np.nditer(a.T.copy(order='C')):
...     print(x, end=' ')
...
0 3 1 4 2 5

两者的要素 aa.T 以相同的顺序遍历,即它们存储在内存中的顺序,而 a.T.copy(order='C') 以不同的顺序访问,因为它们已放入不同的内存布局中。

控制迭代顺序

有时,不管内存中元素的布局如何,以特定的顺序访问数组元素都是很重要的。这个 nditer 对象提供 order 用于控制此迭代方面的参数。具有上述行为的默认值是order='k'以保持现有的顺序。对于C命令,可以用order='c'覆盖,对于Fortran命令,可以用order='f'覆盖。

例子

>>> a = np.arange(6).reshape(2,3)
>>> for x in np.nditer(a, order='F'):
...     print(x, end=' ')
...
0 3 1 4 2 5
>>> for x in np.nditer(a.T, order='C'):
...     print(x, end=' ')
...
0 3 1 4 2 5

修改数组值

默认情况下, nditer 将输入操作数视为只读对象。要能够修改数组元素,必须使用 'readwrite''writeonly' 每个操作数标志。

然后,nditer将生成可写缓冲区数组,您可以对其进行修改。但是,由于完成迭代后,nditer必须将此缓冲区数据复制回原始数组,因此必须通过两种方法之一在迭代结束时发出信号。你也可以:

  • 使用nditer作为上下文管理器 with 语句,当上下文退出时,临时数据将被写回。

  • 调用迭代器 close 方法一旦完成迭代,将触发回写。

nditer也不能再重复一次 close 调用或退出其上下文。

例子

>>> a = np.arange(6).reshape(2,3)
>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> with np.nditer(a, op_flags=['readwrite']) as it:
...    for x in it:
...        x[...] = 2 * x
...
>>> a
array([[ 0,  2,  4],
       [ 6,  8, 10]])

如果您编写的代码需要支持旧版本的numpy,请注意在1.15之前, nditer 不是上下文管理器,也没有 close 方法。相反,它依赖析构函数来启动缓冲区的写回。

使用外部循环

到目前为止,在所有的例子中, a 由迭代器一次提供一个,因为所有循环逻辑都是迭代器内部的。虽然这是简单和方便的,但它不是很有效。更好的方法是将一维最里面的循环移动到代码中,在迭代器外部。这样,numpy的矢量化操作可以用于访问的较大的元素块。

这个 nditer 将尝试向内部循环提供尽可能大的块。通过强制“c”和“f”顺序,我们得到不同的外部循环大小。通过指定迭代器标志启用此模式。

注意,在默认情况下保持本机内存顺序的情况下,迭代器能够提供一个一维块,而在强制fortran顺序时,它必须提供三个由两个元素组成的块。

例子

>>> a = np.arange(6).reshape(2,3)
>>> for x in np.nditer(a, flags=['external_loop']):
...     print(x, end=' ')
...
[0 1 2 3 4 5]
>>> for x in np.nditer(a, flags=['external_loop'], order='F'):
...     print(x, end=' ')
...
[0 3] [1 4] [2 5]

跟踪索引或多索引

在迭代过程中,您可能希望在计算中使用当前元素的索引。例如,您可能希望按内存顺序访问数组的元素,但使用C顺序、Fortran顺序或多维索引查找不同数组中的值。

索引由迭代器对象本身跟踪,并可通过 indexmulti_index 属性,具体取决于请求的内容。以下示例显示了显示索引进程的打印输出:

例子

>>> a = np.arange(6).reshape(2,3)
>>> it = np.nditer(a, flags=['f_index'])
>>> for x in it:
...     print("%d <%d>" % (x, it.index), end=' ')
...
0 <0> 1 <2> 2 <4> 3 <1> 4 <3> 5 <5>
>>> it = np.nditer(a, flags=['multi_index'])
>>> for x in it:
...     print("%d <%s>" % (x, it.multi_index), end=' ')
...
0 <(0, 0)> 1 <(0, 1)> 2 <(0, 2)> 3 <(1, 0)> 4 <(1, 1)> 5 <(1, 2)>
>>> with np.nditer(a, flags=['multi_index'], op_flags=['writeonly']) as it:
...     for x in it:
...         x[...] = it.multi_index[1] - it.multi_index[0]
...
>>> a
array([[ 0,  1,  2],
       [-1,  0,  1]])

跟踪索引或多索引与使用外部循环不兼容,因为每个元素需要不同的索引值。如果您尝试组合这些标志, nditer 对象将引发异常。

例子

>>> a = np.zeros((2,3))
>>> it = np.nditer(a, flags=['c_index', 'external_loop'])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: Iterator flag EXTERNAL_LOOP cannot be used if an index or multi-index is being tracked

可选循环和元素访问

为了使其属性在迭代过程中更容易访问, nditer 有一个用于迭代的替代语法,它显式地与迭代器对象本身一起工作。使用这个循环构造,可以通过索引到迭代器中来访问当前值。其他属性(如跟踪的索引)保持不变。下面的例子产生了与上一节相同的结果。

例子

>>> a = np.arange(6).reshape(2,3)
>>> it = np.nditer(a, flags=['f_index'])
>>> while not it.finished:
...     print("%d <%d>" % (it[0], it.index), end=' ')
...     is_not_finished = it.iternext()
...
0 <0> 1 <2> 2 <4> 3 <1> 4 <3> 5 <5>
>>> it = np.nditer(a, flags=['multi_index'])
>>> while not it.finished:
...     print("%d <%s>" % (it[0], it.multi_index), end=' ')
...     is_not_finished = it.iternext()
...
0 <(0, 0)> 1 <(0, 1)> 2 <(0, 2)> 3 <(1, 0)> 4 <(1, 1)> 5 <(1, 2)>
>>> with np.nditer(a, flags=['multi_index'], op_flags=['writeonly']) as it:
...     while not it.finished:
...         it[0] = it.multi_index[1] - it.multi_index[0]
...         is_not_finished = it.iternext()
...
>>> a
array([[ 0,  1,  2],
       [-1,  0,  1]])

缓冲数组元素

当强制执行迭代顺序时,我们观察到外部循环选项可能以较小的块提供元素,因为无法以适当的顺序以恒定的步幅访问元素。在编写C代码时,这通常是很好的,但是在纯Python代码中,这会导致性能显著降低。

通过启用缓冲模式,迭代器向内部循环提供的块可以变大,从而显著降低了Python解释器的开销。在强制FORTRAN迭代顺序的示例中,当启用缓冲时,内部循环将一次看到所有元素。

例子

>>> a = np.arange(6).reshape(2,3)
>>> for x in np.nditer(a, flags=['external_loop'], order='F'):
...     print(x, end=' ')
...
[0 3] [1 4] [2 5]
>>> for x in np.nditer(a, flags=['external_loop','buffered'], order='F'):
...     print(x, end=' ')
...
[0 3 1 4 2 5]

作为特定数据类型进行迭代

有时需要将数组视为与存储时不同的数据类型。例如,即使被操作的数组是32位浮点,也可能需要对64位浮点进行所有计算。除了在编写低级C代码时,通常最好让迭代器处理复制或缓冲,而不是在内部循环中自己强制转换数据类型。

有两种机制可以做到这一点:临时拷贝和缓冲模式。对于临时副本,使用新的数据类型复制整个数组,然后在副本中进行迭代。允许通过在所有迭代完成后更新原始数组的模式进行写访问。临时副本的主要缺点是临时副本可能消耗大量内存,特别是当迭代数据类型的项大小大于原始数据类型时。

缓冲模式可以缓解内存使用问题,并且比制作临时副本更易于缓存。除特殊情况外,如果在迭代器外部同时需要整个数组,建议缓冲而不是临时复制。在numpy中,ufuncs和其他函数使用缓冲来支持灵活的输入,并且内存开销最小。

在我们的示例中,我们将使用复杂的数据类型处理输入数组,这样我们就可以取负数的平方根。在不启用复制或缓冲模式的情况下,如果数据类型不精确匹配,迭代器将引发异常。

例子

>>> a = np.arange(6).reshape(2,3) - 3
>>> for x in np.nditer(a, op_dtypes=['complex128']):
...     print(np.sqrt(x), end=' ')
...
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Iterator operand required copying or buffering, but neither copying nor buffering was enabled

在复制模式下,将“copy”指定为每个操作数标志。这样做是为了以每个操作数的方式提供控制。缓冲模式被指定为迭代器标志。

例子

>>> a = np.arange(6).reshape(2,3) - 3
>>> for x in np.nditer(a, op_flags=['readonly','copy'],
...                 op_dtypes=['complex128']):
...     print(np.sqrt(x), end=' ')
...
1.7320508075688772j 1.4142135623730951j 1j 0j (1+0j) (1.4142135623730951+0j)
>>> for x in np.nditer(a, flags=['buffered'], op_dtypes=['complex128']):
...     print(np.sqrt(x), end=' ')
...
1.7320508075688772j 1.4142135623730951j 1j 0j (1+0j) (1.4142135623730951+0j)

迭代器使用numpy的强制转换规则来确定是否允许特定的转换。默认情况下,它强制执行“安全”强制转换。例如,这意味着,如果试图将64位浮点数组视为32位浮点数组,它将引发异常。在许多情况下,“同类”规则是最合理的规则,因为它允许从64位浮点转换为32位浮点,但不允许从浮点转换为int或从复杂转换为浮点。

例子

>>> a = np.arange(6.)
>>> for x in np.nditer(a, flags=['buffered'], op_dtypes=['float32']):
...     print(x, end=' ')
...
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Iterator operand 0 dtype could not be cast from dtype('float64') to dtype('float32') according to the rule 'safe'
>>> for x in np.nditer(a, flags=['buffered'], op_dtypes=['float32'],
...                 casting='same_kind'):
...     print(x, end=' ')
...
0.0 1.0 2.0 3.0 4.0 5.0
>>> for x in np.nditer(a, flags=['buffered'], op_dtypes=['int32'], casting='same_kind'):
...     print(x, end=' ')
...
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Iterator operand 0 dtype could not be cast from dtype('float64') to dtype('int32') according to the rule 'same_kind'

要注意的一件事是,在使用读写或只写操作数时转换回原始数据类型。一种常见的情况是以64位浮点实现内部循环,并使用“同类”转换来允许处理其他浮点类型。在只读模式下,可以提供整数数组,但读写模式将引发异常,因为返回到数组的转换将违反强制转换规则。

例子

>>> a = np.arange(6)
>>> for x in np.nditer(a, flags=['buffered'], op_flags=['readwrite'],
...                 op_dtypes=['float64'], casting='same_kind'):
...     x[...] = x / 2.0
...
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
TypeError: Iterator requested dtype could not be cast from dtype('float64') to dtype('int64'), the operand 0 dtype, according to the rule 'same_kind'

广播阵列迭代

numpy有一套处理形状不同的数组的规则,每当函数采用多个操作数组合元素时,都会应用这些规则。这叫做 broadcasting . 这个 nditer 当需要编写这样的函数时,对象可以为您应用这些规则。

作为一个例子,我们将广播一维和二维数组的结果打印出来。

例子

>>> a = np.arange(3)
>>> b = np.arange(6).reshape(2,3)
>>> for x, y in np.nditer([a,b]):
...     print("%d:%d" % (x,y), end=' ')
...
0:0 1:1 2:2 0:3 1:4 2:5

当广播错误发生时,迭代器会引发一个异常,其中包括有助于诊断问题的输入形状。

例子

>>> a = np.arange(2)
>>> b = np.arange(6).reshape(2,3)
>>> for x, y in np.nditer([a,b]):
...     print("%d:%d" % (x,y), end=' ')
...
Traceback (most recent call last):
...
ValueError: operands could not be broadcast together with shapes (2,) (2,3)

迭代器分配的输出数组

numpy函数中的一个常见情况是,根据输入的广播分配输出,另外还有一个名为“out”的可选参数,当提供结果时,结果将放置在该参数中。这个 nditer 对象提供了一个方便的习惯用法,使支持此机制变得非常容易。

我们将通过创建一个函数来演示这是如何工作的 square 它的输入平方。让我们从一个最小的函数定义开始,不包括“out”参数支持。

例子

>>> def square(a):
...     with np.nditer([a, None]) as it:
...         for x, y in it:
...             y[...] = x*x
...         return it.operands[1]
...
>>> square([1,2,3])
array([1, 4, 9])

默认情况下, nditer 将“allocate”和“writeonly”标志用于作为none传入的操作数。这意味着我们只能向迭代器提供两个操作数,它处理了其余的操作数。

当添加“out”参数时,我们必须显式地提供这些标志,因为如果有人将数组作为“out”传入,迭代器将默认为“readonly”,并且我们的内部循环将失败。之所以“readonly”是输入数组的默认值,是为了防止无意中触发缩减操作。如果默认值为“读写”,任何广播操作也会触发一个缩减,本文档稍后将讨论这个主题。

在这里,我们还将介绍“no-u broadcast”标志,它将阻止广播输出。这一点很重要,因为我们只需要为每个输出输入一个值。聚合多个输入值是一个缩减操作,需要特殊处理。它已经引发了一个错误,因为必须在迭代器标志中显式地启用缩减,但是对于最终用户来说,禁用广播导致的错误消息更容易理解。要了解如何将平方函数归纳为约简,请查看Cython部分中的平方和函数。

为了完整性起见,我们还将添加“external_loop”和“buffered”标志,因为出于性能原因,这些标志通常是您想要的。

例子

>>> def square(a, out=None):
...     it = np.nditer([a, out],
...             flags = ['external_loop', 'buffered'],
...             op_flags = [['readonly'],
...                         ['writeonly', 'allocate', 'no_broadcast']])
...     with it:
...         for x, y in it:
...             y[...] = x*x
...         return it.operands[1]
...
>>> square([1,2,3])
array([1, 4, 9])
>>> b = np.zeros((3,))
>>> square([1,2,3], out=b)
array([ 1.,  4.,  9.])
>>> b
array([ 1.,  4.,  9.])
>>> square(np.arange(6).reshape(2,3), out=b)
Traceback (most recent call last):
  ...
ValueError: non-broadcastable output operand with shape (3,) doesn't
match the broadcast shape (2,3)

外部产品迭代

任何二进制操作都可以像在 outernditer 对象提供了一种通过显式映射操作数的轴来实现此目的的方法。也可以用 newaxis 索引,但我们将向您展示如何直接使用nditer op_axes 参数以在没有中间视图的情况下完成此操作。

我们将做一个简单的外积,将第一个操作数的维数放在第二个操作数的维数之前。这个 op_axes 参数需要每个操作数的轴列表,并提供从迭代器轴到操作数轴的映射。

假设第一个操作数是一维的,第二个操作数是二维的。迭代器将具有三维,因此 op_axes 将有两个三元素列表。第一个列表选择第一个操作数的一个轴,对于迭代器轴的其余部分为-1,最终结果为 [0, -1, -1] . 第二个列表选择第二个操作数的两个轴,但不应与第一个操作数中选择的轴重叠。它的名单是 [-1, 0, 1] . 输出操作数以标准方式映射到迭代器轴上,因此我们可以不提供任何操作数,而不是构造另一个列表。

内部循环中的操作是一个简单的乘法。与外部产品有关的一切都由迭代器设置处理。

例子

>>> a = np.arange(3)
>>> b = np.arange(8).reshape(2,4)
>>> it = np.nditer([a, b, None], flags=['external_loop'],
...             op_axes=[[0, -1, -1], [-1, 0, 1], None])
>>> with it:
...     for x, y, z in it:
...         z[...] = x*y
...     result = it.operands[2]  # same as z
...
>>> result
array([[[ 0,  0,  0,  0],
        [ 0,  0,  0,  0]],
       [[ 0,  1,  2,  3],
        [ 4,  5,  6,  7]],
       [[ 0,  2,  4,  6],
        [ 8, 10, 12, 14]]])

注意,一旦迭代器关闭,我们就不能访问 operands 并且必须使用在上下文管理器中创建的引用。

约简迭代

每当可写操作数的元素少于完整迭代空间时,该操作数就会减少。这个 nditer 对象要求将任何缩减操作数标记为读写操作数,并且仅当“reduce_ok”作为迭代器标志提供时才允许缩减。

对于一个简单的例子,考虑取数组中所有元素的和。

例子

>>> a = np.arange(24).reshape(2,3,4)
>>> b = np.array(0)
>>> with np.nditer([a, b], flags=['reduce_ok'],
...                     op_flags=[['readonly'], ['readwrite']]) as it:
...     for x,y in it:
...         y[...] += x
...
>>> b
array(276)
>>> np.sum(a)
276

当结合约简和分配的操作数时,事情会稍微复杂一点。在开始迭代之前,必须将任何约简操作数初始化为其起始值。我们可以这样做,沿着最后一个轴求和 a .

例子

>>> a = np.arange(24).reshape(2,3,4)
>>> it = np.nditer([a, None], flags=['reduce_ok'],
...             op_flags=[['readonly'], ['readwrite', 'allocate']],
...             op_axes=[None, [0,1,-1]])
>>> with it:
...     it.operands[1][...] = 0
...     for x, y in it:
...         y[...] += x
...     result = it.operands[1]
...
>>> result
array([[ 6, 22, 38],
       [54, 70, 86]])
>>> np.sum(a, axis=2)
array([[ 6, 22, 38],
       [54, 70, 86]])

要进行缓冲还原,需要在设置过程中进行另一个调整。通常,迭代器构造涉及将第一个数据缓冲区从可读数组复制到缓冲区中。任何还原操作数都是可读的,因此可以将其读取到缓冲区中。不幸的是,此缓冲操作完成后操作数的初始化将不会反映在迭代开始时的缓冲区中,并且将生成垃圾结果。

迭代器标志“delay-bufalloc”允许迭代器分配的还原操作数与缓冲一起存在。设置此标志后,迭代器将保持其缓冲区未初始化,直到收到重置为止,之后它将准备好进行常规迭代。下面是前一个示例的外观,如果我们还启用了缓冲。

例子

>>> a = np.arange(24).reshape(2,3,4)
>>> it = np.nditer([a, None], flags=['reduce_ok',
...                                  'buffered', 'delay_bufalloc'],
...             op_flags=[['readonly'], ['readwrite', 'allocate']],
...             op_axes=[None, [0,1,-1]])
>>> with it:
...     it.operands[1][...] = 0
...     it.reset()
...     for x, y in it:
...         y[...] += x
...     result = it.operands[1]
...
>>> result
array([[ 6, 22, 38],
       [54, 70, 86]])

把内环放进赛通

那些想要从低级操作中获得良好性能的人应该强烈地考虑直接使用C中提供的迭代API,但是对于那些不适应C或C++的人来说,Cython是一个很好的中间层,具有合理的性能折衷。对于 nditer 对象,这意味着让迭代器负责广播、数据类型转换和缓冲,同时为cython提供内部循环。

对于我们的例子,我们将创建一个平方和函数。首先,让我们用简单的Python实现这个函数。我们希望支持类似于numpy的“axis”参数 sum 函数,因此我们需要为 op_axes 参数。这就是这个样子。

例子

>>> def axis_to_axeslist(axis, ndim):
...     if axis is None:
...         return [-1] * ndim
...     else:
...         if type(axis) is not tuple:
...             axis = (axis,)
...         axeslist = [1] * ndim
...         for i in axis:
...             axeslist[i] = -1
...         ax = 0
...         for i in range(ndim):
...             if axeslist[i] != -1:
...                 axeslist[i] = ax
...                 ax += 1
...         return axeslist
...
>>> def sum_squares_py(arr, axis=None, out=None):
...     axeslist = axis_to_axeslist(axis, arr.ndim)
...     it = np.nditer([arr, out], flags=['reduce_ok',
...                                       'buffered', 'delay_bufalloc'],
...                 op_flags=[['readonly'], ['readwrite', 'allocate']],
...                 op_axes=[None, axeslist],
...                 op_dtypes=['float64', 'float64'])
...     with it:
...         it.operands[1][...] = 0
...         it.reset()
...         for x, y in it:
...             y[...] += x*x
...         return it.operands[1]
...
>>> a = np.arange(6).reshape(2,3)
>>> sum_squares_py(a)
array(55.0)
>>> sum_squares_py(a, axis=-1)
array([  5.,  50.])

为了实现这个功能,我们替换了内部循环(y [...] +=x*x)带有专门用于float64数据类型的cython代码。启用“外部_循环”标志后,提供给内部循环的数组将始终是一维的,因此几乎不需要进行检查。

这是求和平方的列表。Pyx::

import numpy as np
cimport numpy as np
cimport cython

def axis_to_axeslist(axis, ndim):
    if axis is None:
        return [-1] * ndim
    else:
        if type(axis) is not tuple:
            axis = (axis,)
        axeslist = [1] * ndim
        for i in axis:
            axeslist[i] = -1
        ax = 0
        for i in range(ndim):
            if axeslist[i] != -1:
                axeslist[i] = ax
                ax += 1
        return axeslist

@cython.boundscheck(False)
def sum_squares_cy(arr, axis=None, out=None):
    cdef np.ndarray[double] x
    cdef np.ndarray[double] y
    cdef int size
    cdef double value

    axeslist = axis_to_axeslist(axis, arr.ndim)
    it = np.nditer([arr, out], flags=['reduce_ok', 'external_loop',
                                      'buffered', 'delay_bufalloc'],
                op_flags=[['readonly'], ['readwrite', 'allocate']],
                op_axes=[None, axeslist],
                op_dtypes=['float64', 'float64'])
    with it:
        it.operands[1][...] = 0
        it.reset()
        for xarr, yarr in it:
            x = xarr
            y = yarr
            size = x.shape[0]
            for i in range(size):
               value = x[i]
               y[i] = y[i] + value * value
        return it.operands[1]

在这台机器上,将.pyx文件构建成一个模块的过程如下所示,但您可能需要找到一些Cython教程来告诉您系统配置的具体情况。::

$ cython sum_squares.pyx
$ gcc -shared -pthread -fPIC -fwrapv -O2 -Wall -I/usr/include/python2.7 -fno-strict-aliasing -o sum_squares.so sum_squares.c

从python解释器运行这个程序会产生与我们的原生python/numpy代码相同的答案。

例子

>>> from sum_squares import sum_squares_cy
>>> a = np.arange(6).reshape(2,3)
>>> sum_squares_cy(a)
array(55.0)
>>> sum_squares_cy(a, axis=-1)
array([  5.,  50.])

在ipython中进行一点计时显示,Cython内部循环的开销和内存分配减少,这为使用numpy内置的sum函数的简单python代码和表达式提供了非常好的加速。::

>>> a = np.random.rand(1000,1000)

>>> timeit sum_squares_py(a, axis=-1)
10 loops, best of 3: 37.1 ms per loop

>>> timeit np.sum(a*a, axis=-1)
10 loops, best of 3: 20.9 ms per loop

>>> timeit sum_squares_cy(a, axis=-1)
100 loops, best of 3: 11.8 ms per loop

>>> np.all(sum_squares_cy(a, axis=-1) == np.sum(a*a, axis=-1))
True

>>> np.all(sum_squares_py(a, axis=-1) == np.sum(a*a, axis=-1))
True