PythonでTripletな組み合わせを作る
ラベルに基づいて,ある基準となるデータと,そのデータと同じクラスのデータ,異なるクラスのデータの3つのデータの組み合わせを作るPythonプログラムについて書きます.
何かしらのデータdataとdataのクラスに基づいたラベルtargetが付与されたデータセットdatasetが手元にあったとき,ラベル情報に基づいてdataの組み合わせを作りたいと思いました.
その組み合わせは,ある1つのx_anchorに対して,そのx_anchorと同じクラスのデータx_positive,x_anchorと異なるクラスのデータx_negativeで構成される3つのデータの(tripletな)組み合わせです.そしてその組み合わせをbatch_size個だけ用意するプログラムを書くときに少し苦しんだので以下で説明していきます.
この組み合わせの問題ですが,例えば箱の中にリンゴとナシが4個ずつ入っていて,リンゴとナシにはそれぞれマジックで番号が書いてあります.箱の中を見て1つ取り出してメモって,戻して...の動作を3回繰り返します.そのうち2回は同じ果物でもう1つは異なる果物じゃなきゃダメ.そしてこの動作をてきとーに10回繰り返してください.という問題.さぁ,どうしましょう?
こんな問題をプログラムで記述したいと思いました.
まず思いついたのは,条件を満たす組み合わせを書き出して,その中から10回分だったらその数だけてきとーに選ぶという手法.For文で組み合わせをかたっぱしから総当たりしてリストを作って,リストから任意の数だけ取るだけ.とても簡単.
リンゴとナシが4個ずつ存在する問題なら問題ないですが,杏子,ライチ,スイカ,メロン,イチゴ,モモ,ミカン,ブドウ,イチジク,ザクロの10種類の果物がそれぞれ1000個ずつ箱に入っている問題だと9000万通りも作れてしまい普通のパソコンではメモリに載らない.
問題によって計算時間が異なったり,プログラムを実行する環境を選ぶコードを作るのはマズいということになりました.
次に考えたのは,10回分の組み合わせがほしい場合は,10回てきとーに果物を取り出して種類と番号をメモっておく.次にすべての種類の果物を10個ずつてきとーに小箱に移しておいて,小箱から同じ種類と異なる種類のものを1つずつてきとーに取り出す作業を10回行うという作戦.これだと1つの果物が何個ずつ用意されていても小箱の中しか見ないためてきとーに選ぶ労力が減る(メモリの使用量は1つのクラスごとのデータ数に依存しない)と考えたので実装しました.
小箱に移す作業はnumpy.where()で,小箱から取り出すのはpythonの組み込みのリストのメソッドpop()を使います.同じクラスじゃない果物をランダムに1つ取り出すという作業の「じゃないものをランダムに」というコードを書くのが難しくて,二度と読みたくないようなコードを書いてしまいました.
numpy.random.choice(果物の種類, 取り出す数, それぞれの種類からどれくらいの確率で取り出すか決める確率のリスト)で3つ目の引数の確率のリストのうち,取り出したくないものだけ確率を0にしてあげると,「じゃないものを(等しい確率で)ランダムに取り出す」ことができます.
もっとスマートなtripletな組み合わせを作る方法がありそうですが,これが限界でした.
今回の実験に使ったコードと解説は一番下に載せてあります.
今回の実装でよく使ったコードを最後にまとめておきます.
import numpy
matsuge = [1, 2, 3, 4, 5]
# 使える
hoge = numpy.random.permutation(matsuge)
# 使えない
hoge = numpy.random.shuffle(matsuge)
# 使える.matsugeのリストそのものがシャッフルされる
numpy.random.shuffle(matsuge)
# matsugeの要素の値がnのインデックス群が返ってくる.タプルで返ってくる.
n=1
matsuge = [i for i in range(10) for j in range(10)]
indexes = numpy.where(matsuge == n)[0]
# リストの最後の要素をリストから削除して返す
matsuge = [1, 2, 3, 4, 5].pop()
# pのリストの確率で0~2までの値のうち1つを返す(pは省略可).今回だと0が返ってくる確率が80%,2は0%.タプルで返ってくる.
matsuge = numpy.random.choice(3, 1, p=[0.8, 0.2, 0.])[0]
最後に今回のソースコード
Make a batch size of shuffle and combination array ...