ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [분류] 결정 트리
    머신러닝 & 딥러닝 2021. 10. 13. 21:23

    1  분류 알고리즘

    분류는 학습 데이터로 주어진 데이터의 피처와 레이블값을 머신러닝 알고리즘으로 학습해 모델을 생성하고, 생성된 모델에 새로운 데이터 값이 주어졌을 때 미지의 레이블 값을 예측하는 것.

    • 나이브 베이즈 : 베이즈 통계와 생성 모델에 기반
    • 로지스틱 회귀 : 독립 변수와 종속 변수의 선형 관계성에 기반
    • 결정 트리 : 데이터 균일도에 따른 규칙 기반
    • 서포트 벡터 머신 : 개별 클래스 간의 최대 분류 마진을 효과적으로 찾음
    • 최소 근접 알고리즘 : 근접 거리를 기준으로 함
    • 신경망 : 심층 연결 기반
    • 앙상블 : 서로 다른 (또는 같은) 머신러닝 알고리즘을 결합

     

    2  결정 트리 주요 하이퍼 파라미터

    • max_depth : 트리의 최대 깊이 규정. 디폴트는 none
    • max_features : 분할 하는데 고려할 최대 피처 개수
    • min_samples_split
    • min_samples_leaf : 말단 노드가 되기 위한 최소한의 샘플 데이터 수
    • max_leaf_nodes

     

    3  Graphviz를 이용한 결정 트리 모델의 시각화

    brew install graphviz

     

     

    4  결정 트리 모델 시각화

    from sklearn.tree import DecisionTreeClassifier
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    import warnings
    warnings.filterwarnings('ignore')
    
    # DecisionTree Classifier 생성
    dt_clf = DecisionTreeClassifier(random_state=156)
    
    # 붓꽃 데이터를 로딩하고, 학습과 테스트 데이터 셋으로 분리
    iris_data = load_iris()
    X_train , X_test , y_train , y_test = train_test_split(iris_data.data, iris_data.target,
                                                           test_size=0.2,  random_state=11)
    
    # DecisionTreeClassifer 학습. 
    dt_clf.fit(X_train , y_train)

    DecisionTreeClassifier(random_state=156)

    from sklearn.tree import export_graphviz
    
    # export_graphviz()의 호출 결과로 out_file로 지정된 tree.dot 파일을 생성함. 
    export_graphviz(dt_clf, out_file="tree.dot", class_names=iris_data.target_names , \
    feature_names = iris_data.feature_names, impurity=True, filled=True)
    import graphviz
    
    # 위에서 생성된 tree.dot 파일을 Graphviz 읽어서 Jupyter Notebook상에서 시각화 
    with open("tree.dot") as f:
        dot_graph = f.read()
    graphviz.Source(dot_graph)
    import seaborn as sns
    import numpy as np
    %matplotlib inline
    
    # feature importance 추출 
    print("Feature importances:\n{0}".format(np.round(dt_clf.feature_importances_, 3)))
    
    # feature별 importance 매핑
    for name, value in zip(iris_data.feature_names , dt_clf.feature_importances_):
        print('{0} : {1:.3f}'.format(name, value))
    
    # feature importance를 column 별로 시각화 하기 
    sns.barplot(x=dt_clf.feature_importances_ , y=iris_data.feature_names)

    Feature importances: [0.025 0. 0.555 0.42 ]

    sepal length (cm) : 0.025

    sepal width (cm) : 0.000

    petal length (cm) : 0.555

    petal width (cm) : 0.420

     

     

     

     

     

    GitHub - DAWUNHAN/Machine-Learning

    Contribute to DAWUNHAN/Machine-Learning development by creating an account on GitHub.

    github.com

     

    댓글