Pytorch 프로그램을 하다 보면 tensor에서 특정 조건에 맞는 원소의 인덱스를 구해야 하는 경우가 많다
이런 경우 다음과 같이 하면 된다.
import torch
a = torch.tensor([1,2,3,4])
b = torch.tensor([0,4,3,5])
# 두 개의 tensor에서 다른 부분의 인덱스 구하기
idx = ((a-b)!=0).nonzero().flatten()
print(idx.tolist()) # [0,1,3]
# a 원소가 b 원소보다 큰 경우의 인덱스 구하기
idx = ((a-b)>0).nonzero().flatten()
print(idx.tolist()) # [0]
# a의 최대값과 최대값의 인덱스 구하기
print('Max: ', a.max(dim=0)[0].tolist()) # 4
print('Argmax: ', a.max(dim=0)[1].tolist()) # 3
(주의) Tensor가 2차원인 경우는 인덱스가 2차원이니 flatten을 사용하지 말고 약간 다르게 해야...
# 예컨대, 원소가 0 값을 가지는 것을 방지하려면 아래처럼 하면 된다
# b = torch.tensor([0, 1.0, 3.0, 5.0]) 인 경우
eps = 0.00001
b[b<eps] = eps # b <-- [0.00001, 1.0, 3.0, 5.0]
'개발 팁' 카테고리의 다른 글
Base64로 인코딩된 그림 저장하기 (Python) (0) | 2020.06.01 |
---|---|
원격접속 프로그램 (0) | 2020.05.28 |
우분투 cuda와 nvidia 드라이버 버전 충돌 해결방법 (0) | 2020.02.07 |
우분투 nvidia 드라이버 설치하기 (1) | 2020.02.03 |
ImportError: cannot import name 'PILLOW_VERSION' (0) | 2020.02.03 |
댓글