【AIプログラミング】GAN(生成的敵対ネットワーク)について勉強をする、PyTorchのチュートリアル1

AIプログラミング

いろいろアーティスティックなことができそうなGANというものに興味がでてきました。

DCGANというものがPyTorchのチュートリアルにあるね。

GAN(生成的敵対ネットワーク)というアルゴリズムでいろいろな画像を生成できたり、すごく面白そうと思いました。PyTorchにDCGANというチュートリアルがありましたので、それをもとに勉強中です。まずはGANとはどんなものなのかなど、ざっくりではありますが、チュートリアルやその他調査でまとめてみようと思います。

行なっているチュートリアルはこちらです。

こんな人の役に立つかも

・機械学習プログラミングを勉強している人

・GAN(生成的敵対ネットワーク)について大まかに知りたい人

・PyTorchでGANを勉強している人

スポンサーリンク

GANについて知る

GANは、画像を生成したりできるアルゴリズムです。生成モデルというような呼ばれ方をするようです。

具体的には、ジェネレータ(G)というデータを生成するネットワークと、ディスクリミネータ(D)という本物か偽物かを判定するネットワークが存在しており、ジェネレータは生成したデータでディスクリミネータを欺くために努力を行います。一方でディスクリミネータは本物のデータと偽物のデータを正確に見抜くように学習を行います。

GANは、2014年によって発表された教師なし学習アルゴリズムです。

また、DCGANは、GANを拡張したもので、次の論文で発表されました。

DCGANでは、バッチノーマライゼーション、ReakyReLUといった概念を取り入れることでより安定したGANを実現したそうです。 

利用用途

機械学習に役立つ利用用途としては、訓練データの水増しに利用されるらしいです。確かに、似たようで違うデータが必要な時は便利ですよね。

また、GANの進化系によって、架空の人物の顔を生成したり、様々なことができるようになってきています。エンターテイメント的な利用もされているようです。GANを利用して、今もいろいろな応用例が模索されています。

損失関数

損失関数も複雑です。DCGANの論文によると、次の損失関数になります。Gはジェネレータ、Dはディスクリミネータです。

Dのディスクリミネータは「logD(x)」を最大化しようとします。これは、本物と、Gが生成した偽物を正確に分類しようとします。一方で、Gは「log(1-D(G(z)))」を最小化しようとします。DとGの最小最大化問題というように言われています。

「ディスクリミネータ」の判定が、本物と偽物の予測結果が50%の予測となったときがモデルの均衡となるそうです。

チュートリアルでいきなり数式はちょっととっつきにくいですね。

実装してみると動き方がわかるかもしれないね。

訓練は、「ジェネレータ」「ディスクリミネータ」それぞれに対して順番に行われるようです。実際に実装しながら勉強を進めていきたいと思います。

DCGANについて

DCGANは、GANの拡張的なもので、論文で、明示的に構成が定義されています。

構成など

「ジェネレータ」と「ディスクリミネータ」という2つのモデルで構成されています。
※以下、画像を生成するように話を進めます。

ジェネレータ

ジェネレータは、訓練画像のように見える「偽」の出力を生成することです。

チュートリアルでは、説明上、「G」と表記されます。

畳み込み層、バッチノーマライゼーション層、ReLU活性化関数で構成されています。

ジェネレータに入力するデータである、「z」は入力するベクトルになります。「潜在ベクトル」と訳すことができるようなのですが、彫刻でいう削る前の原木のようなイメージでしょうか。「z」というベクトルを与えて、それがジェネレータによって加工されて画像になっていきます。

ディスクリミネータ

「ディスクリミネータ」の仕事はジェネレータの出力を本物か偽物判定することです。判定とは、分類問題の2値出力と同じようなイメージです。
ディスクリミネータ自体は、確率を出力する、分類器のニューラルネットワークです。説明上、「D」と表記されます。

畳み込み層、バッチノーマライゼーション層、LeakyReLUの活性化関数で構成されるネットワークです。

入力画像として、3×64×64(3チャンネルの画像)とします。

出力としては、スカラー確率で出力されます。偽物のデータか、本物のデータかが「本物データである確率」という形で出力されることになります。

ちょっと図にしてみました。

チュートリアルでは、論文に基づいたDCGANの構成、パラメータ設定を元に実装をしていくようです。実装が楽しみです。

バッチノーマライゼーションとは?

チュートリアルでは、バッチノーマライゼーションという言葉が普通に出てきて、なんとなく違和感なく読み進めていたのですが、まだ何も知らないということに気づきました 笑

バッチノーマライゼーションとは、データの平均を0にして分散を1にするとのことらしいです。CNNでは、プーリングを行なっていたところをバッチノーマライゼーションに変更しています。バッチノーマライゼーションを行うことがDCGANでとても重要な発見だったらしいですね。

まだこれをすると良いのか、という程度の理解に止まっています。

タイトルとURLをコピーしました