본문 바로가기

Code, Error, Tip, Etc.

2D convolution의 weight size. Channel 추가하기.

2D 컨볼루션 가중치의 모양은 [output_channel, input_chnnel, kernel_size]

 

기존 컨볼루션 네트워크를 그대로 사용하면서 추가 채널을 받고 싶으면

모양에 맞는 추가 tensor를 만들고 weight에 합쳐주면 된다.

weight_new_channels = torch.zeros(30,10,3,3)

new_weight = torch.cat((conv_in.weight,weight_new_channels),dim=1)
conv_in.weight = torch.nn.Parameter(new_weight)