【Pytorch】常用函数及其用法总结
创始人
2025-05-31 21:31:33
0

目录

    • python 张量信息提取、交互有关函数
      • torch.gather()
      • torch.eq()
      • torch.ge()/torch.gt() & torch.le()/torch.lt()
      • torch.eye(raw, col)
      • torch.masked_fill()

python 张量信息提取、交互有关函数

记录码代码中常用的一些pytorch函数,提升写代码的效率。

torch.gather()

对张量a按索引张量b取值。则最后得到的张量c维度应和张量b维度相同。

a = torch.tensor([[3,  4,  5],[6,  7,  8],[9, 10, 11]])
b = torch.tensor([[2, 1, 0]])
# dim=0 按列取值(分别取第一列索引2、第二列索引1、第三列索引0的值)
c = torch.gather(a, dim=0, index=b)
print(c) #c=tensor([[9, 7, 5]])
# dim=-1 按行取值(分别取第一行索引2、第一行索引1、第一行索引0的值)
c = torch.gather(a, dim=-1, index=b)
print(c) #c=tensor([[5, 4, 3]])

torch.eq()

对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False。输入的第二个张量可以是数字或张量,可以和第一个维度不同(广播扩展维度)
与torch.equal不同,torch.eq()是逐元素对比,torch.equal()是整个张量和张量是否相等。

print(torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[ True, False],# [False,  True]])

torch.ge()/torch.gt() & torch.le()/torch.lt()

ge(input, other): 逐元素对比input>=other
gt(input, other): 逐元素对比input>other
le(input, other): 逐元素对比input<=other
lt(input, other): 逐元素对比input

print(torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[False, False],#[ True, False]])

torch.eye(raw, col)

返回一个 对角线上为1,其他地方为0 的二维张量。
torch.eye(raw, col), 其中raw是必须给出的,col可不给,默认为raw

torch.eye(3)
# tensor([[ 1.,  0.,  0.],# [ 0.,  1.,  0.],# [ 0.,  0.,  1.]])

torch.masked_fill()

a.masked_fill(mask, value), 其中mask必须是一个二值张量(ByteTensor),且大小维度必须和a一样。
该函数将a中对应mask为1的值替换为value。

a = torch.tensor([1,2,3,5,2,1])
a = a[:,None]
mask = torch.eq(a, a.t()).bool()
print(mask)
#tensor([[ True, False, False, False, False,  True],# [False,  True, False, False,  True, False],# [False, False,  True, False, False, False],# [False, False, False,  True, False, False],# [False,  True, False, False,  True, False],# [ True, False, False, False, False,  True]])
eye = torch.eye(mask.shape[0], mask.shape[1]).bool()
print(eye)
#tensor([[ True, False, False, False, False, False],# [False,  True, False, False, False, False],# [False, False,  True, False, False, False],# [False, False, False,  True, False, False],# [False, False, False, False,  True, False],# [False, False, False, False, False,  True]])
mask_pos = mask.masked_fill(eye, 0)
print(mask_pos)
#tensor([[False, False, False, False, False,  True],# [False, False, False, False,  True, False],# [False, False, False, False, False, False],# [False, False, False, False, False, False],# [False,  True, False, False, False, False],# [ True, False, False, False, False, False]])

相关内容

热门资讯

保存时出现了1个错误,导致这篇... 当保存文章时出现错误时,可以通过以下步骤解决问题:查看错误信息:查看错误提示信息可以帮助我们了解具体...
汇川伺服电机位置控制模式参数配... 1. 基本控制参数设置 1)设置位置控制模式   2)绝对值位置线性模...
不能访问光猫的的管理页面 光猫是现代家庭宽带网络的重要组成部分,它可以提供高速稳定的网络连接。但是,有时候我们会遇到不能访问光...
表格中数据未显示 当表格中的数据未显示时,可能是由于以下几个原因导致的:HTML代码问题:检查表格的HTML代码是否正...
本地主机上的图像未显示 问题描述:在本地主机上显示图像时,图像未能正常显示。解决方法:以下是一些可能的解决方法,具体取决于问...
表格列调整大小出现问题 问题描述:表格列调整大小出现问题,无法正常调整列宽。解决方法:检查表格的布局方式是否正确。确保表格使...
不一致的条件格式 要解决不一致的条件格式问题,可以按照以下步骤进行:确定条件格式的规则:首先,需要明确条件格式的规则是...
Android|无法访问或保存... 这个问题可能是由于权限设置不正确导致的。您需要在应用程序清单文件中添加以下代码来请求适当的权限:此外...
【NI Multisim 14...   目录 序言 一、工具栏 🍊1.“标准”工具栏 🍊 2.视图工具...
银河麒麟V10SP1高级服务器... 银河麒麟高级服务器操作系统简介: 银河麒麟高级服务器操作系统V10是针对企业级关键业务...