상세 컨텐츠

본문 제목

머신러닝과 딥러닝(18)_포켓몬 분류

카테고리 없음

by teminam 2023. 6. 21. 10:13

본문

valid, test data의 차이

  • validation data: 하이퍼 파라미터 튜닝 등 성능의 차이를 검증, 모델 학습에 사용되지 않지만 관여는 함
  • test data: 모델의 성능에 영향을 미치지 않고 단지 최종적으로 모델의 성능을 평가

1. 포켓몬 149종 분류

 

Pokemon Generation One

Gotta train 'em all!

www.kaggle.com

 

Complete Pokemon Image Dataset

2,500+ clean labeled images, all official art, for Generations 1 through 8.

www.kaggle.com

 

 

# transforms: Compose를 사용하여 사이즈, Affine, RandomHorizontalFlip, ToTensor 역할
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.RandomAffine(0, shear=10, scale=(0.8 , 1.2)),  # 랜덤하게 변경할 것을 선택(인덱스 0번부터 10가지 선택, 크기는 범위 +-20%하여 랜덤하게 변경)
        transforms.RandomHorizontalFlip(),  # 랜덤하게 이미지 좌우 반전
        transforms.ToTensor()  # 이미지를 텐서형으로 변환
    ]),
    'validation': transforms.Compose([
        transforms.Resize([224, 224]),  # 사이즈 맞춤
        transforms.ToTensor()
    ])
}

# # 타겟 데이터를 받아 텐서형으로 바꿔주는 함수
# def target_transforms(target):
#   return torch.FloatTensor([target])
# 데이터셋, 데이터로더(batch_size를 32) 만들기
image_datasets = {
    'train': datasets.ImageFolder('train', data_transforms['train']), # data 폴더 안에 train 폴더를 데이터셋화
    'validation': datasets.ImageFolder('validation', data_transforms['validation'])
}

# 데이터로더
dataloaders ={
    'train': DataLoader(
        image_datasets['train'],
        batch_size=32,
        shuffle=True
  ),
    'validation':DataLoader(
        image_datasets['validation'],
        batch_size=32,
        shuffle=False
    )
}

print(len(image_datasets['train']), len(image_datasets['validation']))

# 1개의 batch만큼 이미지를 출력
imgs, labels = next(iter(dataloaders['train']))


fig, axes = plt.subplots(4, 8, figsize=(20, 10))

for img, label, ax in zip(imgs, labels, axes.flatten()):
  ax.set_title(label.item())
  ax.imshow(img.permute(1,2,0))  # 텐서에 저장되어있을 때 shape(컬러, 가로, 세로) -> matplotlib에서는 (가로, 세로, 컬러채널)
  ax.axis('off')

# 학습
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 5

for epoch in range(epochs+1):
  for phase in ['train', 'validation']:  # train과  validation 따로 반복문을 돌아
    if phase == 'train':
      model.train()
    else:
      model.eval()   # 학습 모드에 있던 메모리를 지우고 바로 Test모드(훨씬 빠름)

    sum_losses = 0
    sum_accs = 0

    for x_batch, y_batch in dataloaders[phase]:  # train이라면 train에 대한 데이터로더, validataion이라면 validation에 대한 데이터로더 (따로 쓰지 않고 합쳐서 씀)
      x_batch = x_batch.to(device)
      y_batch = y_batch.to(device)

      y_pred = model(x_batch)

      loss = nn.CrossEntropyLoss()(y_pred, y_batch.long())

      if phase == 'train':
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

      sum_losses = sum_losses + loss.item()

      y_prob = nn.Softmax(1)(y_pred)
      y_pred_index = torch.argmax(y_prob, axis=1)
      acc = (y_batch == y_pred_index).float().sum() / len(y_batch) * 100
      sum_accs = sum_accs + acc.item()

    avg_loss = sum_losses / len(dataloaders[phase])
    avg_acc = sum_accs / len(dataloaders[phase])

    print(f'{phase:10s}: Epoch {epoch+1:4d}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.2f}%')

# 테스트(validation에 있는 2종의 포켓몬을 통해 분류테스트)
from PIL import Image  # 이미지를 코랩에서 띄우기


img1 = Image.open('/content/validation/Butterfree/0.jpg')   # validation에 있는 이미지 쓰기
img2 = Image.open('/content/validation/Charmeleon/2.jpg')

fig, axes = plt.subplots(1,2, figsize=(12,6))
axes[0].imshow(img1)
axes[0].axis('off')
axes[1].imshow(img2)
axes[1].axis('off')
plt.show()

fig, axes = plt.subplots(1,2, figsize=(12,6))

axes[0].set_title('{:.2f}% {},{:.2f}% {},{:.2f}% {}'.format(
    probs[0, 0] * 100, image_datasets['validation'].classes[indices[0,0]],
    probs[0, 1] * 100, image_datasets['validation'].classes[indices[0,1]],
    probs[0, 2] * 100, image_datasets['validation'].classes[indices[0,2]],
))
axes[0].imshow(img1)
axes[0].axis('off')

axes[1].set_title('{:.2f}% {},{:.2f}% {},{:.2f}% {}'.format(
    probs[1, 0] * 100, image_datasets['validation'].classes[indices[1,0]],
    probs[1, 1] * 100, image_datasets['validation'].classes[indices[1,1]],
    probs[1, 2] * 100, image_datasets['validation'].classes[indices[1,2]],
))
axes[1].imshow(img2)
axes[1].axis('off')
plt.show()