Computer Vision News - January 2020
3 Summary PyTo ch 5 We start by loading a pretrained VGG model that will give us the neural representation of an image. Then, since our trainable parameters are the image pixels, we freeze all the layers in order not to train them. We have an easy way to do it with PyTorch: Another feature that is easy to use in PyTorch is the assignment of computations toGPU.With PyTorchwe can assign each computation in our network to a specific device. For example, here, we will assign all the computation of VGG features to the GPU: For the style transfer, we need to define several utility functions. At the end of this article you will find: tensor_to_image, transformation, and correlation_ matrix, which are quite technical functions. One very interesting function is get_ features : this function takes themodel and an image as input, forward passes the image into the network and outputs the intermediate levels. Each intermediate layer in the VGG architecture has its title so we simply call it by names. In Keras/ TensorFlow this is a quite complex thing to do, but with PyTorch it becomes simply: vgg = models . vgg16(pretrained = True ) . features for param in vgg . parameters(): param . requires_grad_( False ) os . environ[ "CUDA_VISIBLE_DEVICES" ] = '0' device = torch . device( "cpu" ) if torch . cuda . is_available(): device = torch . device( "cuda" ) print ( 'working on GPU' ) vgg . to(device) def get_features (image, model): layers = { '0' : 'conv1_1' , '5' : 'conv2_1' , '10' : 'conv3_1' , '19' : 'conv4_1' , '21' : 'conv4_2' , '28' : 'conv5_1' } x = image features = {} for name, layer in model . _modules . items(): x = layer(x) if name in layers: features[layers[name]] = x return features Basically, we now have everything we need to do the style transfer. We just need to load the content image, which is img1, and the style image, which is img2, and normalize and resize it to fit our network. Note that we also assign them to the GPU with a simple command:
Made with FlippingBook
RkJQdWJsaXNoZXIy NTc3NzU=