深度学习pytorch——高阶OP(where & gather)(持续更新)

03-27 1164阅读 0评论

where

1、我们为什么需要where?

我们经常需要一个数据来自好几个的取值,而这些取值通常是不规律的,这就会导致使用传统的拆分和合并会非常的麻烦。我们也可以使用for循环嵌套来取值,也是可以的,但是使用for循环就意味着是python,那并没有很好的利用pytorch提供的使用gpu加速计算,当数据量非常大的话,会很大的拉低效率,因此我们使用pytorch提供的where。

2、where的使用

语法:torch.where(condition, x, y)  ------>  tensor

返回值:最后的返回值是一个张量,最后每个元素来自数据x,还是数据y依赖于条件。

使用where的条件:x.shape = y.shape = c.shape = condition.shape(c为结果,condition为0 1矩阵)

深度学习pytorch——高阶OP(where & gather)(持续更新)

代码示例:

cond = torch.tensor([[0.6,0.7],[0.8,0.4]])
a = torch.zeros(2,2)
b = torch.ones(2,2)
print(torch.where(cond>0.5,a,b))
# tensor([[0., 0.],
#         [0., 1.]])

gather

1、我们为什么需要gather?

gather:根据index收集数据。

不使用gather的情况:

深度学习pytorch——高阶OP(where & gather)(持续更新)

可以从上图中看出,索引是非常繁琐的,而且不小心就看错了,虽说也不是很难,但是深度学习处理的数据都是非常庞大的,比如一个1024*1024的图片,这时候内心是崩溃的🌹。还有一点,我们可以使用gpu帮助我们加快数据处理的效率。 

2、gather的使用

语法:torch.gather(input, dim, index, out=None) -----> tensor

input:表

dim:在哪个维度查表

index:索引表

代码示例:

prob=torch.randn(4,10)
idx=prob.topk(dim=1,k=3)
idx=idx[1]
# 以上为了得到索引表
label=torch.arange(10)+100
print(torch.gather(label.expand(4,10),dim=1,index=idx))

免责声明
本网站所收集的部分公开资料来源于AI生成和互联网,转载的目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,也不构成任何其他建议。
文章版权声明:除非注明,否则均为主机测评原创文章,转载或复制请以超链接形式并注明出处。

发表评论

快捷回复: 表情:
评论列表 (暂无评论,1164人围观)

还没有评论,来说两句吧...

目录[+]