본문 바로가기
개발 팁

[Pytorch] tensor에서 특정 조건에 맞는 원소의 인덱스 구하기

by Felizang 2020. 5. 15.

Pytorch 프로그램을 하다 보면 tensor에서 특정 조건에 맞는 원소의 인덱스를 구해야 하는 경우가 많다

이런 경우 다음과 같이 하면 된다.

 

import torch

a = torch.tensor([1,2,3,4])

b = torch.tensor([0,4,3,5])

 

# 두 개의 tensor에서 다른 부분의 인덱스 구하기

idx = ((a-b)!=0).nonzero().flatten()

print(idx.tolist())   # [0,1,3]

 

# a 원소가 b 원소보다 큰 경우의 인덱스 구하기

idx = ((a-b)>0).nonzero().flatten()

print(idx.tolist())   # [0]

 

# a의 최대값과 최대값의 인덱스 구하기

print('Max: ', a.max(dim=0)[0].tolist())   # 4 

print('Argmax: ', a.max(dim=0)[1].tolist())  # 3

 

 

(주의) Tensor가 2차원인 경우는 인덱스가 2차원이니 flatten을 사용하지 말고 약간 다르게 해야...

 

 

# 예컨대, 원소가 0 값을 가지는 것을 방지하려면 아래처럼 하면 된다

# b = torch.tensor([0, 1.0, 3.0, 5.0]) 인 경우

eps = 0.00001

b[b<eps] = eps    # b <-- [0.00001, 1.0, 3.0, 5.0]   

 

댓글