Partial Dependence Plot
어떤 black-box model 내부에서 각 feature가 prediction에 어떻게 영향을
주는지를 시각적으로 보여주는 기법이다.
기본 원리
validation dataset에서 구하고자하는 feature의 백분위수(percentile)를 전체적으로 균등하게 추출한다음 각 percentile값만을
feature의 값으로 가지는 여러가지 dataset을 만든다. 예를들면 11개의
백분위수로 나누면 feature값으로 0%백분위수만을 가지는 dataset, 10%
백분위수만을 가진 dataset, ... 100% 분위수만을 가진 데이터셋 이렇게 11개의
dataset이 만들어진다. 이제 각 dataset을 이용해 prediction을 수행하면
해당 feature값이 변화함에 따라 predicted target이 어떻게 변하는지 알 수가
있다.
|
"YearBuilt"값의 변화에 대한 target response의 변화를 나타낸다. |
좌측에서 첫번째 점이 첫 dataset으로 prediction을 수행한 결과로 이 값이 y축의
기준점(0)이 된다. 예제에서는 "YearBuilt"값이 1880일때의 예측값이 기준이다.
각각의 점들은 각 dataset으로 예측을 수행했을때의 값을 나타내고 파란영역은 그
분산을 나타낸다. 예제를 보면 알듯이, PDP를 통해 대략적인 상관관계를 유추
가능하다. 이 model에서 target은 'YearBuilt'와 강한 양의 상관관계를 가지고
있음을 알수 있다. 예제에서는 아니지만, 2차함수나 3차함수의 모양을한
non-linear한 상관관계를 보여줄 수도 있다.
2D PDP
PDP를 각각의 feature에 대해 모두 구할수 있는만큼, 여러개의 feature들을 변수로
주었을 때의 target response의 변화도 구할 수 있다. 이론적으로는 model의 전체
feature들의 집합의 모든 부분집합에 대한 PDP를 구할 수 있지만, 시각적으로
적절히 표현하기에는 원소가 2개인 부분집합이 최대 한계라고 할 수있다. 그 원소
2개짜리 부분 집합에대한 PDP가 2D PDP다. 2D PDP는 해당 feature쌍이 target
response에 어떻게 영향을 주는지를 나타낸다.
|
색으로 구분된 영역으로 '1stFlrSF', '2ndFlrSF'쌍에 대한 target response값을 나타낸다. |
'1stFlrSF'와 '2ndFlrSF'의 합이 일정한 부분을 따라 선을 긋는다면 그 선에 영역이 나란히 나타남을 알 수가 있다. 완벽하지는 않지만 '1stFlrSF'와 '2ndFlrSF'의 합에 target response가 비례한다는 것을 알 수가 있다.
Permutation Importance와의 차이
Permutation Importance : Model의 각 feature가 prediction에 얼만큼의
영향을 끼치는가?
= 각 feature와 target간의 상관관계 세기
Partial Dependence Plot : Model의 각 feature가 prediction에
어떻게 영향을 끼치는가?
= 각 feature와 target간의 상관관계
장, 단점 : Permutation Importance와 같다
장점
- 별도의 재학습(re-training)과정이 필요가 없어서 빠르다. 학습은 예측보다 시간이 훨신 많이 걸린다.
단점
- under or over-fitting된 model에 대해서는 model의 feature importance가 실제 feature-target간의 상관관계와는 상이할 수 있다. 언제까지나 model에 종속적인 속성임을 알아야한다.
- 각 feature들이 모두 독립변수라는 가정을 한다. 이는 Model들이 실제로 그런 가정을 기반으로 학습되기 때문이기도 하다. 때문에 실제 dataset에서는 feature들간의 상관관계가 존재함에 따라 target과의 상관관계가 다를 수 있다.
적용
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from matplotlib import pyplot as plt | |
from pdpbox import pdp | |
#my_model, val_X, feature_list | |
# isolate pdp | |
pdp_iso = pdp.pdp_isolate(model=my_model, dataset=val_X, model_features=feature_list, feature='feature1') | |
pdp.pdp_plot(pdp_iso, 'feature1') | |
plt.show() | |
# interact pdp | |
pdp_inter = pdp.pdp_interact(model=my_model, dataset=val_X, model_features=feature_list, features=['feature1', 'feature2']) | |
pdp.pdp_interact_plot(pdp_interact_out=pdp_inter, feature_names=['feature1', 'feature2'], plot_type='contour') | |
plt.show() |
댓글
댓글 쓰기