TensorFlowで訓練済みモデルの保存(Save)と復元(Saver)の使い方

2018年8月10日

使っているFX会社はどこ?
テクニカル分析用(チャートが使いやすい) → DMM FX
Python API用(FX APIはここしかない) → OANDA ジャパン
デイトレ用(スプレッドがとにかく狭い) → SBI FXトレード

TensorFlowで訓練したモデルの保存(Saver)の基本的な使い方まとめ。シンプルな線形回帰モデルを構築して、訓練済みモデルの保存から復元(saver.restor)を使う。

こんばんは、新米データサイエンティスト(@algon_fx)です。先日の投稿で8月はTwitterのAPIを使った自然言語処理でのFX予想をすると言っときながら・・先週末はランダムフォレストに没頭してしまいました(汗)

機械学習をやっていると、Pythonの様々なライブラリを使うことになります。当然ではありますが、それぞれライブラリ毎の操作を覚えなくてはいけないのですが、少し時間が経過すると詳細を忘れてしまい、「あれ?これってどうやるんだっけ?」と同じことを調べる羽目になります。

これは完全に個人的な意見ですが、特にTensorFlowはその傾向が強い。他の機械学習ライブラリと比較しても、少し特殊な部分というか、使う上で面倒なことが多い。

でもスクラッチで最急降下法を書くより圧倒的に便利な訳で、いつもお世話になっています笑。愚痴はさておき、今回はそんなTensorFlowの「訓練済みのモデルの保存と復元」についてです。

機械学習トレード(機械学習を使ったFXトレードを勝手に命名)やっていると、あ!あの時に試したあのモデルをちょっと今使ってみたい!ってニーズ結構あると思います。

その都度、データを読み込んで訓練させても問題ありませんが、例えば訓練データが5秒足過去10年分+テクニカル指標30個などになると、モデル訓練で多少の時間が割かれます。

まぁいずれにせよ、訓練したモデルを保存して、使いたいときに復元して使える操作を覚えておくに越したことはありません。

こんな状況を想定してみよう

とある時間帯にドル円をトレードしていたら、ドル円の1分足の安値(Low)に明らかな上昇トレンドが出てきた。下記は過去20分の1分足の安値のチャートです。

ここはデーターサイエンティストっぽく、線形回帰のモデルを構築して、このトレンドをより具体的に解析してみよう。線形モデルってのは、この赤いポイント(安値)にもっともフィットする直線を引くという意味です。

図で表すことこんな感じ。青色の直線は過去20分の安値を元に導き出した直線(線形モデル)です。対して青色のポイントは今から将来(1分〜5分先)のレートを表しています。

線形モデルを導きだして、将来のレートとの差分(緑の矢印)を計算して、誤差が大きくなれば売りのチャンス!というのが想定です。

余談ですが、線形回帰ですが非常に単純な機械学習の手法ではありますが、使い方次第では機械学習トレードで強い武器になると私は思ってます。

例えば一定期間(例:1分足を10本)とって、その都度の線形モデル(直線)を算出するとします。それぞれの期間の直線の傾きはトレンドの強さを直接的に教えてくれますし、異なる期間の線形モデルを複数算出することで、より実践的なトレードのシグナルとして使えると考えています。(注意:実際にやってみましたが、損しました!改善したら役に立つって意味です。笑)

と、少し話がそれましたが、今回の大枠の手順は下記の通りです。

(1)模擬レートのデータを作成
(2)TensorFlowで線形回帰モデルを訓練
(3)線形モデルとレートをチャートで確認
(4)訓練したモデルを保存
(5)訓練済みモデルの復元

ではやってみましょう!

STEP1 模擬レートのデータを作成

想定の通り20分足で上昇トレンドのレートを作成します。まずは必要なライブラリをインポートしましょう。

次に実際のレートのようなランダムウォークになるように、かつ上昇トレンドとなるようなレートを作成しましょう。

念のためY(安値)を確認してみます。乱数を使ってるので、各自の結果とは異なります。

作ったデータをプロットしてみましょう。

まぁそれなりの上昇トレンドの模擬データとなりました。

STEP2 TensorFlowで線形回帰モデルを訓練

データが作れましたので、いよいよ主役となるTensorFlowさんに協力をいただいて、線形回帰モデルの訓練を行いましょう。

まずは学習率、反復計算回数(エポック)、さらに実際のレートと予測したレートの誤差(コスト)を反復計算中に確認したいので、表示回数も設定してあげます。

次にTensorFlowのプレースホルダーと変数を作ります。ここで注意点が一つ!!

後に訓練ずみモデルの保存と復元を行う訳ですが、プレースホルダーと変数には「name」を指定しましょう。これを忘れると、訓練済みモデルの復元が非常に面倒になります。

続いて、線形回帰モデル、コスト関数、最急降下法(Gradient Descent)を構築しましょう。

では初期化してセッションを開始しましょう。

これで訓練の準備が整いましたので、先ほど作った模擬データを使ってモデル訓練を行います。途中、学習がちゃんと上手くいっているのか確認のためエポック200回毎にコスト(平均二乗誤差)を表示させます。

大丈夫そうですね。しっかり1000回のエポックでコストが最小に近づいているのが確認できます。参考までにですが上記で表示をしたWが線形回帰の「傾き」を表します。レートが上昇トレンドの場合は傾きの正の値となり、値が増えれば増えるほどレートが急激に上がっていることを表します。

対してレートが下降トレンドの場合、Wは負の値となり、低ければ低いどレートが急激に下がっていることを表します。機械学習的にはWを係数(Coefficient)、Bを定数項(BiasまたはY-intercept)と呼びます。

モデルの訓練も完了したので、次は実際に線形モデルとレートをチャートに落としてみましょう。

STEP3 線形モデルとレートをチャートで確認

では、訓練済みの線形モデルをチャートに落とします。

まぁ、当然ではありますが、しっかりフィットした直線を引くことができました。

STEP4 訓練したモデルを保存

では訓練した線形モデルをTensorFlowのSaver()を使って保存しましょう。コードは非常に単純です。

linearという名前でこの訓練済みモデルを保存しました。Saver()を実行すると、下記の4つのファイルが作成されます。(Jupyter Notebookでやっている方は、コードを実行したディレクト内にファイルが作成されているはずです)

*checkoint
*linera.data-00000-of-00001
*linear.index
*linear.meta

それぞれ意味のあるファイルです。詳しくはTensorFlowの公式ドキュメントで確認してみください。(tf.train.Saver

STEP5 訓練済みモデルの復元

先ほど保存した訓練済みモデルの復元(restore)をしてみましょう。Jupyter Notebookでやっている方は、試しに新しいノートブックを作ってやってみてください。

まずは必要なライブラリをインポートします。

では先ほど書き出した訓練済みモデルをrestore()を使って読み込みましょう。まずはセッションを立てて、先ほど書き出した「linear.meta」をimport_meta_graph()を使って読み込みます。

いよいよモデルの復元です。latest_checkpoint()を使って最後の状態のモデルを復元させます。

パラメータが「linear」から復元されましたと出力されています。本当に復元されたのか、確認してみましょう。

上記にある通り、get_tensor_by_name()を使って、モデル訓練の時に名前をつけた(Nameに指定した値)を指定することで、パラメータを復元させることが可能です。

では復元したW(係数)とB(定数項)を確認してみましょう。

STEP2で訓練した際に最後の反復計算をした時のWとBを表示させていました。下記が訓練済みモデルのWとBでした。

ご覧の通り復元したWとBと、訓練した際に算出したWとBの値が一致しているのが確認できます。つまり、訓練済みモデルのパラメータが正しく復元された訳です。

例えば復元したモデルで21分目の安値を予測するとしたら、下記のようにします。

そもそも線形回帰なので予測したレートにはさほど意味はありませんが、トレンドの強さや過去のトレンドと今のレートを比較する際には役に立ちます。

まとめと次の課題

今回はTensorFlowを使ってシンプルな線形回帰モデルの構築を行い、Saver()を使って訓練済みモデルの保存と復元を行う流れを行いました。

記事中でも触れましたが、線形モデルはとてもシンプルな機械学習の手法ですが、使い方次第で強力なトレードツールになり得ると(私は)考えています。

次の課題としては…ニューラルネットワークを使ったFX予想(初心者版)をしっかり書き上げます笑。実は先月半ばくらいから、時間を見つけて少しづつ書いていたのですが、、ニューラルネットワークの説明に踏み込み過ぎて…もはや収拾がつかない状況に…汗

以上となります!ブログ読んでいただきありがとうございます!Twitterでも色々と発信しているので、是非フォローお願いします!

【人気記事】

私の機械学習の開発環境&トレード環境

2018年8月10日機械学習チュートリアル

Posted by algon