Pytorch:torch.gather 함수


Pytorch:torch.gather 함수

파이썬 딥러닝이나 강화학습중 간혹 gether 라는 함수가 있는데 이해가 안되서 정리해보는 그런 포스팅입니다 gather란 input텐서가 입력으로 주어지고, 차원 dim을 따라서 각 행으로부터 값을 취해, 새로운 텐서를 반환한다(return) out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 예시) dim=0 import torch # 2x3 크기의 텐서를 생성합니다. x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # dim=0으로 gather합니다. # 첫번째 인덱스는 [1, 2, 3]이고, 두번째 인덱스는 [4, 5, 6]입니다. # 따라서 idx가 0이면 [1, 2, 3], 1이면 [4, 5, 6]...



원문링크 : Pytorch:torch.gather 함수