公開日:2019-08-15
最終更新日:2019-08-26
最終更新日:2019-08-26
skl07-0:準備
次のcurry2.csvは第5章で用いたBobのカレーに対する評価履歴データである.以下のデータをdataディレクトリに配置したうえで,次のコードを実行しよう.
curry2.csv:カレーに対する評価履歴データ(データIDid,辛さspicy {0-100},とろみthickness {0-100},評価値rating {0=嫌い, 1=好き})
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
>>> import numpy as np >>> import pandas as pd >>> >>> # データの読込み >>> curry2 = pd.read_csv('data/curry2.csv', index_col=0) >>> >>> feature_names = np.array(curry2.columns[:-1]) >>> target_names = ['dislike', 'like'] >>> >>> curry2_X = np.array(curry2[feature_names]) >>> curry2_y = np.array(curry2['rating']) >>> >>> # SVCオブジェクトの生成 >>> from sklearn.svm import SVC >>> svc = SVC(kernel='rbf') |
skl07-1:パラメタ検証範囲の設定
Cとgammaのパラメタ検証範囲をそれぞれ10^-6, 10^-5, ... 10^5, 10^6とし,それぞれCs, gammasとしよう.
難易度:★★
| ミッション | 説明 |
|---|---|
| 1 | numpy.logspace()関数を使う. |
| 2 | Cのパラメタ検証範囲をCsとする. |
| 3 | gammaのパラメタ検証範囲をgammasとする. |
skl07-2:
GridSearchCVのインポートGridSearchCVをインポートしよう.
難易度:★
| ミッション | 説明 |
|---|---|
| 1 | GridSearchCVをインポートする. |
skl07-3:
GridSearchCVオブジェクトの生成GridSearchCVオブジェクトを生成しよう.ここで,パラメタはestimator=svc, param_grid=dict(C=Cs, gamma=gammas)とする.
難易度:★★
| ミッション | 説明 |
|---|---|
| 1 | GridSearchCV()コンストラクタを呼び出す. |
| 2 | estimatorパラメタを指定する. |
| 3 | param_gridパラメタを指定する. |
| 4 | 生成したGridSearchCVオブジェクトをgsとする. |
skl07-4:グリッドサーチの実行
データセットを基にgsにより学習しよう.
難易度:★★
| ミッション | 説明 |
|---|---|
| 1 | GridSearchCV.fit()メソッドを使う. |
skl07-5:パラメタ
Cの最適値の取得グリッドサーチにより得られたCの最適値を取得しよう.
難易度:★★
| ミッション | 説明 |
|---|---|
| 1 | GridSearchCV.best_estimator_.C属性使う. |
| 2 | 取得した最適値をC_bestとする. |
skl07-6:パラメタ
gammaの最適値の取得グリッドサーチにより得られたgammaの最適値を取得しよう.
難易度:★★
| ミッション | 説明 |
|---|---|
| 1 | GridSearchCV.best_estimator_.gamma属性使う. |
| 2 | 取得した最適値をgamma_bestとする. |
skl07-7:最適パラメタによる
SVCオブジェクトの生成グリッドサーチにより得られた最適パラメタC_bestとgamma_bestを使ってSVCオブジェクトを生成しよう.ここで,パラメタはkernel='rbf', C=C_best, gamma=gamma_bestとする.
難易度:★★
| ミッション | 説明 |
|---|---|
| 1 | SVC()コンストラクタを呼び出す. |
| 2 | kernelパラメタを指定する. |
| 3 | Cパラメタを指定する. |
| 4 | gammaパラメタを指定する. |
| 5 | 生成したSVCオブジェクトをsvc_bestとする. |
skl07-8:交差検証
svc_bestによりデータセットに対して5分割交差検証による5回分のスコアを取得しよう.ここで,パラメタはcv=k_fold, scoring='accuracy'とする.
難易度:★★
| ミッション | 説明 |
|---|---|
| 1 | KFoldをインポートする. |
| 2 | cross_val_scoreをインポートする. |
| 3 | KFold()コンストラクタを呼び出す. |
| 4 | 生成したKFoldオブジェクトをk_foldとする. |
| 5 | cross_val_score()関数を使う. |
| 5 | cvパラメタを指定する. |
| 6 | scoringパラメタを指定する. |
skl07-9:学習モデルの可視化
次のコードはsvc_bestによる学習モデルを可視化するものである.次のコードをskl07_plt.pyというファイル名で保存し,python3コマンドで実行しよう.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn.svm import SVC from sklearn.model_selection import GridSearchCV # data curry2 = pd.read_csv('data/curry2.csv', index_col=0) feature_names = np.array(curry2.columns[:-1]) target_names = ['dislike', 'like'] curry2_X = np.array(curry2[feature_names]) curry2_y = np.array(curry2['rating']) # train svc = SVC(kernel='rbf') Cs = np.logspace(-6, 6, 13) gammas = np.logspace(-6, 6, 13) gs = GridSearchCV(estimator=svc, param_grid=dict(C=Cs, gamma=gammas)) gs.fit(curry2_X, curry2_y) C_best = gs.best_estimator_.C gamma_best = gs.best_estimator_.gamma svc_best = SVC(kernel='rbf', C=C_best, gamma=gamma_best) svc_best.fit(curry2_X, curry2_y) # plot cmap_light = ListedColormap(['#CCCCFF', '#FFCCCC']) cmap_dark = ListedColormap(['#8888FF', '#FF8888']) x_min = 0 x_max = 100 y_min = 0 y_max = 100 xx, yy = np.mgrid[x_min:x_max:200j, y_min:y_max:200j] Z = svc_best.decision_function(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) plt.pcolormesh(xx, yy, Z, cmap=cmap_light) plt.contour(xx, yy, Z, colors=['k', 'k', 'k'], linestyles=['--', '-', '--'], levels=[-.5, 0, .5]) plt.scatter(svc_best.support_vectors_[:, 0], svc_best.support_vectors_[:, 1], s=80, facecolors='none', edgecolors='k') plt.scatter(curry2_X[:, 0], curry2_X[:, 1], c=curry2_y, cmap=cmap_dark, edgecolors='k') plt.title("curry2") plt.xlabel('spicy') plt.ylabel('thickness') plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.show() |
難易度:★
| ミッション | 説明 |
|---|---|
| 1 | python3コマンドでskl07_plt.pyを実行する. |