记录码代码中常用的一些pytorch函数,提升写代码的效率。
对张量a按索引张量b取值。则最后得到的张量c维度应和张量b维度相同。
a = torch.tensor([[3, 4, 5],[6, 7, 8],[9, 10, 11]])
b = torch.tensor([[2, 1, 0]])
# dim=0 按列取值(分别取第一列索引2、第二列索引1、第三列索引0的值)
c = torch.gather(a, dim=0, index=b)
print(c) #c=tensor([[9, 7, 5]])
# dim=-1 按行取值(分别取第一行索引2、第一行索引1、第一行索引0的值)
c = torch.gather(a, dim=-1, index=b)
print(c) #c=tensor([[5, 4, 3]])
对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False。输入的第二个张量可以是数字或张量,可以和第一个维度不同(广播扩展维度)
与torch.equal不同,torch.eq()是逐元素对比,torch.equal()是整个张量和张量是否相等。
print(torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[ True, False],# [False, True]])
ge(input, other): 逐元素对比input>=other 返回一个 对角线上为1,其他地方为0 的二维张量。 a.masked_fill(mask, value), 其中mask必须是一个二值张量(ByteTensor),且大小维度必须和a一样。
gt(input, other): 逐元素对比input>other
le(input, other): 逐元素对比input<=other
lt(input, other): 逐元素对比inputprint(torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[False, False],#[ True, False]])
torch.eye(raw, col)
torch.eye(raw, col), 其中raw是必须给出的,col可不给,默认为rawtorch.eye(3)
# tensor([[ 1., 0., 0.],# [ 0., 1., 0.],# [ 0., 0., 1.]])
torch.masked_fill()
该函数将a中对应mask为1的值替换为value。a = torch.tensor([1,2,3,5,2,1])
a = a[:,None]
mask = torch.eq(a, a.t()).bool()
print(mask)
#tensor([[ True, False, False, False, False, True],# [False, True, False, False, True, False],# [False, False, True, False, False, False],# [False, False, False, True, False, False],# [False, True, False, False, True, False],# [ True, False, False, False, False, True]])
eye = torch.eye(mask.shape[0], mask.shape[1]).bool()
print(eye)
#tensor([[ True, False, False, False, False, False],# [False, True, False, False, False, False],# [False, False, True, False, False, False],# [False, False, False, True, False, False],# [False, False, False, False, True, False],# [False, False, False, False, False, True]])
mask_pos = mask.masked_fill(eye, 0)
print(mask_pos)
#tensor([[False, False, False, False, False, True],# [False, False, False, False, True, False],# [False, False, False, False, False, False],# [False, False, False, False, False, False],# [False, True, False, False, False, False],# [ True, False, False, False, False, False]])