【PyTorchチュートリアル】DCGANの損失関数や訓練について考える

AIプログラミング

DCGAN、チュートリアルはやったけれど、まだ雲をつかむイメージです。

腑に落ちてない感じだね。

PyTorchのDCGANチュートリアルではプログラムを実行してみて確かめる、という点で何となくそんなものか~程度の理解にとどまっていた、DCGANの損失関数や訓練について考えてみました。チュートリアルではいきなりLogを交えた数式の定義がでてきましたので、少し躊躇していましたが、プログラムの実装面からもう少し詳しく損失と訓練について考察をしてみました。

チュートリアルページはこちらです。

スポンサーリンク

チュートリアルのプログラム

今回、数式で考えるのではなく、実装からどのようになっているかを考察しました。実装のプログラムは次の部分です。

for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

ディスクリミネータの訓練の処理

ディスクリミネータは、本物の画像かジェネレータが生成した偽の画像かを判別する2クラス分類器のニューラルネットワークになります。これが、プログラムの実装だと、

①「Train with all-real batch」パート

②「Train with all-fake batch」パート

に分かれています。分かれているだけで、実際には「通常の2クラス分類問題」と何ら変わりがないことに気づくまで時間がかかってしまいました。

なぜこのように①と②に分かれているかというと、ラベル「0」である偽画像はその場でジェネレータに生成させるためです。

まずはすでに準備してある本物画像をラベル「1」として入力、誤差を計算し、勾配を算出します。次に、ジェネレータに画像を生成させてラベル「0」として入力し、誤差を計算、勾配を算出し計算しているだけのことでした。

また、算出した勾配は、明示的にリセットしないといけないので、本物データの勾配算出と偽物データの勾配算出をすべて行った後、stepにより重みの更新を行っています。

最後の「errD」は表示用に利用されているようです。

ジェネレータの処理

ジェネレータの訓練では、ディスクリミネータも利用します。

既にディスクリミネータの訓練で「fake」という画像を作成しているため、このままディスクリミネータにfakeを入力して答えを得ます。

fakeはディスクリミネータのところで「fake = netG(noise)」として作成した画像になります。

この時に、すでにディスクリミネータは訓練を一回している状態のため、新しいディスクリミネータに対してジェネレータの画像を判定することになります。ジェネレータはこの判定結果をもとに訓練をします。

ジェネレータはディスクリミネータの判定を得るという点以外は、非常にシンプルに訓練を行っています。

①ディスクリミネータにfakeを入力して判定を得ます。

②誤差を求めます

③勾配を求めます

④Gのみ重みの更新を行います。

optimizerGは、最初に

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

と定義したものです。Adam最適化手法で、Gのパラメータの重みを更新する処理になります。ここで、誤差からGのみの重みを更新する点がポイントですね。

数式ではいまいち理解できなかったことも実装面からアプローチすると理解できてきた気がします。

PyTorchのDCGANのチュートリアル記事

タイトルとURLをコピーしました