公開日:2019-08-15
最終更新日:2019-08-26
最終更新日:2019-08-26
skl07-0:準備
次のcurry2.csv
は第5章で用いたBobのカレーに対する評価履歴データである.以下のデータをdata
ディレクトリに配置したうえで,次のコードを実行しよう.
curry2.csv
:カレーに対する評価履歴データ(データIDid
,辛さspicy {0-100}
,とろみthickness {0-100}
,評価値rating {0=嫌い, 1=好き}
)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
>>> import numpy as np >>> import pandas as pd >>> >>> # データの読込み >>> curry2 = pd.read_csv('data/curry2.csv', index_col=0) >>> >>> feature_names = np.array(curry2.columns[:-1]) >>> target_names = ['dislike', 'like'] >>> >>> curry2_X = np.array(curry2[feature_names]) >>> curry2_y = np.array(curry2['rating']) >>> >>> # SVCオブジェクトの生成 >>> from sklearn.svm import SVC >>> svc = SVC(kernel='rbf') |
skl07-1:パラメタ検証範囲の設定
C
とgamma
のパラメタ検証範囲をそれぞれ10^-6, 10^-5, ... 10^5, 10^6
とし,それぞれCs
, gammas
としよう.
難易度:★★
ミッション | 説明 |
---|---|
1 | numpy.logspace() 関数を使う. |
2 | C のパラメタ検証範囲をCs とする. |
3 | gamma のパラメタ検証範囲をgammas とする. |
skl07-2:
GridSearchCV
のインポートGridSearchCV
をインポートしよう.
難易度:★
ミッション | 説明 |
---|---|
1 | GridSearchCV をインポートする. |
skl07-3:
GridSearchCV
オブジェクトの生成GridSearchCV
オブジェクトを生成しよう.ここで,パラメタはestimator=svc, param_grid=dict(C=Cs, gamma=gammas)
とする.
難易度:★★
ミッション | 説明 |
---|---|
1 | GridSearchCV() コンストラクタを呼び出す. |
2 | estimator パラメタを指定する. |
3 | param_grid パラメタを指定する. |
4 | 生成したGridSearchCV オブジェクトをgs とする. |
skl07-4:グリッドサーチの実行
データセットを基にgs
により学習しよう.
難易度:★★
ミッション | 説明 |
---|---|
1 | GridSearchCV.fit() メソッドを使う. |
skl07-5:パラメタ
C
の最適値の取得グリッドサーチにより得られたC
の最適値を取得しよう.
難易度:★★
ミッション | 説明 |
---|---|
1 | GridSearchCV.best_estimator_.C 属性使う. |
2 | 取得した最適値をC_best とする. |
skl07-6:パラメタ
gamma
の最適値の取得グリッドサーチにより得られたgamma
の最適値を取得しよう.
難易度:★★
ミッション | 説明 |
---|---|
1 | GridSearchCV.best_estimator_.gamma 属性使う. |
2 | 取得した最適値をgamma_best とする. |
skl07-7:最適パラメタによる
SVC
オブジェクトの生成グリッドサーチにより得られた最適パラメタC_best
とgamma_best
を使ってSVC
オブジェクトを生成しよう.ここで,パラメタはkernel='rbf', C=C_best, gamma=gamma_best
とする.
難易度:★★
ミッション | 説明 |
---|---|
1 | SVC() コンストラクタを呼び出す. |
2 | kernel パラメタを指定する. |
3 | C パラメタを指定する. |
4 | gamma パラメタを指定する. |
5 | 生成したSVC オブジェクトをsvc_best とする. |
skl07-8:交差検証
svc_best
によりデータセットに対して5分割交差検証による5回分のスコアを取得しよう.ここで,パラメタはcv=k_fold, scoring='accuracy'
とする.
難易度:★★
ミッション | 説明 |
---|---|
1 | KFold をインポートする. |
2 | cross_val_score をインポートする. |
3 | KFold() コンストラクタを呼び出す. |
4 | 生成したKFold オブジェクトをk_fold とする. |
5 | cross_val_score() 関数を使う. |
5 | cv パラメタを指定する. |
6 | scoring パラメタを指定する. |
skl07-9:学習モデルの可視化
次のコードはsvc_best
による学習モデルを可視化するものである.次のコードをskl07_plt.py
というファイル名で保存し,python3
コマンドで実行しよう.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn.svm import SVC from sklearn.model_selection import GridSearchCV # data curry2 = pd.read_csv('data/curry2.csv', index_col=0) feature_names = np.array(curry2.columns[:-1]) target_names = ['dislike', 'like'] curry2_X = np.array(curry2[feature_names]) curry2_y = np.array(curry2['rating']) # train svc = SVC(kernel='rbf') Cs = np.logspace(-6, 6, 13) gammas = np.logspace(-6, 6, 13) gs = GridSearchCV(estimator=svc, param_grid=dict(C=Cs, gamma=gammas)) gs.fit(curry2_X, curry2_y) C_best = gs.best_estimator_.C gamma_best = gs.best_estimator_.gamma svc_best = SVC(kernel='rbf', C=C_best, gamma=gamma_best) svc_best.fit(curry2_X, curry2_y) # plot cmap_light = ListedColormap(['#CCCCFF', '#FFCCCC']) cmap_dark = ListedColormap(['#8888FF', '#FF8888']) x_min = 0 x_max = 100 y_min = 0 y_max = 100 xx, yy = np.mgrid[x_min:x_max:200j, y_min:y_max:200j] Z = svc_best.decision_function(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) plt.pcolormesh(xx, yy, Z, cmap=cmap_light) plt.contour(xx, yy, Z, colors=['k', 'k', 'k'], linestyles=['--', '-', '--'], levels=[-.5, 0, .5]) plt.scatter(svc_best.support_vectors_[:, 0], svc_best.support_vectors_[:, 1], s=80, facecolors='none', edgecolors='k') plt.scatter(curry2_X[:, 0], curry2_X[:, 1], c=curry2_y, cmap=cmap_dark, edgecolors='k') plt.title("curry2") plt.xlabel('spicy') plt.ylabel('thickness') plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.show() |
難易度:★
ミッション | 説明 |
---|---|
1 | python3 コマンドでskl07_plt.py を実行する. |