pytorch에서 nn.Module 을 기반으로 만들어진 모듈은
forward 함수를 호출 하지도 않았는데 모델 instance을 호출하는 것 만으로 forward가 실행된다.
왜 이럴까?
python의 __init__ 은 모두가 알듯이 생성자 함수이다.
클래스의 instance가 생성되면 __init__가 호출된다.
__call__은 호출자 함수인데 instance가 호출되면 __call__가 실행된다. python은 클래스 인스턴스도 함수 처럼 호출할 수 있다
class test_class():
def __init__(self, n1,n2,n3):
print('called __init__')
self.n1 = n1
self.n2 = n2
self.n3 = n3
print(n1, n2, n3)
def __call__(self):
print('called __call__')
test = test_class(10,20,30)
test()
test 인스턴스를 생성하면
"
called __init__
10 20 30
"
test인스턴스를 호출하면
"
called __call__
"
이렇게 호출자함수 __call__가 호출된다.
__call__을 오버라이드하여 이렇게도 쓸수있다
class test_class():
def __init__(self, n1,n2,n3):
print('called __init__')
self.n1 = n1
self.n2 = n2
self.n3 = n3
print(n1, n2, n3)
def __call__(self):
print('called __call__')
def multiply(self):
return self.n1*self.n2
__call__ = multiply
test = test_class(10,20,30)
test()
이러면 test인스턴스 호출시 multiply가 호출된다.
"
called __init__
10 20 30
200
"
pytorch의 forward도 이렇게 구성되어있는데
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
1176 line에 __call__ : Callable[..., Any] = _call_impl
가 있고
__call_impl에는
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
이렇게 self.forward 가 호출된다그래서 pytorch에서 모델클래스는 forword이름로 작성해야하는 것이고 instance호출로 자동 forward가 실행되는것이다.
'machine learning > pytorch' 카테고리의 다른 글
[pytorch] Variable -> tensor 로 통합 (0) | 2020.08.21 |
---|---|
[pytorch] list of tensors 의 평균구하기 (0) | 2020.08.21 |