torch.stackの挙動が気になりましたので、いろいろと触ってみます。
テンソルの軸という部分が混乱しますね。
PyTorchのチュートリアルをやってきて、自在にPyTorchを操るためには、テンソルのデータ形式について感覚をつかむことが重要な気がしています。今回は、torch.stackについていろいろと触り、テンソルがどのように結合されるのかを見てみました。
torch.stackの公式ドキュメントはこちらです。
stackの基本的な動作を確認
torchをimportしてstackが利用できます。
import torch
#2×3の行列をテンソルとして準備
x = torch.randn(2, 3)
x
tensor([[ 0.7329, -1.1706, -0.3256],
[-0.3038, 1.0603, 1.0140]])
#リストを作成します。
x_list = []
x_list.append(x)
x_list.append(x)
x_list.append(x)
torch.stackに与えるテンソルは、全て同じサイズである必要があります。 今回はすべて同じテンソルの複製なので大丈夫です。
dim=0
まずは、デフォルトでdim=0の設定での結合です。
print(torch.stack(x_list))
torch.stack(x_list).size()
tensor([[[ 0.7329, -1.1706, -0.3256],
[-0.3038, 1.0603, 1.0140]],
[[ 0.7329, -1.1706, -0.3256],
[-0.3038, 1.0603, 1.0140]],
[[ 0.7329, -1.1706, -0.3256],
[-0.3038, 1.0603, 1.0140]]])
torch.Size([3, 2, 3])
2階テンソルの行列を軸0で結合すると、次のイメージのように、テンソルが結合されました。
単純に、画像枚数として一つのデータにまとめるイメージを持つことができます。
dim=1
次に、dim=1の場合を見てみます。
print(torch.stack(x_list, dim=1))
torch.stack(x_list, dim=1).size()
tensor([[[ 0.7329, -1.1706, -0.3256],
[ 0.7329, -1.1706, -0.3256],
[ 0.7329, -1.1706, -0.3256]],
[[-0.3038, 1.0603, 1.0140],
[-0.3038, 1.0603, 1.0140],
[-0.3038, 1.0603, 1.0140]]])
torch.Size([2, 3, 3])
テンソルのサイズの1の以外をすでにあるサイズで固定して新しいテンソルに結合します。
dim=2
最後にdim=2です。
print(torch.stack(x_list, dim=2))
torch.stack(x_list, dim=2).size()
tensor([[[ 0.7329, 0.7329, 0.7329],
[-1.1706, -1.1706, -1.1706],
[-0.3256, -0.3256, -0.3256]],
[[-0.3038, -0.3038, -0.3038],
[ 1.0603, 1.0603, 1.0603],
[ 1.0140, 1.0140, 1.0140]]])
torch.Size([2, 3, 3])
軸の方向がさっきと変わりましたが動作としては似ています。
やっぱり、軸が0以外の時は非常にわかりにくいです・・・
カラー画像データ
RGB3チャンネルのカラー画像データを軸0でまとめてみます。
一般的なカラー画像では、赤(R)のみを0~255で表現する画像、緑(G)のみを0~255で表現する画像、青(B)のみを0~255で表現する画像の3枚(3チャンネル)の画像をひとつにしてカラー画像を表現しています。(透明情報のアルファが加わって4チャンネルとなる場合もあります。)
PyTorchのテンソルでこのカラー画像1枚を表現すると、次のようなテンソルのサイズになります。画像サイズは2×3のサイズとしています。
画像ちっさいですね・・・
3チャンネルの2×3のカラー画像を次のようにテンソルで表現できます。
rgb = torch.randn(3, 2, 3)
rgb
tensor([[[ 1.0691, 0.0893, -0.9815],
[-0.6828, -1.6071, 0.4074]],
[[ 1.1217, 1.9044, 0.1318],
[ 0.0065, 1.7577, -0.4926]],
[[ 0.2009, 0.2247, -0.6303],
[-1.0737, 0.1366, -1.2880]]])
今回は、4枚の画像を結合して一つのテンソルにします。
#リストを作成します。
rgb_list = []
rgb_list.append(rgb)
rgb_list.append(rgb)
rgb_list.append(rgb)
rgb_list.append(rgb)
stackでdim=0で結合します。
print(torch.stack(rgb_list))
torch.stack(rgb_list).size()
tensor([[[[ 1.0691, 0.0893, -0.9815],
[-0.6828, -1.6071, 0.4074]],
[[ 1.1217, 1.9044, 0.1318],
[ 0.0065, 1.7577, -0.4926]],
[[ 0.2009, 0.2247, -0.6303],
[-1.0737, 0.1366, -1.2880]]],
[[[ 1.0691, 0.0893, -0.9815],
[-0.6828, -1.6071, 0.4074]],
[[ 1.1217, 1.9044, 0.1318],
[ 0.0065, 1.7577, -0.4926]],
[[ 0.2009, 0.2247, -0.6303],
[-1.0737, 0.1366, -1.2880]]],
[[[ 1.0691, 0.0893, -0.9815],
[-0.6828, -1.6071, 0.4074]],
[[ 1.1217, 1.9044, 0.1318],
[ 0.0065, 1.7577, -0.4926]],
[[ 0.2009, 0.2247, -0.6303],
[-1.0737, 0.1366, -1.2880]]],
[[[ 1.0691, 0.0893, -0.9815],
[-0.6828, -1.6071, 0.4074]],
[[ 1.1217, 1.9044, 0.1318],
[ 0.0065, 1.7577, -0.4926]],
[[ 0.2009, 0.2247, -0.6303],
[-1.0737, 0.1366, -1.2880]]]])
torch.Size([4, 3, 2, 3])
サイズの0番目が4となって、4枚の画像を一つのテンソルとしてまとめることができました。このようにして、ミニバッチのような画像の塊としてニューラルネットでは処理されていきます。
GitHubにプログラムを配置しますので、ご利用ください。