【Pytorch】常用函数及其用法总结
创始人
2025-05-31 21:31:33
0

目录

    • python 张量信息提取、交互有关函数
      • torch.gather()
      • torch.eq()
      • torch.ge()/torch.gt() & torch.le()/torch.lt()
      • torch.eye(raw, col)
      • torch.masked_fill()

python 张量信息提取、交互有关函数

记录码代码中常用的一些pytorch函数,提升写代码的效率。

torch.gather()

对张量a按索引张量b取值。则最后得到的张量c维度应和张量b维度相同。

a = torch.tensor([[3,  4,  5],[6,  7,  8],[9, 10, 11]])
b = torch.tensor([[2, 1, 0]])
# dim=0 按列取值(分别取第一列索引2、第二列索引1、第三列索引0的值)
c = torch.gather(a, dim=0, index=b)
print(c) #c=tensor([[9, 7, 5]])
# dim=-1 按行取值(分别取第一行索引2、第一行索引1、第一行索引0的值)
c = torch.gather(a, dim=-1, index=b)
print(c) #c=tensor([[5, 4, 3]])

torch.eq()

对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False。输入的第二个张量可以是数字或张量,可以和第一个维度不同(广播扩展维度)
与torch.equal不同,torch.eq()是逐元素对比,torch.equal()是整个张量和张量是否相等。

print(torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[ True, False],# [False,  True]])

torch.ge()/torch.gt() & torch.le()/torch.lt()

ge(input, other): 逐元素对比input>=other
gt(input, other): 逐元素对比input>other
le(input, other): 逐元素对比input<=other
lt(input, other): 逐元素对比input

print(torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[False, False],#[ True, False]])

torch.eye(raw, col)

返回一个 对角线上为1,其他地方为0 的二维张量。
torch.eye(raw, col), 其中raw是必须给出的,col可不给,默认为raw

torch.eye(3)
# tensor([[ 1.,  0.,  0.],# [ 0.,  1.,  0.],# [ 0.,  0.,  1.]])

torch.masked_fill()

a.masked_fill(mask, value), 其中mask必须是一个二值张量(ByteTensor),且大小维度必须和a一样。
该函数将a中对应mask为1的值替换为value。

a = torch.tensor([1,2,3,5,2,1])
a = a[:,None]
mask = torch.eq(a, a.t()).bool()
print(mask)
#tensor([[ True, False, False, False, False,  True],# [False,  True, False, False,  True, False],# [False, False,  True, False, False, False],# [False, False, False,  True, False, False],# [False,  True, False, False,  True, False],# [ True, False, False, False, False,  True]])
eye = torch.eye(mask.shape[0], mask.shape[1]).bool()
print(eye)
#tensor([[ True, False, False, False, False, False],# [False,  True, False, False, False, False],# [False, False,  True, False, False, False],# [False, False, False,  True, False, False],# [False, False, False, False,  True, False],# [False, False, False, False, False,  True]])
mask_pos = mask.masked_fill(eye, 0)
print(mask_pos)
#tensor([[False, False, False, False, False,  True],# [False, False, False, False,  True, False],# [False, False, False, False, False, False],# [False, False, False, False, False, False],# [False,  True, False, False, False, False],# [ True, False, False, False, False, False]])

相关内容

热门资讯

AWSECS:访问外部网络时出... 如果您在AWS ECS中部署了应用程序,并且该应用程序需要访问外部网络,但是无法正常访问,可能是因为...
AWSElasticBeans... 在Dockerfile中手动配置nginx反向代理。例如,在Dockerfile中添加以下代码:FR...
银河麒麟V10SP1高级服务器... 银河麒麟高级服务器操作系统简介: 银河麒麟高级服务器操作系统V10是针对企业级关键业务...
北信源内网安全管理卸载 北信源内网安全管理是一款网络安全管理软件,主要用于保护内网安全。在日常使用过程中,卸载该软件是一种常...
AWR报告解读 WORKLOAD REPOSITORY PDB report (PDB snapshots) AW...
AWS管理控制台菜单和权限 要在AWS管理控制台中创建菜单和权限,您可以使用AWS Identity and Access Ma...
​ToDesk 远程工具安装及... 目录 前言 ToDesk 优势 ToDesk 下载安装 ToDesk 功能展示 文件传输 设备链接 ...
群晖外网访问终极解决方法:IP... 写在前面的话 受够了群晖的quickconnet的小水管了,急需一个新的解决方法&#x...
不能访问光猫的的管理页面 光猫是现代家庭宽带网络的重要组成部分,它可以提供高速稳定的网络连接。但是,有时候我们会遇到不能访问光...
Azure构建流程(Power... 这可能是由于配置错误导致的问题。请检查构建流程任务中的“发布构建制品”步骤,确保正确配置了“Arti...