Computer Vision News - June 2021
8 Computer Vision Tool Finally, we define the train function: function train(; kws...) args = Args(; kws...) @info("Loading data set") train_set, test_set = get_processed_data(args) # Define our model. We will use a simple convolutional architecture with # three iterations of Conv -> ReLU -> MaxPool, followed by a final Dense layer. @info("Building model...") model = build_model(args) # Load model and datasets onto GPU, if enabled train_set = gpu.(train_set) test_set = gpu.(test_set) model = gpu(model) # Make sure our model is nicely precompiled before starting our training loop model(train_set[1][1]) # `loss()` calculates the crossentropy loss between our prediction `y_hat` # (calculated from `model(x)`) and the ground truth `y`. We augment the data # a bit, adding gaussian random noise to our image to make it more robust. function loss(x, y) x̂ = augment(x) ŷ = model( x̂ ) return logitcrossentropy( ŷ , y) end # Train our model with the given training set using the ADAM optimizer and # printing out performance against the test set as we go. opt = ADAM (args.lr) @info("Beginning training loop...") best_acc = 0.0 last_improvement = 0 for epoch_idx in 1:args.epochs # Train for a single epoch Flux.train!(loss, params(model), train_set, opt) # Terminate on NaN if anynan(Flux.params(model)) @error "NaN params" break end # Calculate accuracy: acc = accuracy(test_set..., model)
Made with FlippingBook
RkJQdWJsaXNoZXIy NTc3NzU=