一、tensor.sum
为了更好地理解 `torch. sum ` 函数中 `dim` 参数的作用,我们可以将三维张量的求和过程分解,并通过具体的例子来说明不同 `dim` 参数的效果。
假设我们有一个 3x2x2 的张量,如下所示:
import torch
tensor = torch. tensor( [ [ [ 1 , 2 ] ,
[ 3 , 4 ] ] ,
[ [ 5 , 6 ] ,
[ 7 , 8 ] ] ,
[ [ 9 , 10 ] ,
[ 11 , 12 ] ] ] )
print ( tensor)
这个张量可以看作是包含三个 2x2 矩阵 的集合:
[
[ [ 1 , 2 ] ,
[ 3 , 4 ] ] ,
[ [ 5 , 6 ] ,
[ 7 , 8 ] ] ,
[ [ 9 , 10 ] ,
[ 11 , 12 ] ]
]
默认情况下,`torch. sum ` 会对所有元素求和:
total_sum = torch. sum ( tensor)
print ( total_sum)
解释:1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 = 78
`dim= 0 ` 表示沿最外层维度求和,即对每个 2x2 矩阵 的对应位置元素求和:
即制定dim= x, 表示沿着z方向求和( 即消灭z方向)
sum_dim0 = torch. sum ( tensor, dim= 0 )
print ( sum_dim0)
输出:
tensor( [ [ 15 , 18 ] ,
[ 21 , 24 ] ] )
解释:
第一个位置:[ 1 + 5 + 9 , 2 + 6 + 10 ] = [ 15 , 18 ]
第二个位置:[ 3 + 7 + 11 , 4 + 8 + 12 ] = [ 21 , 24 ]
`dim= 1 ` 表示沿每个 2x2 矩阵 的行方向求和:
即制定dim= y, 表示沿着y方向求和( 即消灭y方向)
sum_dim1 = torch. sum ( tensor, dim= 1 )
print ( sum_dim1)
输出:
tensor( [ [ 4 , 6 ] ,
[ 12 , 14 ] ,
[ 20 , 22 ] ] )
解释:
对第一个二维矩阵 :行和 [ 1 + 3 , 2 + 4 ] = [ 4 , 6 ]
对第二个二维矩阵 :行和 [ 5 + 7 , 6 + 8 ] = [ 12 , 14 ]
对第三个二维矩阵 :行和 [ 9 + 11 , 10 + 12 ] = [ 20 , 22 ]
`dim= 2 ` 表示沿每个 2x2 矩阵 的列方向求和:
即制定dim= z, 表示沿着x方向求和( 即消灭x方向)
sum_dim2 = torch. sum ( tensor, dim= 2 )
print ( sum_dim2)
输出:
tensor( [ [ 3 , 7 ] ,
[ 11 , 15 ] ,
[ 19 , 23 ] ] )
解释:
对第一个二维矩阵 :列和 [ 1 + 2 , 3 + 4 ] = [ 3 , 7 ]
对第二个二维矩阵 :列和 [ 5 + 6 , 7 + 8 ] = [ 11 , 15 ]
对第三个二维矩阵 :列和 [ 9 + 10 , 11 + 12 ] = [ 19 , 23 ]
dim= 0 :沿最外层维度求和,结果是一个 2x2 矩阵 ,每个元素是对应位置上所有二维矩阵 元素的和。
dim= 1 :沿每个二维矩阵 的行方向求和,结果是一个 3x2 矩阵 ,每个元素是对应位置上行的和。
dim= 2 :沿每个二维矩阵 的列方向求和,结果是一个 3x2 矩阵 ,每个元素是对应位置上列的和。