ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [군집화] Mean Shift
    머신러닝 & 딥러닝 2021. 11. 3. 19:04

    Mean Shift

    • KDE (Kernel Density Estimation)을 이용하여 데이터 포인트들이 데이터 분포가 높은 곳으로 이동하면서 군집화를 수행
    • 별도의 군집화 개수를 지정하지 않으며 mean shift는 데이터 분포도에 기반하여 자동으로 군집화 개수를 정함
    1. 개별 데이터의 특정 반경 내에 주변 데이터를 포함한 데이터 분포도 계산
    2. 데이터 분포도가 높은 방향으로 중심점 이동
    3. 중심점을 따라 해당 데이터 이동
    4. 이동된 데이터의 특정 반경 내에 다시 데이터 분포 계산 후, 2번 3번 스텝을 반복
    5. 가장 분포도가 높은 곳으로 이동하면 더 이상 해당 데이터는 움직이지 않고 수렴
    6. 모든 데이터를 1~5까지 수행하면서 군집 중심점을 찾음.

    KDE (Kernal Density Estimation)

    • KDE는 커널 함수를 통해 어떤 변수의 확률 밀도 함수를 추정하는 방식. 관측된 데이터 각각에 커널 함수를 적용한 값을 모두 더한 뒤 데이터 건수로 나누어서 확률 밀도 함수를 추정.
    • 개별 관측 데이터들에 커널 함수를 적용한 뒤, 커널 함수들의 적용값을 모두 합한 뒤에 개별 관측 데이터의 건수로 나누어서 확률 밀도 함수를 추정하는 방식. 대표적으로 가우시안 분포 함수가 사용됨.
    • 작은 h값은 좁고 spike한 KDE로 변동성이 큰 확률 밀도 함수를 추정 (오버 피팅)
    • 큰 h값은 과도하게 smoothing 된 KDE로 단순화된 확률 밀도 함수를 추정 (언더 피팅)

    확률 밀도 추정 방법

    • 모수적 추정 (Parametric) : 데이터가 특정 데이터 분포를 따른 다는 가정 하에 데이터 분포를 찾는 방법. e.g. Gaussian Mixture.
    • 비모수적 추정 (Non-Parametric) : 데이터가 특정 분포를 따르지 않는 다는 가정 하에 밀도를 추정. 예를 들어, KDE.
      • 히스토그램

    seaborn의 distplot()을 이용하여 KDE 시각화

    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    sns.set(color_codes=True)
    
    np.random.seed(0)
    x = np.random.normal(0, 1, size=30)
    print(x)
    sns.distplot(x);
    [ 1.76405235  0.40015721  0.97873798  2.2408932   1.86755799 -0.97727788
      0.95008842 -0.15135721 -0.10321885  0.4105985   0.14404357  1.45427351
      0.76103773  0.12167502  0.44386323  0.33367433  1.49407907 -0.20515826
      0.3130677  -0.85409574 -2.55298982  0.6536186   0.8644362  -0.74216502
      2.26975462 -1.45436567  0.04575852 -0.18718385  1.53277921  1.46935877]
    
    /Users/terrydawunhan/opt/anaconda3/lib/python3.8/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
      warnings.warn(msg, FutureWarning)
    

    sns.distplot(x, rug=True)
    /Users/terrydawunhan/opt/anaconda3/lib/python3.8/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
      warnings.warn(msg, FutureWarning)
    /Users/terrydawunhan/opt/anaconda3/lib/python3.8/site-packages/seaborn/distributions.py:2103: FutureWarning: The `axis` variable is no longer used and will be removed. Instead, assign variables directly to `x` or `y`.
      warnings.warn(msg, FutureWarning)
    
    <AxesSubplot:ylabel='Density'>
    

    sns.distplot(x, kde=False, rug=True)
    /Users/terrydawunhan/opt/anaconda3/lib/python3.8/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
      warnings.warn(msg, FutureWarning)
    /Users/terrydawunhan/opt/anaconda3/lib/python3.8/site-packages/seaborn/distributions.py:2103: FutureWarning: The `axis` variable is no longer used and will be removed. Instead, assign variables directly to `x` or `y`.
      warnings.warn(msg, FutureWarning)
    
    
    

    sns.distplot(x, hist=False, rug=True);
    /Users/terrydawunhan/opt/anaconda3/lib/python3.8/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `kdeplot` (an axes-level function for kernel density plots).
      warnings.warn(msg, FutureWarning)
    /Users/terrydawunhan/opt/anaconda3/lib/python3.8/site-packages/seaborn/distributions.py:2103: FutureWarning: The `axis` variable is no longer used and will be removed. Instead, assign variables directly to `x` or `y`.
      warnings.warn(msg, FutureWarning)
    

    개별 관측데이터에 대해 가우시안 커널 함수를 적용

    from scipy import stats
    
    #x = np.random.normal(0, 1, size=30)
    bandwidth = 1.06 * x.std() * x.size ** (-1 / 5.)
    support = np.linspace(-4, 4, 200)
    
    kernels = []
    for x_i in x:
        kernel = stats.norm(x_i, bandwidth).pdf(support)
        kernels.append(kernel)
        plt.plot(support, kernel, color="r")
    
    sns.rugplot(x, color=".2", linewidth=3);

    from scipy.integrate import trapz
    density = np.sum(kernels, axis=0)
    density /= trapz(density, support)
    plt.plot(support, density);

    seaborn은 kdeplot()으로 kde곡선을 바로 구할 수 있음

    sns.kdeplot(x, shade=True);

    bandwidth에 따른 KDE 변화

    sns.kdeplot(x)
    sns.kdeplot(x, bw=.2, label="bw: 0.2")
    sns.kdeplot(x, bw=2, label="bw: 2")
    plt.legend();

    import warnings
    
    warnings.filterwarnings(action='ignore')

    사이킷런을 이용한 Mean Shift

    • make_blobs()를 이용하여 2개의 feature와 3개의 군집 중심점을 가지는 임의의 데이터 200개를 생성하고 MeanShift를 이용하여 군집화 수행
    import numpy as np
    from sklearn.datasets import make_blobs
    from sklearn.cluster import MeanShift
    
    X, y = make_blobs(n_samples=200, n_features=2, centers=3, 
                      cluster_std=0.8, random_state=0)
    
    meanshift= MeanShift(bandwidth=0.9)
    cluster_labels = meanshift.fit_predict(X)
    print('cluster labels 유형:', np.unique(cluster_labels))
    cluster labels 유형: [0 1 2 3 4 5 6 7]
    

    커널함수의 bandwidth크기를 1로 약간 증가 후에 Mean Shift 군집화 재 수행

    meanshift= MeanShift(bandwidth=1)
    cluster_labels = meanshift.fit_predict(X)
    print('cluster labels 유형:', np.unique(cluster_labels))
    cluster labels 유형: [0 1 2]
    

    최적의 bandwidth값을 estimate_bandwidth()로 계산 한 뒤에 다시 군집화 수행

    from sklearn.cluster import estimate_bandwidth
    
    bandwidth = estimate_bandwidth(X,quantile=0.25)
    print('bandwidth 값:', round(bandwidth,3))
    bandwidth 값: 1.689
    
    import pandas as pd
    
    clusterDF = pd.DataFrame(data=X, columns=['ftr1', 'ftr2'])
    clusterDF['target'] = y
    
    # estimate_bandwidth()로 최적의 bandwidth 계산
    best_bandwidth = estimate_bandwidth(X, quantile=0.25)
    
    meanshift= MeanShift(bandwidth = best_bandwidth)
    cluster_labels = meanshift.fit_predict(X)
    print('cluster labels 유형:',np.unique(cluster_labels))
    cluster labels 유형: [0 1 2]
    
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    clusterDF['meanshift_label']  = cluster_labels
    centers = meanshift.cluster_centers_
    unique_labels = np.unique(cluster_labels)
    markers=['o', 's', '^', 'x', '*']
    
    for label in unique_labels:
        label_cluster = clusterDF[clusterDF['meanshift_label']==label]
        center_x_y = centers[label]
        # 군집별로 다른 marker로 scatter plot 적용
        plt.scatter(x=label_cluster['ftr1'], y=label_cluster['ftr2'], edgecolor='k', 
                    marker=markers[label] )
    
        # 군집별 중심 시각화
        plt.scatter(x=center_x_y[0], y=center_x_y[1], s=200, color='white',
                    edgecolor='k', alpha=0.9, marker=markers[label])
        plt.scatter(x=center_x_y[0], y=center_x_y[1], s=70, color='k', edgecolor='k', 
                    marker='$%d$' % label)
    
    plt.show()

    print(clusterDF.groupby('target')['meanshift_label'].value_counts())
    target  meanshift_label
    0       0                  67
    1       2                  67
    2       1                  65
            2                   1
    Name: meanshift_label, dtype: int64
    

    댓글