본문 바로가기
machine learning/pytorch

[pytorch] list of tensors 의 평균구하기

by 단창 2020. 8. 21.

my_list = [tensor(0.8223, device='cuda:0'), tensor(1.8351, device='cuda:0'), tensor(1.4888, device='cuda:0'),]

 

np.mean(my_list) 하면, 

TypeError: mean(): argument 'input' (position 1) must be Tensor, not list

에러가 난다. 

해결> 

mean = torch.mean(torch.stack(my_list))

반응형