博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensor索引操作
阅读量:5294 次
发布时间:2019-06-14

本文共 2315 字,大约阅读时间需要 7 分钟。

 
  1. #Tensor索引操作  
  2.     ''''' 
  3.     Tensor支持与numpy.ndarray类似的索引操作,语法上也类似 
  4.     如无特殊说明,索引出来的结果与原tensor共享内存,即修改一个,另一个会跟着修改 
  5.     '''  
  6.     import torch as t  
  7.       
  8.     a = t.randn(3,4)  
  9.     '''''tensor([[ 0.1986,  0.1809,  1.4662,  0.6693], 
  10.             [-0.8837, -0.0196, -1.0380,  0.2927], 
  11.             [-1.1032, -0.2637, -1.4972,  1.8135]])'''  
  12.     print(a[0])         #第0行  
  13.     '''''tensor([0.1986, 0.1809, 1.4662, 0.6693])'''  
  14.     print(a[:,0])       #第0列  
  15.     '''''tensor([ 0.1986, -0.8837, -1.1032])'''  
  16.     print(a[0][2])      #第0行第2个元素,等价于a[0,2]  
  17.     '''''tensor(1.4662)'''  
  18.     print(a[0][-1])     #第0行最后一个元素  
  19.     '''''tensor(0.6693)'''  
  20.     print(a[:2,0:2])    #前两行,第0,1列  
  21.     '''''tensor([[ 0.1986,  0.1809], 
  22.             [-0.8837, -0.0196]])'''  
  23.       
  24.     print(a[0:1,:2])    #第0行,前两列  
  25.     '''''tensor([[0.1986, 0.1809]])'''  
  26.     print(a[0,:2])      #注意两者的区别,形状不同  
  27.     '''''tensor([0.1986, 0.1809])'''  
  28.       
  29.     print(a>1)  
  30.     '''''tensor([[0, 0, 1, 0], 
  31.             [0, 0, 0, 0], 
  32.             [0, 0, 0, 1]], dtype=torch.uint8)'''  
  33.     print(a[a>1])        #等价于a.masked_select(a>1),选择结果与原tensor不共享内存空间  
  34.     print(a.masked_select(a>1))  
  35.     '''''tensor([1.4662, 1.8135]) 
  36.     tensor([1.4662, 1.8135])'''  
  37.     print(a[t.LongTensor([0,1])])  
  38.     '''''tensor([[ 0.1986,  0.1809,  1.4662,  0.6693], 
  39.             [-0.8837, -0.0196, -1.0380,  0.2927]])'''  
  40.       
  41.     ''''' 
  42.                             常用的选择函数 
  43.     index_select(input,dim,index)   在指定维度dim上选取,列如选择某些列、某些行 
  44.     masked_select(input,mask)       例子如上,a[a>0],使用ByteTensor进行选取 
  45.     non_zero(input)                 非0元素的下标 
  46.     gather(input,dim,index)         根据index,在dim维度上选取数据,输出size与index一样 
  47.     gather是一个比较复杂的操作,对一个二维tensor,输出的每个元素如下: 
  48.         out[i][j] = input[index[i][j]][j]   #dim = 0 
  49.         out[i][j] = input[i][index[i][j]]   #dim = 1 
  50.     '''  
  51.       
  52.     b = t.arange(0,16).view(4,4)  
  53.     '''''tensor([[ 0,  1,  2,  3], 
  54.             [ 4,  5,  6,  7], 
  55.             [ 8,  9, 10, 11], 
  56.             [12, 13, 14, 15]])'''  
  57.     index = t.LongTensor([[0,1,2,3]])  
  58.     print(b.gather(0,index))            #取对角线元素  
  59.     '''''tensor([[ 0,  5, 10, 15]])'''  
  60.       
  61.     index = t.LongTensor([[3,2,1,0]]).t()       #取反对角线上的元素  
  62.     print(b.gather(1,index))  
  63.     '''''tensor([[ 3], 
  64.             [ 6], 
  65.             [ 9], 
  66.             [12]])'''  
  67.       
  68.     index = t.LongTensor([[3,2,1,0]])           #取反对角线的元素,与上面不同  
  69.     print(b.gather(0,index))  
  70.     '''''tensor([[12,  9,  6,  3]])'''  
  71.       
  72.     index = t.LongTensor([[0,1,2,3],[3,2,1,0]]).t()  
  73.     print(b.gather(1,index))  
  74.     '''''tensor([[ 0,  3], 
  75.             [ 5,  6], 
  76.             [10,  9], 
  77.             [15, 12]])'''  
  78.       
  79.     ''''' 
  80.     与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而 
  81.     scatter_是把取出的数据再放回去,scatter_函数时inplace操作 
  82.     out = input.gather(dim,index) 
  83.     out = Tensor() 
  84.     out.scatter_(dim,index) 
  85.     '''  
  86.       
  87.     x = t.rand(2, 5)  
  88.     print(x)  
  89.     c = t.zeros(3, 5).scatter_(0, t.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)  
  90.     print(c)  
  91. 2018-10-23 20:30:30       

转载于:https://www.cnblogs.com/monkeyT/p/9839150.html

你可能感兴趣的文章
三.野指针和free
查看>>
activemq5.14+zookeeper3.4.9实现高可用
查看>>
TCP/IP详解学习笔记(3)IP协议ARP协议和RARP协议
查看>>
简单【用户输入验证】
查看>>
python tkinter GUI绘制,以及点击更新显示图片
查看>>
Spark基础脚本入门实践3:Pair RDD开发
查看>>
HDU4405--Aeroplane chess(概率dp)
查看>>
CS0103: The name ‘Scripts’ does not exist in the current context解决方法
查看>>
20130330java基础学习笔记-语句_for循环嵌套练习2
查看>>
Spring面试题
查看>>
窥视SP2010--第一章节--SP2010开发者路线图
查看>>
MVC,MVP 和 MVVM 的图示,区别
查看>>
C语言栈的实现
查看>>
代码为什么需要重构
查看>>
TC SRM 593 DIV1 250
查看>>
SRM 628 DIV2
查看>>
2018-2019-2 20165314『网络对抗技术』Exp5:MSF基础应用
查看>>
统计单词,字符,和行
查看>>
Python-S9-Day127-Scrapy爬虫框架2
查看>>
使用Chrome(PC)调试移动设备上的网页
查看>>