파이토치 3-4. nn.Module로 구현하는 선형 회귀


파이토치 3-4. nn.Module로 구현하는 선형 회귀

파이토치에서는 선형회귀모델이 nn.Linear() , 평균 제곱오차는 nn.functional.mse_loss() 라는 함수로 구현되어 있다. #선형 회귀 모델 import torch.nn as nn model = nn.Linear(input_dim, output_dim) # 평균 제곱 오차 import torch.nn.functional as F cost = F.mse_loss(prediction, y_train) model 에는 w와 b가 저장되어 있는데 한번 불러보자 w 에는 0.5153 b에는 -0.4414 가 들어가 있는데 이는 랜덤으로 초기화 되어있는 것이다 optimizer는 model.parameters() 를 이용하여 w와 b를 전달한다. lr는 0.01로 설정한다 . 훈련시킨 모델을 테스트 해보자 y= 2x 가 정답이었는데 4를 넣으니 8이 나왔다. w는 2이고 b는 0인 것을 확인할 수 있다. 총 정리 해보자면 H(x) 식에 입력x로 부터 예측된 y를 얻는 것을 ...



원문링크 : 파이토치 3-4. nn.Module로 구현하는 선형 회귀