Worth spreading

cs231n (2017) - lecture3 _ 1 (Loss function) 본문

카테고리 없음

cs231n (2017) - lecture3 _ 1 (Loss function)

annual 2017. 11. 22. 17:43


안녕하세요! 

지난 시간에 뉴럴넷을 만들기 위한 첫번째 레고블럭인 선형분류기(linear classifier)를 배웠습니다.

선형분류기는 '파라미터적 접근법(parametric approach)'을 이용한다고 했습니다. 학습데이터를 summarize한 값을 파라미터(W) 값에 반영시킨다고 했죠.

 

우리의 선형분류기는 매우 심플했습니다. 

[32x32x3] 크기의 이미지를 입력받아서 한줄로 된 행렬(혹은 벡터)로 쭈욱 펴줬습니다( [32x32x3] -> [3072 x 1] )

그리고 W([10 x 3072])를 기다란 벡터([3072x1]) dot product해서 각 클래스의 점수에 해당하는 10개의 숫자(score)가 나오도록 만들었습니다.

그리고 여기서 나온 숫자들은 각 카테고리가 정답이 될 확률을 나타낸다고 했습니다.

여기서 우리 W의 각 행(row)( [10x3072]의 행렬 중 [1x3072] 한줄 )들은  10개의 카테고리의 각 이미지들에 대한 형상(templete)을 나타낸다고 했습니다.

각 행들이 입력데이터([3072x1])와 dot product되어 각 클래스에 대한 점수들이 나오게 되는겁니다.


이렇게 우리는 지난 시간에 선형분류기(linear classifier)가 무엇인지 대락젹으로 살펴보았습니다.

하지만 우리는 W 값을 고르는 방법이나, 학습데이터를 이용해 좋은 W값을 알아내는 방법에 대해서 아직 이야기하지 않았습니다.

이번 강의(lec3)에서는 이러한 것들을 다루도록 하겠습니다.

 

위 슬라이드를 보시면 세 개의 이미지에 대한 선형분류 결과가 각 열(column)에 나와있습니다.

고양이사진에 대한 선형분류기를 보시면 고양이의 점수(score)는 2.9입니다. 그런데 강아지나 개구리, 사슴 등의 점수가 더 높게 나와 있네요.

자동차사진에 대한 선형분류기는 올바르게 작동됐네요. 자동차의 점수가 6.04로 가장 높습니다.

하지만 개구리에 대한 선형분류기는 개구리 점수를 -4.34 로 주어 정답과 거리가 먼 답을 냈습니다. 


여기서 우리가 알고리즘을 만들어 '좋은 W값'을 자동으로 찾아내게 하려면 지금의 W가 얼마나 안좋은 지를 나타낼 수 있어야 할 것 같습니다.

이것을 함수로 구현한다면 이 함수는 W를 입력받아 지금의 W가 얼마나 나쁜지를 말해주는 함수여야 할 것입니다.

우리는 이 함수를 loss function이라고 부를 것입니다.


우리는 이번 강의에서 이미지분류문제에 바로 적용 시킬 수 있는 여러가지의 loss function을 살펴볼 것입니다.


우선 개념만 간단히 익히기 위해서 세 개의 카테고리를 가진 예제를 보겠습니다.

이 전 슬라이드에서 본 것과 마찬가지로 자동차에 대해서만 분류를 잘 해냈으며 고양이에 대해서는 그럭저럭 아쉽게 틀렸습니다.하지만 개구리에 대해서는 여전히 분류를 잘 못하고 있습니다.


우선 loss function을 살펴보기 위해 우리의 dataset부터 정의하겠습니다.

우리의 dataset은  N개로 이루어져있으며 x_i는 i번째 이미지, y_i는 i번째 이미지의 정답을 나타냅니다. 이 정답은 클래스의 개수에 따라 범위가 달라지는데요, 우리 예제에서 클래스의 개수는 3개이므로 1~3 혹은 0~2 값을 가질 수 있습니다.

그리고 우리는 loss function을 다음과 같이 이용할 것입니다

loss function은 L( f(x,W) , y ) 처럼 생겼으며 f(x,W)와  y를 입력으로 받고 있습니다.  f(x,w)는 우리가 사용했던 선형분류기이고 y는 위에서 정의했던 정답 label입니다.

우리는 각 데이터 i에 대한 모든  loss function의 값을 더하고 이걸 데이터 개수인 N으로 나눠서 우리 W의 loss 값을 알아낼 것입니다.


이제 첫 번째 loss function을 살펴보겠습니다.


첫 번째 loss function은 'Multiclass SVM loss'라고 불립니다. 여러 개의 클래스를 분류할 때 쓴다고 해서 multiclass라고 합니다.

위 슬라이드를 보시면 f(x_i, W)를 s로 축약하기로 했습니다.

loss score는 다음과 같이 구해집니다. 


*분류할 카테고리들 중에서 정답 클래스를 제외한 나머지 클래스들에 대해서 loss를 구할 것입니다.

예를 들어 cifar10이라면 10개의 클래스가 있지만 정답 클래스는 제외하므로 9개의 클래스에 대하여 loss score를 구합니다.

이제 9개의 클래스 (정답이 아닌 라벨 1,2,3,4,5,6,7,8,9) 에 대해 다음 연산을 반복합니다

<

1.  만약 (정답이 아닌 라벨1 의 점수 + 1) < (정답인 라벨의 점수)  이라면 이 클래스의 loss는 0입니다.

    정답라벨의 점수가 정답이 아닌 라벨의 점수보다 높으니까 잘한거죠 ! 그러니까 loss는 0입니다

2. 만약 그렇지 않다면 (정답인 라벨의 점수가 더 낮다면) loss 값은 ( 정답이 아닌 라벨1 의 점수 - 정답 라벨의 점수 + 1)이 됩니다.

---------------

이것을 C언어로 작성하면 다음과 같습니다. ( C를 모르시면 넘어가셔도 무방합니다)

//정답을 10이라고 하겠습니다 때문에 10에 대해서는 for문을 돌지 않습니다

for(int i=1; i<=9; i++) {

if( (score[i] + 1) < score[10] )

loss[i] = 0;

else

loss[i] = score[i] - score[10] + 1;

}

----------------

multiclass SVM loss를 그래프로 그려보면 다음과 같습니다. 보시면 그래프의 모양이 경첩(hinge)같이 생겼습니다. 때문에 hinge loss라고 불리기도 합니다.

여기서 loss function은 'safety margin'이라고 불리며 안정적인 분류를 위해 살짝 더 엄격하게 판단한다는 의미에서 1을 더해준 것입니다.

그래프에서 보시면 x축은 S_y_i값으로 정답 라벨의 점수, y축은 loss를 뜻합니다. x값이 커질수록 loss가 0에 가까워지는 것을 볼 수 있습니다. 


위에서 설명한 것을 수식으로 나타내면 위 슬라이드의 하단과 같이 max를 이용해서 정의할 수 있습니다.

(S_j - S_y_i + 1)의 값이 0보다 작으면 loss의 값은 0 이 되고 그렇지 않으면 이 값은 그대로 loss가 됩니다.

위 식을 한 문장으로 정리하면 "정답 라벨의 점수가 높으면  loss는 작다"입니다.


이제 위에서 봤던 예제에 SVM loss를 적용해보겠습니다 !


다른 부분은 보지 마시고 파란색으로 표시된 부분만 보시기 바랍니다.

여기서는 고양이 사진을 분류한 것에 대한 loss를 구했습니다. 정답라벨에 대해서는 연산을 하지 않는다고 했으므로 자동차와 개구리에 대한 점수인 5.1과 -1.7에 대한 loss  값을 계산합니다.

자동차클래스에 대한 loss : max(0 , 5.1 - 3.2 + 1)  ->  2.9

개구리클래스에 대한 loss : max(0, -1.7 - 3.2 + 1)  -> 0

 2.9 + 0 = 2.9


보시면 자동차를 고양이보다 높은 점수로 예측했으므로 자동차클래스에 대한 loss는 0보다 크게 나오는 것이 당연합니다.

또한 개구리는 고양이보다 1이상 작은 값을 가지고 있으므로 loss는 0이 나왔구요.


두 번째 예제로 자동차 사진에 대한 loss는 자동차의 점수(4.9)가 다른 두 클래스(1.3 , 2.0)에 비해 높으므로 loss는 0이 나올 것입니다.

하지만  세 번째 예제인 개구리는 loss가 꽤나 나올 것 같습니다. 계산을 해보겠습니다


고양이클래스에 대한 loss : max(0 , 2.2 -(-3.1) + 1) -> 6.3

자동차클래스에 대한 loss : max(0 , 2.5 -  -(3.1) + 1) -> 6.6

6.3 + 6.6 = 12.9

이제 우리의 최종 loss를 구할 수 있습니다. 최종 loss는 각 클래스들의 loss들의 평균으로

(2.9 + 0 + 12.9) / 3 = 5.27

5.27이 여기서 사용한 W의 loss가 됩니다.


# Johnson's Questions

여기서 존슨 선생님이 학생들에게 세 가지 질문을 던집니다

Q #1. What happens to loss if car scores change a bit?

A :  'a bit'이 얼마나일지는 모르지만 작다고 생각했을 때, loss 값은 변하지 않을 것이다.

SVM loss에서는 정답값이 틀린 값과 1이상 차이가 나게 만든다(safety margin). 때문에 이러한 작은 변화에 대해서는 safety margin이 버텨주어 loss값은 변하지 않을 것이다. 

( 또한 위에서 우리 자동차 다른 클래스에 비해 큰 점수를 갖고 있다. )


Q #2. What is the min/max possible loss?

A : 최소값 : 0 , 최대값 : infinete(무한대)

위 슬라이드의 그래프에 나타나듯 loss 는 0 이하로 내려가지 않는다. 반대로 정답 라벨의 예측값이 작아질수록 loss는 커지므로 무한대에 수렴하게 된다.


Q #3. At initialization W is small so all s ~ 0. What is the loss? (파라미터 값이 매우 작아서 0에 가깝다면 loss는 어떻게 될까?)

A : loss 는 (클래스의 개수 - 1)의 값을 갖는다.

SVM loss는 정답이 아닌 라벨에 대해서만 loss값을 계산한다고 했다. 

위에서 보았던 공식 ( S_j - S_y_i + 1) 에서 S는 입력데이터 x와 파라미터W의 곱셈이었다. 만약 W가 매우 작다면 S도 매우 작은 값이 나올 것이다. 때문에 모든 loss는 1이 될 것이고 이들이 더해져 loss 는 (클래스의 개수 - 1)이 될 것이다.


* Q#3은 실제로 SVM을 사용했을 때 버그가 있는지 찾아낼 수 있는 트릭 중 하나이다. W을 매우 작은 값으로 설정한 뒤 loss를 구해봤을 때 (클래스 개수-1)의 값이 나오지 않는다면 뭔가 코드에 이상이 있는 것이므로 다시한번 살펴봐야 할 것이다.


Q #4. What if the sum was over all classes? (including j = y_i)  (정답 라벨의 loss까지 더하면 어떻게 될까?)

A : loss가 1 증가할 것이다.

당연한 것이다. ( S_j - S_y_i + 1) 에서 S_j와 S_y_i가 같은 값이 될 것이므로 1이 남게되고 이것이 loss에 더해질 것이다.
가장 작은 loss가 0인 것이 직관적으로도 이해하기 좋기 때문에 통상적으로 정답 클래스에 대한 계산은 하지 않는다.


Q #5. What if we used mean instead of sum? (SVM 공식에서 시그마 대신 평균을 구한다면 어떻게 될까?)

A : 답은 변하지 않는다.

이것은 단지 이전의 loss를 n으로 나눈 것과 같은 역할을 한다.(rescale했다고 표현한다).


Q #6. What if we used 

이렇게 max를 취해준 뒤에 제곱을 한다면 이전의 SVM function과 달라질까?


A : 달라진다

기존의 SVM function에서 나온 값에 제곱을 취하게되면 기존의 SVM function과 다른 모양의 그래프가 그려진다. 따라서 좋고 나쁨을 판단하는 정도도 어느정도 달라지게 된다. 

위 식은 Squared hinge loss 라고도 불리며 실제로 사용되는 trick이다.


위 코드는 Multiclass SVM Loss를 파이썬으로 구현한 코드입니다. max함수, dot product, sum 함수는 모두 numpy 함수를 사용했습니다.

* numpy를 이용하면 다양한 연산(특히 선형대수와 관련된)을 간단한 코드로 수행할 수 있으며 numpy의 많은 함수들이 c로 구현돼있기 때문에 실행속도나 메모리 사용량 면에서도 아주 좋습니다.


위의 L_i_vectorized 함수를 정답이 아닌 클래스들에 대해서 반복적으로 호출하며 loss를 계산하면 됩니다.


Question : 만약 loss가 0인 W를 찾았다면 이 W는 유일(unique)할까?

Answer : No. 2*W도 loss가 0이 된다.

  


위 슬라이드를 보시면 w에 대한 loss가 0이라면 2*w에 대한 loss도 0이 되는 것을 보실 수 있습니다.

다음 시간에는 loss function의 연장선으로 우리의 classifier가 좀 더 보편적인 판단을 하도록 만들어 줄 Regularization을 살펴보겠습니다.

Comments