torch.split()-程序员宅基地

技术标签: python  深度学习  pytorch  

torch.split()

官网链接:https://pytorch.org/docs/stable/torch.html
官网解释:Splits the tensor into chunks.——PyTorch中用于分割张量的函数。
作用:将一个多维张量分割成多个张量。

If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

定义是:

torch.split(tensor, split_size_or_sections, dim=0)
参数解释:
- tensor:要分割的输入张量
- split_size_or_sections:
    - 如果是一个整数,则表示分割成每个张量里包含split_size_or_sections个张量,而不是分成split_size_or_sections个
    - 如果是一个列表,则表示对dim维度进行分割,分割为指定大小的张量
- dim:沿着哪个维度进行分割,默认是dim=0,第一维

例1,有这样一个3D张量:

# 生成大小为(2, 4, 8)的随机张量
random_tensor = torch.rand(2, 4, 8)
tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
         [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
         [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
         [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],

        [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
         [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
         [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
         [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]])

我们可以这样分割:

  1. torch.split(random_tensor, 2, dim=1) :分割第二维(dim=1)
split_2 = torch.split(random_tensor, 2, dim=1)  # 返回一个元组 tuple
split_2
# (tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
#           [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660]],
 
#          [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
#           [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687]]]),
#  tensor([[[0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
#           [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],
 
#          [[0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
#           [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]]))
len(split_2)   # 2   
split_2[0]
# tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
#          [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660]],

#         [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
#          [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687]]])
split_2[0].size()  # torch.Size([2, 2, 8])
split_2[1]
# tensor([[[0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
#          [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],

#         [[0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
#          [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]])
split_2[1].size()  # torch.Size([2, 2, 8])
  1. torch.split(random_tensor, 3, dim=1) 与上例对比
split_3 = torch.split(random_tensor, 3, dim=1)
split_3
# (tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
#           [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
#           [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943]],
 
#          [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
#           [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
#           [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221]]]),
#  tensor([[[0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],
 
#          [[0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]]))
len(split_3)  # 2   1维长度为4,第一次取3,第二次也应取3,但是剩余长度不够,所以取1
split_3[0]    # torch.Size([2, 3, 8])
# tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
#          [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
#          [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943]],

#         [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
#          [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
#          [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221]]])
split_3[1]   # torch.Size([2, 1, 8])
# tensor([[[0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],

#         [[0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]])
  1. torch.split(random_tensor, [1, 3], dim=1)
split_1_3 = torch.split(random_tensor, [1, 3], dim=1) # 列表中数值总和必须与原维度数值相等
split_1_3
# (tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261]],
 
#          [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776]]]),
#  tensor([[[0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
#           [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
#           [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],
 
#          [[0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
#           [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
#           [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]]))
len(split_1_3)  #2
split_1_3[0]  # torch.Size([2, 1, 8])
# tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261]],

#         [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776]]])
split_1_3[1]   # torch.Size([2, 3, 8])
# tensor([[[0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
#          [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
#          [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],

#         [[0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
#          [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
#          [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]])

例2,有这样一个3D张量:

random_tensor = torch.rand(2, 2, 3)
tensor([[[0.0445, 0.0481, 0.1199],
         [0.2850, 0.1215, 0.0584]],

        [[0.1323, 0.4458, 0.0899],
         [0.3338, 0.3624, 0.7511]]])
  1. torch.split(random_tensor, 2, dim=1):分割第二维(dim=1),第一次取两个张量,数据取完。这里本身就是两个张量,所以还是返回自身
split_2 = torch.split(random_tensor, 2, dim=1)   # 返回元组
split_2
# (tensor([[[0.0445, 0.0481, 0.1199],
#           [0.2850, 0.1215, 0.0584]],
 
#          [[0.1323, 0.4458, 0.0899],
#           [0.3338, 0.3624, 0.7511]]]),)
len(split_2)   # 1
split_2[0]     # torch.Size([2, 2, 3])
# tensor([[[0.0445, 0.0481, 0.1199],
#          [0.2850, 0.1215, 0.0584]],

#         [[0.1323, 0.4458, 0.0899],
#          [0.3338, 0.3624, 0.7511]]])
split_2[1]   # 报错
  1. torch.split(random_tensor, [1, 2], dim=2):沿第三维(dim=2)分割
split_1_2 = torch.split(random_tensor, [1, 2], dim=2) # 返回元组
split_1_2
# (tensor([[[0.0445],
#           [0.2850]],
 
#          [[0.1323],
#           [0.3338]]]),
#  tensor([[[0.0481, 0.1199],
#           [0.1215, 0.0584]],
 
#          [[0.4458, 0.0899],
#           [0.3624, 0.7511]]]))
len(split_1_2)  # 2
split_1_2[0]    # torch.Size([2, 2, 1])
# tensor([[[0.0445],
#          [0.2850]],

#         [[0.1323],
#          [0.3338]]])
split_1_2[1]   # torch.Size([2, 2, 2])
# tensor([[[0.0481, 0.1199],
#          [0.1215, 0.0584]],

#         [[0.4458, 0.0899],
#          [0.3624, 0.7511]]])

所以,torch.split()是一个很有用的函数,可以轻松地将张量分割成任意形状和大小的张量列表,以用于后续处理。

Tips:
感谢@qq_42798074指正
感谢@qq_41720271指正

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_39398917/article/details/131075177

智能推荐

攻防世界_难度8_happy_puzzle_攻防世界困难模式攻略图文-程序员宅基地

文章浏览阅读645次。这个肯定是末尾的IDAT了,因为IDAT必须要满了才会开始一下个IDAT,这个明显就是末尾的IDAT了。,对应下面的create_head()代码。,对应下面的create_tail()代码。不要考虑爆破,我已经试了一下,太多情况了。题目来源:UNCTF。_攻防世界困难模式攻略图文

达梦数据库的导出(备份)、导入_达梦数据库导入导出-程序员宅基地

文章浏览阅读2.9k次,点赞3次,收藏10次。偶尔会用到,记录、分享。1. 数据库导出1.1 切换到dmdba用户su - dmdba1.2 进入达梦数据库安装路径的bin目录,执行导库操作  导出语句:./dexp cwy_init/[email protected]:5236 file=cwy_init.dmp log=cwy_init_exp.log 注释:   cwy_init/init_123..._达梦数据库导入导出

js引入kindeditor富文本编辑器的使用_kindeditor.js-程序员宅基地

文章浏览阅读1.9k次。1. 在官网上下载KindEditor文件,可以删掉不需要要到的jsp,asp,asp.net和php文件夹。接着把文件夹放到项目文件目录下。2. 修改html文件,在页面引入js文件:<script type="text/javascript" src="./kindeditor/kindeditor-all.js"></script><script type="text/javascript" src="./kindeditor/lang/zh-CN.js"_kindeditor.js

STM32学习过程记录11——基于STM32G431CBU6硬件SPI+DMA的高效WS2812B控制方法-程序员宅基地

文章浏览阅读2.3k次,点赞6次,收藏14次。SPI的详情简介不必赘述。假设我们通过SPI发送0xAA,我们的数据线就会变为10101010,通过修改不同的内容,即可修改SPI中0和1的持续时间。比如0xF0即为前半周期为高电平,后半周期为低电平的状态。在SPI的通信模式中,CPHA配置会影响该实验,下图展示了不同采样位置的SPI时序图[1]。CPOL = 0,CPHA = 1:CLK空闲状态 = 低电平,数据在下降沿采样,并在上升沿移出CPOL = 0,CPHA = 0:CLK空闲状态 = 低电平,数据在上升沿采样,并在下降沿移出。_stm32g431cbu6

计算机网络-数据链路层_接收方收到链路层数据后,使用crc检验后,余数为0,说明链路层的传输时可靠传输-程序员宅基地

文章浏览阅读1.2k次,点赞2次,收藏8次。数据链路层习题自测问题1.数据链路(即逻辑链路)与链路(即物理链路)有何区别?“电路接通了”与”数据链路接通了”的区别何在?2.数据链路层中的链路控制包括哪些功能?试讨论数据链路层做成可靠的链路层有哪些优点和缺点。3.网络适配器的作用是什么?网络适配器工作在哪一层?4.数据链路层的三个基本问题(帧定界、透明传输和差错检测)为什么都必须加以解决?5.如果在数据链路层不进行帧定界,会发生什么问题?6.PPP协议的主要特点是什么?为什么PPP不使用帧的编号?PPP适用于什么情况?为什么PPP协议不_接收方收到链路层数据后,使用crc检验后,余数为0,说明链路层的传输时可靠传输

软件测试工程师移民加拿大_无证移民,未受过软件工程师的教育(第1部分)-程序员宅基地

文章浏览阅读587次。软件测试工程师移民加拿大 无证移民,未受过软件工程师的教育(第1部分) (Undocumented Immigrant With No Education to Software Engineer(Part 1))Before I start, I want you to please bear with me on the way I write, I have very little gen...

随便推点

Thinkpad X250 secure boot failed 启动失败问题解决_安装完系统提示secureboot failure-程序员宅基地

文章浏览阅读304次。Thinkpad X250笔记本电脑,装的是FreeBSD,进入BIOS修改虚拟化配置(其后可能是误设置了安全开机),保存退出后系统无法启动,显示:secure boot failed ,把自己惊出一身冷汗,因为这台笔记本刚好还没开始做备份.....根据错误提示,到bios里面去找相关配置,在Security里面找到了Secure Boot选项,发现果然被设置为Enabled,将其修改为Disabled ,再开机,终于正常启动了。_安装完系统提示secureboot failure

C++如何做字符串分割(5种方法)_c++ 字符串分割-程序员宅基地

文章浏览阅读10w+次,点赞93次,收藏352次。1、用strtok函数进行字符串分割原型: char *strtok(char *str, const char *delim);功能:分解字符串为一组字符串。参数说明:str为要分解的字符串,delim为分隔符字符串。返回值:从str开头开始的一个个被分割的串。当没有被分割的串时则返回NULL。其它:strtok函数线程不安全,可以使用strtok_r替代。示例://借助strtok实现split#include <string.h>#include <stdio.h&_c++ 字符串分割

2013第四届蓝桥杯 C/C++本科A组 真题答案解析_2013年第四届c a组蓝桥杯省赛真题解答-程序员宅基地

文章浏览阅读2.3k次。1 .高斯日记 大数学家高斯有个好习惯:无论如何都要记日记。他的日记有个与众不同的地方,他从不注明年月日,而是用一个整数代替,比如:4210后来人们知道,那个整数就是日期,它表示那一天是高斯出生后的第几天。这或许也是个好习惯,它时时刻刻提醒着主人:日子又过去一天,还有多少时光可以用于浪费呢?高斯出生于:1777年4月30日。在高斯发现的一个重要定理的日记_2013年第四届c a组蓝桥杯省赛真题解答

基于供需算法优化的核极限学习机(KELM)分类算法-程序员宅基地

文章浏览阅读851次,点赞17次,收藏22次。摘要:本文利用供需算法对核极限学习机(KELM)进行优化,并用于分类。

metasploitable2渗透测试_metasploitable2怎么进入-程序员宅基地

文章浏览阅读1.1k次。一、系统弱密码登录1、在kali上执行命令行telnet 192.168.26.1292、Login和password都输入msfadmin3、登录成功,进入系统4、测试如下:二、MySQL弱密码登录:1、在kali上执行mysql –h 192.168.26.129 –u root2、登录成功,进入MySQL系统3、测试效果:三、PostgreSQL弱密码登录1、在Kali上执行psql -h 192.168.26.129 –U post..._metasploitable2怎么进入

Python学习之路:从入门到精通的指南_python人工智能开发从入门到精通pdf-程序员宅基地

文章浏览阅读257次。本文将为初学者提供Python学习的详细指南,从Python的历史、基础语法和数据类型到面向对象编程、模块和库的使用。通过本文,您将能够掌握Python编程的核心概念,为今后的编程学习和实践打下坚实基础。_python人工智能开发从入门到精通pdf

推荐文章

热门文章

相关标签