Code, Error, Tip, Etc.
2D convolution의 weight size. Channel 추가하기.
Ostin
2023. 1. 13. 14:55
2D 컨볼루션 가중치의 모양은 [output_channel, input_chnnel, kernel_size]
기존 컨볼루션 네트워크를 그대로 사용하면서 추가 채널을 받고 싶으면
모양에 맞는 추가 tensor를 만들고 weight에 합쳐주면 된다.
weight_new_channels = torch.zeros(30,10,3,3)
new_weight = torch.cat((conv_in.weight,weight_new_channels),dim=1)
conv_in.weight = torch.nn.Parameter(new_weight)