PythonでXORとUCI_Iris_datasetを解くニューラルネットワーク
前回は線形の問題を解くニューラルネットワーク(Neural Net)のコードを書きましたが、今回は、非線形のXOR、巷で有名?なIrisデータセットを識別するNeural Netのコードを書きました。
前回のコードは2入力3出力の2層のニューラルネットワークをコーディングしましたが、平たく書いてしまったために層を増やすことも層の次元数(ユニット数)も変更することが簡単ではないものとなっていました。
今回は、できるだけ汎化的に使える構造でニューラルネットワークを書いてみました。もともと、1年以上前にPythonを初めて使った時にPythonの練習として
- Hello worldの出力
- Fizz Buzz問題
- とりあえず学習するNeural Network
- クラスや継承を使った3.よりスマートなNeural Network
を順にやっていました。
今回書いたプログラムは4.をちょこちょこ修正して(過去のコードのヒドさ、読みにくさは凄まじいもの...)バッチ学習などに対応したものにしました。
次回は、今回のコードをNumpyで書き換えて、より簡素に、より処理速度が速いPythonコードを書いて紹介したいと思います。
今回のコードは以下のものになります。
ちなみに記事の最後に最後にJupyter(最近ハマってる笑)をつかったコードも載せておきます。
Python script version. It's same content as the fo ...
今回のコードで使用したデータセットは、XOR(排他的論理和)とIris datasetです。XOR問題は、論理演算の1つのアレです。
詳しくは以下のWebページを参照してください。
XORとは|排他的論理和|EOR|EX-OR|eXclusive OR - 意味/定義 : IT用語辞典
XORを識別する関数を学習することは、2入力1出力の2クラス分類問題ですが、それぞれのクラスを1本の直線でわけることができない非線形の問題となっています。
http://www.gifu-nct.ac.jp/elec/deguchi/sotsuron/niwa/node11.html
もう1つのIris(アヤメ)データセットはカリフォルニア大学アーバイン校が提供するデータセットです。
アイドルユニットの画像が詰まったデータセットではありません。
http://manasite.net/animemanga/1661/
150個のデータが入っており、4入力の3クラス分類問題となっています。配布時には教師ラベルがIrisの花の種類の名前?となっているため、機械学習の手法をこのデータセットに適用する場合は教師ラベルを設計することも最初の問題となってきます。
UCI Machine Learning Repository: Iris Data Set
(以前、Python&機械学習デビューした時にはこの教師ラベル設計で苦しめられました...)
今回は、One-hot vectorで教師ラベルを振りました。
3クラスなら[[1, 0, 0], [0, 1, 0], [0, 0, 1]]などでしょうか。
最近は、自然言語処理の方で、1つの言語で使われる記号-character(アルファベット、数字など)をone-hot vectorで表して、文章をとてもスパースな行列にembeddingして入力データを表現してCNNに突っ込んで、文章分類を行うことがホットになってますね。
今回のコードでつまったのが、CSVの読み込みに以前はPythonの組み込みのCSVをimportして使っていたのが、Pandasを使ってみたところ、Data frame型で読み込まれて、組み込みのリスト型に直して整形する部分です。
pandasのdataframe型をnumpyのarray型に変換する - 新kensuke-miの日記
標準python、numpy、pandasを行ったり来たりするために① - Qiita
今回はNumpyを使わない縛りをしていたので、「Numpyないとこれできないんだー、不便すぎる(T_T)」みたいな気持ちに何度かなっていました。
結果的に、それぞれのデータセットで学習誤差(loss)の値は下がっていき、ほぼ100%の正解率を出すことができたので、こちらのコードで正しくNeural Netを学習して、非線形問題を解くことができる識別関数を獲得できたと思います。
排他的論理和の問題の学習時の誤差の推移は以下になりました。
縦軸が誤差の値(1epoch中のそれぞれのデータのlossの合計値を出して、1つのデータの平均のlossを出していなかったことがここで発覚、ミスってすんません)。横軸が学習epochになります。
Iris data settにおける学習時の誤差の推移は以下のようになりました。
データセットは、hold outで150個中、30個のデータ(全体の20%)をランダムにテストデータとして作成して、正解率(accuracy)を出しました。コードのハイパーパラメータではこの精度が最大っぽかったです。
以上になります。
次回はこのコードをNumpyを使ってよりスマートなコードに直したものを載せたいと思います。