https://stackoverflow.com/questions/55982067/how-to-correctly-access-elements-in-a-3d-pytorch-tensor
>>> a = torch.arange(4*3*2).reshape(4,3,2)
>>> tensor([[ 0, 2, 4],
[ 6, 8, 10],
[12, 14, 16],
[18, 20, 22]])
>>> a[:, :, 0].T
tensor([[ 0, 6, 12, 18],
[ 2, 8, 14, 20],
[ 4, 10, 16, 22]])
>>> a[:, :, 0]
# input tensor to work with
In [11]: a = torch.arange(4*3*2).reshape(4,3,2)
In [12]: a.shape
Out[12]: torch.Size([4, 3, 2])
In [13]: a
Out[13]:
tensor([[[ 0, 1], |
[ 2, 3], |
[ 4, 5]], |
[[ 6, 7], |
[ 8, 9], |
[10, 11]], |
[[12, 13], |
[14, 15], |
[16, 17]], |
[[18, 19], |
[20, 21], |
[22, 23]]]) |
In [14]: a[:, :, 0]
Out[14]:
tensor([[ 0, 2, 4],
[ 6, 8, 10],
[12, 14, 16],
[18, 20, 22]])
Explanation/Clarification:
The tensor has shape [4, 3, 2]
where 4
represents the number of blocks (block-0, ... block-3). Next, we have 3
which represents the number of rows in each block. And finally, we've 2
which represent the number of columns in each row. We slice this using the
a[:, :, 0]
.
To access, the block, we'd need only one index (viz. a[0], ... a[3]). To access a specific row in a specific block, we'd need two indices (viz. a[0, 1], ... a[3,2]). To access a specific column of a specific row from a specific block, we'd need three indices (viz. a[0, 1, 1] etc.,)
_______________________________________________________________________________
>>> a = torch.arange(64).reshape(8,4,2)
>>> a
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11],
[12, 13],
[14, 15]],
[[16, 17],
[18, 19],
[20, 21],
[22, 23]],
[[24, 25],
[26, 27],
[28, 29],
[30, 31]],
[[32, 33],
[34, 35],
[36, 37],
[38, 39]],
[[40, 41],
[42, 43],
[44, 45],
[46, 47]],
[[48, 49],
[50, 51],
[52, 53],
[54, 55]],
[[56, 57],
[58, 59],
[60, 61],
[62, 63]]])
>>> a = torch.arange(64).reshape(8,2,4)
>>> a
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]],
[[24, 25, 26, 27],
[28, 29, 30, 31]],
[[32, 33, 34, 35],
[36, 37, 38, 39]],
[[40, 41, 42, 43],
[44, 45, 46, 47]],
[[48, 49, 50, 51],
[52, 53, 54, 55]],
[[56, 57, 58, 59],
[60, 61, 62, 63]]])
>>> a = torch.arange(128).reshape(2,8,2,4)
>>> a
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[ 12, 13, 14, 15]],
[[ 16, 17, 18, 19],
[ 20, 21, 22, 23]],
[[ 24, 25, 26, 27],
[ 28, 29, 30, 31]],
[[ 32, 33, 34, 35],
[ 36, 37, 38, 39]],
[[ 40, 41, 42, 43],
[ 44, 45, 46, 47]],
[[ 48, 49, 50, 51],
[ 52, 53, 54, 55]],
[[ 56, 57, 58, 59],
[ 60, 61, 62, 63]]],
[[[ 64, 65, 66, 67],
[ 68, 69, 70, 71]],
[[ 72, 73, 74, 75],
[ 76, 77, 78, 79]],
[[ 80, 81, 82, 83],
[ 84, 85, 86, 87]],
[[ 88, 89, 90, 91],
[ 92, 93, 94, 95]],
[[ 96, 97, 98, 99],
[100, 101, 102, 103]],
[[104, 105, 106, 107],
[108, 109, 110, 111]],
[[112, 113, 114, 115],
[116, 117, 118, 119]],
[[120, 121, 122, 123],
[124, 125, 126, 127]]]])
_______________________________________________________________________________
>>> a = torch.arange(128).reshape(16,2,2,4)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: shape '[16, 2, 2, 4]' is invalid for input of size 128
>>> a = torch.arange(128).reshape(16,2,2,2)
>>> a
tensor([[[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]]],
[[[ 8, 9],
[ 10, 11]],
[[ 12, 13],
[ 14, 15]]],
[[[ 16, 17],
[ 18, 19]],
[[ 20, 21],
[ 22, 23]]],
[[[ 24, 25],
[ 26, 27]],
[[ 28, 29],
[ 30, 31]]],
[[[ 32, 33],
[ 34, 35]],
[[ 36, 37],
[ 38, 39]]],
[[[ 40, 41],
[ 42, 43]],
[[ 44, 45],
[ 46, 47]]],
[[[ 48, 49],
[ 50, 51]],
[[ 52, 53],
[ 54, 55]]],
[[[ 56, 57],
[ 58, 59]],
[[ 60, 61],
[ 62, 63]]],
[[[ 64, 65],
[ 66, 67]],
[[ 68, 69],
[ 70, 71]]],
[[[ 72, 73],
[ 74, 75]],
[[ 76, 77],
[ 78, 79]]],
[[[ 80, 81],
[ 82, 83]],
[[ 84, 85],
[ 86, 87]]],
[[[ 88, 89],
[ 90, 91]],
[[ 92, 93],
[ 94, 95]]],
[[[ 96, 97],
[ 98, 99]],
[[100, 101],
[102, 103]]],
[[[104, 105],
[106, 107]],
[[108, 109],
[110, 111]]],
[[[112, 113],
[114, 115]],
[[116, 117],
[118, 119]]],
[[[120, 121],
[122, 123]],
[[124, 125],
[126, 127]]]])
_______________________________________________________________________________
>>> a = torch.arange(256).reshape(16,2,4,2)
>>> a
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[ 10, 11],
[ 12, 13],
[ 14, 15]]],
[[[ 16, 17],
[ 18, 19],
[ 20, 21],
[ 22, 23]],
[[ 24, 25],
[ 26, 27],
[ 28, 29],
[ 30, 31]]],
[[[ 32, 33],
[ 34, 35],
[ 36, 37],
[ 38, 39]],
[[ 40, 41],
[ 42, 43],
[ 44, 45],
[ 46, 47]]],
[[[ 48, 49],
[ 50, 51],
[ 52, 53],
[ 54, 55]],
[[ 56, 57],
[ 58, 59],
[ 60, 61],
[ 62, 63]]],
[[[ 64, 65],
[ 66, 67],
[ 68, 69],
[ 70, 71]],
[[ 72, 73],
[ 74, 75],
[ 76, 77],
[ 78, 79]]],
[[[ 80, 81],
[ 82, 83],
[ 84, 85],
[ 86, 87]],
[[ 88, 89],
[ 90, 91],
[ 92, 93],
[ 94, 95]]],
[[[ 96, 97],
[ 98, 99],
[100, 101],
[102, 103]],
[[104, 105],
[106, 107],
[108, 109],
[110, 111]]],
[[[112, 113],
[114, 115],
[116, 117],
[118, 119]],
[[120, 121],
[122, 123],
[124, 125],
[126, 127]]],
[[[128, 129],
[130, 131],
[132, 133],
[134, 135]],
[[136, 137],
[138, 139],
[140, 141],
[142, 143]]],
[[[144, 145],
[146, 147],
[148, 149],
[150, 151]],
[[152, 153],
[154, 155],
[156, 157],
[158, 159]]],
[[[160, 161],
[162, 163],
[164, 165],
[166, 167]],
[[168, 169],
[170, 171],
[172, 173],
[174, 175]]],
[[[176, 177],
[178, 179],
[180, 181],
[182, 183]],
[[184, 185],
[186, 187],
[188, 189],
[190, 191]]],
[[[192, 193],
[194, 195],
[196, 197],
[198, 199]],
[[200, 201],
[202, 203],
[204, 205],
[206, 207]]],
[[[208, 209],
[210, 211],
[212, 213],
[214, 215]],
[[216, 217],
[218, 219],
[220, 221],
[222, 223]]],
[[[224, 225],
[226, 227],
[228, 229],
[230, 231]],
[[232, 233],
[234, 235],
[236, 237],
[238, 239]]],
[[[240, 241],
[242, 243],
[244, 245],
[246, 247]],
[[248, 249],
[250, 251],
[252, 253],
[254, 255]]]])
_______________________________________________________________________________
>>> a = torch.arange(256).reshape(16,2,2,4)
>>> a
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[ 12, 13, 14, 15]]],
[[[ 16, 17, 18, 19],
[ 20, 21, 22, 23]],
[[ 24, 25, 26, 27],
[ 28, 29, 30, 31]]],
[[[ 32, 33, 34, 35],
[ 36, 37, 38, 39]],
[[ 40, 41, 42, 43],
[ 44, 45, 46, 47]]],
[[[ 48, 49, 50, 51],
[ 52, 53, 54, 55]],
[[ 56, 57, 58, 59],
[ 60, 61, 62, 63]]],
[[[ 64, 65, 66, 67],
[ 68, 69, 70, 71]],
[[ 72, 73, 74, 75],
[ 76, 77, 78, 79]]],
[[[ 80, 81, 82, 83],
[ 84, 85, 86, 87]],
[[ 88, 89, 90, 91],
[ 92, 93, 94, 95]]],
[[[ 96, 97, 98, 99],
[100, 101, 102, 103]],
[[104, 105, 106, 107],
[108, 109, 110, 111]]],
[[[112, 113, 114, 115],
[116, 117, 118, 119]],
[[120, 121, 122, 123],
[124, 125, 126, 127]]],
[[[128, 129, 130, 131],
[132, 133, 134, 135]],
[[136, 137, 138, 139],
[140, 141, 142, 143]]],
[[[144, 145, 146, 147],
[148, 149, 150, 151]],
[[152, 153, 154, 155],
[156, 157, 158, 159]]],
[[[160, 161, 162, 163],
[164, 165, 166, 167]],
[[168, 169, 170, 171],
[172, 173, 174, 175]]],
[[[176, 177, 178, 179],
[180, 181, 182, 183]],
[[184, 185, 186, 187],
[188, 189, 190, 191]]],
[[[192, 193, 194, 195],
[196, 197, 198, 199]],
[[200, 201, 202, 203],
[204, 205, 206, 207]]],
[[[208, 209, 210, 211],
[212, 213, 214, 215]],
[[216, 217, 218, 219],
[220, 221, 222, 223]]],
[[[224, 225, 226, 227],
[228, 229, 230, 231]],
[[232, 233, 234, 235],
[236, 237, 238, 239]]],
[[[240, 241, 242, 243],
[244, 245, 246, 247]],
[[248, 249, 250, 251],
[252, 253, 254, 255]]]])
_______________________________________________________________________________
>>> a = torch.arange(1024).reshape(16,2,2,4,4)
>>> a
tensor([[[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[ 12, 13, 14, 15]],
[[ 16, 17, 18, 19],
[ 20, 21, 22, 23],
[ 24, 25, 26, 27],
[ 28, 29, 30, 31]]],
[[[ 32, 33, 34, 35],
[ 36, 37, 38, 39],
[ 40, 41, 42, 43],
[ 44, 45, 46, 47]],
[[ 48, 49, 50, 51],
[ 52, 53, 54, 55],
[ 56, 57, 58, 59],
[ 60, 61, 62, 63]]]],
[[[[ 64, 65, 66, 67],
[ 68, 69, 70, 71],
[ 72, 73, 74, 75],
[ 76, 77, 78, 79]],
[[ 80, 81, 82, 83],
[ 84, 85, 86, 87],
[ 88, 89, 90, 91],
[ 92, 93, 94, 95]]],
[[[ 96, 97, 98, 99],
[ 100, 101, 102, 103],
[ 104, 105, 106, 107],
[ 108, 109, 110, 111]],
[[ 112, 113, 114, 115],
[ 116, 117, 118, 119],
[ 120, 121, 122, 123],
[ 124, 125, 126, 127]]]],
[[[[ 128, 129, 130, 131],
[ 132, 133, 134, 135],
[ 136, 137, 138, 139],
[ 140, 141, 142, 143]],
[[ 144, 145, 146, 147],
[ 148, 149, 150, 151],
[ 152, 153, 154, 155],
[ 156, 157, 158, 159]]],
[[[ 160, 161, 162, 163],
[ 164, 165, 166, 167],
[ 168, 169, 170, 171],
[ 172, 173, 174, 175]],
[[ 176, 177, 178, 179],
[ 180, 181, 182, 183],
[ 184, 185, 186, 187],
[ 188, 189, 190, 191]]]],
...,
[[[[ 832, 833, 834, 835],
[ 836, 837, 838, 839],
[ 840, 841, 842, 843],
[ 844, 845, 846, 847]],
[[ 848, 849, 850, 851],
[ 852, 853, 854, 855],
[ 856, 857, 858, 859],
[ 860, 861, 862, 863]]],
[[[ 864, 865, 866, 867],
[ 868, 869, 870, 871],
[ 872, 873, 874, 875],
[ 876, 877, 878, 879]],
[[ 880, 881, 882, 883],
[ 884, 885, 886, 887],
[ 888, 889, 890, 891],
[ 892, 893, 894, 895]]]],
[[[[ 896, 897, 898, 899],
[ 900, 901, 902, 903],
[ 904, 905, 906, 907],
[ 908, 909, 910, 911]],
[[ 912, 913, 914, 915],
[ 916, 917, 918, 919],
[ 920, 921, 922, 923],
[ 924, 925, 926, 927]]],
[[[ 928, 929, 930, 931],
[ 932, 933, 934, 935],
[ 936, 937, 938, 939],
[ 940, 941, 942, 943]],
[[ 944, 945, 946, 947],
[ 948, 949, 950, 951],
[ 952, 953, 954, 955],
[ 956, 957, 958, 959]]]],
[[[[ 960, 961, 962, 963],
[ 964, 965, 966, 967],
[ 968, 969, 970, 971],
[ 972, 973, 974, 975]],
[[ 976, 977, 978, 979],
[ 980, 981, 982, 983],
[ 984, 985, 986, 987],
[ 988, 989, 990, 991]]],
[[[ 992, 993, 994, 995],
[ 996, 997, 998, 999],
[1000, 1001, 1002, 1003],
[1004, 1005, 1006, 1007]],
[[1008, 1009, 1010, 1011],
[1012, 1013, 1014, 1015],
[1016, 1017, 1018, 1019],
[1020, 1021, 1022, 1023]]]]])
_______________________________________________________________________________
>>> a = torch.arange(256).reshape(4,2,2,4,4)
>>> a
tensor([[[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[ 12, 13, 14, 15]],
[[ 16, 17, 18, 19],
[ 20, 21, 22, 23],
[ 24, 25, 26, 27],
[ 28, 29, 30, 31]]],
[[[ 32, 33, 34, 35],
[ 36, 37, 38, 39],
[ 40, 41, 42, 43],
[ 44, 45, 46, 47]],
[[ 48, 49, 50, 51],
[ 52, 53, 54, 55],
[ 56, 57, 58, 59],
[ 60, 61, 62, 63]]]],
[[[[ 64, 65, 66, 67],
[ 68, 69, 70, 71],
[ 72, 73, 74, 75],
[ 76, 77, 78, 79]],
[[ 80, 81, 82, 83],
[ 84, 85, 86, 87],
[ 88, 89, 90, 91],
[ 92, 93, 94, 95]]],
[[[ 96, 97, 98, 99],
[100, 101, 102, 103],
[104, 105, 106, 107],
[108, 109, 110, 111]],
[[112, 113, 114, 115],
[116, 117, 118, 119],
[120, 121, 122, 123],
[124, 125, 126, 127]]]],
[[[[128, 129, 130, 131],
[132, 133, 134, 135],
[136, 137, 138, 139],
[140, 141, 142, 143]],
[[144, 145, 146, 147],
[148, 149, 150, 151],
[152, 153, 154, 155],
[156, 157, 158, 159]]],
[[[160, 161, 162, 163],
[164, 165, 166, 167],
[168, 169, 170, 171],
[172, 173, 174, 175]],
[[176, 177, 178, 179],
[180, 181, 182, 183],
[184, 185, 186, 187],
[188, 189, 190, 191]]]],
[[[[192, 193, 194, 195],
[196, 197, 198, 199],
[200, 201, 202, 203],
[204, 205, 206, 207]],
[[208, 209, 210, 211],
[212, 213, 214, 215],
[216, 217, 218, 219],
[220, 221, 222, 223]]],
[[[224, 225, 226, 227],
[228, 229, 230, 231],
[232, 233, 234, 235],
[236, 237, 238, 239]],
[[240, 241, 242, 243],
[244, 245, 246, 247],
[248, 249, 250, 251],
[252, 253, 254, 255]]]]])
>>>