욤미의 개발일지

6. 파이토치(PyTorch): 클래스로 파이토치 모델 구현하기 본문

PyTorch

6. 파이토치(PyTorch): 클래스로 파이토치 모델 구현하기

욤미 2023. 6. 5. 16:25
728x90
반응형

📌 구글 colab에서 실습한 내용


파이토치의 대부분의 구현체들은 대부분 모델을 생성할 때 클래스(Class)를 사용한다. 선형 회귀를 클래스로 구현해보자.


1. 모델을 클래스로 구현하기

  • __int__() 생성자: 객체가 생성될 때 자동으로 호출
  • super(): 클래스는 nn.Module 클래스의 속성들을 가지고 초기화
  • foward():모델이 학습데이터를 입력받아서 forward 연산을 진행시키는 함수로 model 객체를 데이터와 함께 호출하면 자동으로 실행된다.
class LinearRegressionModel(nn.Module): # torch.nn.Module을 상속받는 클래스
  def __init__(self): # 생성자, 객체가 생성될 때 자동으로 호출 
    super().__init__()
    self.linear = nn.Linear(1, 1) #  단순 선형 회귀, input_dim = 1, output_dim = 1
  
  def forward(self, x): # 순전파
    return self.linear(x)
model = LinearRegressionModel() # 모델 선언
728x90
반응형
Comments