【AIプログラミング】GAN(生成的敵対ネットワーク)について勉強をする、PyTorchのチュートリアル3

AIプログラミング

今回は、ジェネレータのところまで進みました。

あまり進んでないね・・・

今回もPyTorchのDCGANチュートリアルを進めました。実装に入って気がついたのですが、DCGANで使われている畳み込みがどうやら今までCNNで利用してきた畳み込みとは一味違うものらしくそこも理解する必要がありました。色々調べていたのですが、いまいち内容の理解まで至っていません^^;

また、チュートリアルの翻訳+私の気持ちが入っているという微妙に読みにくい内容となっています・・・

前回の記事はこちらです。

スポンサーリンク

重み初期化関数の定義

DCGANの論文によると、すべてのモデルの重みの平均が「0」標準偏差が「0.02」の正規分布からランダムに初期化することをと指定しています。

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

nnのinit.normal_というメソッドで、「平均0、標準偏差0.02の正規分布」から取得した値でテンソルを埋めることができます。

「normal_(テンソル,平均,標準偏差)」

ジェネレータ

畳み込み層が「ConvTranspose2d」というものを利用しています。ただの畳み込みではなく、「転置畳み込み」という処理を行っているようです。

翻訳で「転置」ってついてきていたから少しだけ気になっていたんですが・・・

転置畳み込みは、畳み込みとは反対のデコードのようなイメージの畳み込みとのことです。

一般的なCNNなどに利用される畳み込みは、

「画像データ→特徴を抽出したデータ」

というような変換が行われます。

転置畳み込みは、反対に、

「何かしらのデータ→画像データ」

へと変換をするような畳み込み処理となります。

また、一般的なCNNの畳み込みとは違い、プーリング層の代わりにバッチノーマライゼーションという処理が行われています。そして、最後は全結合のニューラルネットワークは存在していません。最後の活性化関数は「tanh()」を利用しています。

#ジェネレータ
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            #ここから図のConv1?
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

ConvTranspose2dは次のような入力パラメータをとります。

ConvTranspose2d(入力チャンネル数,出力,フィルターサイズ,ストライド,パディング)となっているようです。

「nz = 100」「ngf = 64」「nc = 3」

のパラメータを参考に図の中のデータの流れを追いましたが、よくわかりませんでした^^;

CNNの時と同様、絵の数値とはあっていないのでしょうか・・・

チュートリアルのジェネレータの処理イメージ

CNNの時もそうでしたが、絵にとらわれ過ぎて先に進めませんでしたので、ほどほどにしておきます。

これで、ジェネレータをインスタンス化して、weights_init関数を適用できます。印刷されたモデルをチェックして、ジェネレーターオブジェクトの構造を確認ししてみます。

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)

初期化の際、「netG.apply(weight_init)」というように行う点が慣れていない書き方でした。

ジェネレータの出力結果はこのようになりました。

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

参考にさせていただいた資料

今回は、ジェネレータのプログラミング部分について勉強しました。内容について、転置畳み込みや、バッチノーマライゼーションといった概念が出てきました。この点は、次の資料が大いに参考になりました。

[slideshare id=188544721&doc=mlstudymtgforgan1-191030083556]

最近出番少ない・・・

タイトルとURLをコピーしました