欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Pytorch 使用tensor特定条件判断索引

程序员文章站 2022-07-09 16:58:36
torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”区别于python numpy中的where()直接可以...

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:pytorch torch.tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的tensor和原tensor共享相同的存储空间,但是返回的 tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

–假如a网络输出了一个tensor类型的变量a, a要作为输入传入到b网络中,如果我想通过损失函数反向传播修改b网络的参数,但是不想修改a网络的参数,这个时候就可以使用detcah()方法

a = a(input)
a = detach()
b = b(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=true)
x.requires_grad   #true
y = t.ones(1, requires_grad=true)
y.requires_grad   #true
x = x.detach()   #分离之后
x.requires_grad   #false
y = x+y         #tensor([2.])
y.requires_grad   #我还是true
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #none

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为none

既然谈到了修改模型的权重问题,那么还有一种情况是:

–假如a网络输出了一个tensor类型的变量a, a要作为输入传入到b网络中,如果我想通过损失函数反向传播修改a网络的参数,但是不想修改b网络的参数,这个时候又应该怎么办了?

这时可以使用tensor.requires_grad属性,只需要将requires_grad修改为false即可.

for param in b.parameters():
 param.requires_grad = false
a = a(input)
b = b(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持。如有错误或未考虑完全的地方,望不吝赐教。