This code visualizes model performance results from the MERMAID image classification model.
Loading package libraries and setting parameters
The following libraries and parameters are used throughout the code.
Show the code
rm(list =ls()) #remove past stored objectsoptions(scipen =999) #turn off scientific notation#Package librarieslibrary(readxl)library(tidyverse)library(ggplot2)library(mermaidr)library(ggtext)library(xfun)library(plotly) #for turning the confusion matrices into interactive plotslibrary(htmlwidgets) #saving interactive plots as html files
Loading results files
Show the code
### Report with the overall performanceoverallClassReportTBL <-read_excel(path ="../data/classifier_metrics.xlsx",sheet ="Metrics")### Report with the performance per label (UUIDs)labelClassReportTBL <-read_excel(path ="../data/classifier_metrics.xlsx",sheet ="Classification_Report")### Label mapping --> UUID to characterlabelMapTBL <-read_csv(file ="../data/LabelMap.csv")### MERMAID benthic attributesbenthicAttTBL <-mermaid_get_reference(reference ="benthicattributes") %>%select(name, parent) %>%rename(ba = name)### User testing results (for confusion matrices)userResTBL <-read_csv(file ="../data/anonymizedUserTestingResults.csv")
Prepare data
Show the code
overallClassReportTBL <- overallClassReportTBL %>%filter(...1%in%c("precision", "recall", "f1_score")) %>%pivot_wider(names_from = ...1,values_from =`0.0`) %>%mutate(CoralFocus3Label ="<b>Overall</b>",ba ="Overall",tlc ="Overall") %>%rename(`f1-score`= f1_score)labelClassReportTBL <- labelClassReportTBL %>%rename(label_id = ...1) %>%filter(!label_id %in%c("accuracy", "weighted avg", "macro avg")) %>%left_join(labelMapTBL, by ="label_id") %>%select(CoralFocus3Label, `f1-score`, precision, recall)#### Get the top level categories for each of the labels## Extract just the benthic attribute from the labellabelClassReportTBL <- labelClassReportTBL %>%mutate(ba =gsub(pattern =paste(" -",c("Branching","Foliose","Encrusting","Plates or tables","Massive","Digitate"),collapse ="|"),replacement ="",x = CoralFocus3Label))### Get all the unique benthic attributes and assign them to the top level categoriesuniqueBaParentTBL <- labelClassReportTBL %>%select(ba) %>%distinct() # Function to find the top-level categoryfind_top_level <-function(ba, lookup_table) { parent <- lookup_table$parent[lookup_table$ba == ba]if (is.na(parent)) {return(ba) # If no parent, the current ba is the top-level category } else {# Recursively find the top-level categoryreturn(find_top_level(parent, lookup_table)) }}uniqueBaTlcTBL <- uniqueBaParentTBL %>%rowwise() %>%mutate(tlc =find_top_level(ba, benthicAttTBL)) %>%ungroup()labelClassReportTBL <- labelClassReportTBL %>%left_join(uniqueBaTlcTBL, by ="ba")##Get the maximum f1 score by tlc to order tlc as a factorf1ScoreByTlcTBL <- labelClassReportTBL %>%group_by(tlc) %>%summarise(max_f1_score =max(`f1-score`)) %>%ungroup() %>%arrange(desc(max_f1_score))allClassReportTBL <-bind_rows(labelClassReportTBL, overallClassReportTBL) ### Add an asterisk to all the user labels that are not represented in the modeluserResTBL <- userResTBL %>%mutate(ba_user =ifelse(ba_user %in% labelClassReportTBL$ba,as.character(ba_user),paste0(ba_user, "*")))
Plot the results - horizontal barplots by label
Horizontal barplot with color gradient
Show the code
# Reshape data for faceted plottinglongAllClassReportTBL <- allClassReportTBL %>%pivot_longer(cols =c(precision, recall, `f1-score`),names_to ="Metric", values_to ="Score") %>%mutate(Group =ifelse(CoralFocus3Label =="<b>Overall</b>", "Overall", "Label"))# Order labels by F1-score label_order <- labelClassReportTBL %>%arrange(`f1-score`) %>%pull(CoralFocus3Label)tlc_order <- f1ScoreByTlcTBL %>%pull(tlc)longAllClassReportTBL <- longAllClassReportTBL %>%mutate(CoralFocus3Label =factor(CoralFocus3Label,levels =c(label_order, "<b>Overall</b>")),Metric =factor(Metric, levels =c("f1-score", "precision", "recall")),Group =factor(Group, levels =c("Overall", "Label")),tlc =factor(tlc, levels =c("Overall", tlc_order)))# Create a plot with group as the row facetsggplot(longAllClassReportTBL,aes(x = Score, y = CoralFocus3Label, fill = Score)) +facet_grid(Group ~ Metric, scales ="free_y", space ="free_y") +geom_col(alpha =0.75) +geom_vline(xintercept =c(0, 0.25, 0.5, 0.75, 1), linetype ="dotted") +scale_fill_viridis_c(option ="D", direction =-1) +scale_x_continuous(breaks =c(0, 0.25, 0.5, 0.75, 1),labels =c(0, 0.25, 0.5, 0.75, 1)) +labs(x ="Score",y =NULL,title ="MERMAID Classification Model Performance",subtitle ="Overall results and by label") +theme_minimal(base_size =12) +theme(strip.text =element_text(face ="bold", size =14),strip.text.y.right =element_blank(),axis.text.x =element_text(size =10),axis.ticks.x =element_line(),axis.title.x =element_text(size =12, face ="bold", margin =margin(t =10)),axis.text.y =element_markdown(size =10),panel.spacing.x =unit(0.75, "lines"),legend.position ="none" )
Horizontal barplot with labels organized into top level categories
Confusion matrices are a visualized representation of the how frequently the model and users assign the same or different labels to points on images. Cells along the diagonal indicate agreement and off the diagonal indicate disagreement between the model assignment and users.
Confusion matrix with all labels
This confusion matrix includes all labels assigned by the model or users. The labels with asterisks were assigned by users but were not included in the model, so could not have been assigned as such.
Show the code
## Note - this uses the best model guess with no score cutoff# Get all unique labels and sort them alphabeticallyall_labels <-sort(union(unique(userResTBL$ba_user),unique(userResTBL$ba_classifier)))# Ensure all levels are included and ordered alphabeticallyconfusion_matrix <- userResTBL %>%mutate(ba_user =factor(ba_user, levels = all_labels),ba_classifier =factor(ba_classifier, levels = all_labels) ) %>%count(ba_user, ba_classifier, .drop =FALSE)# Create the ggplot heatmaptopClassAllLabelsConfusionHeatmap <-ggplot(confusion_matrix,aes(x = ba_classifier,y = ba_user,fill = n)) +geom_tile(color ="white") +scale_fill_gradientn(colors =c("white", "#6baed6", "#08306b"),name ="Count") +labs(title ="Confusion Matrix: User vs. Classifier (all labels)",x ="Classifier Labels",y ="User Labels") +theme_minimal(base_size =12) +theme(axis.text.x =element_text(angle =45, hjust =1, size =8),axis.text.y =element_text(size =8),plot.title =element_text(hjust =0.5, size =16, face ="bold"),legend.position ="right")#topClassAllLabelsConfusionHeatmap#ggsave(plot = topClassAllLabelsConfusionHeatmap,# filename = "../plots/TopClassAllLabelsConfusionHeatmap.jpg",# width = 11,# height = 8)# Convert to interactive plot using plotlytopClassAllLabelsConfusionInteractivePlot <-ggplotly(topClassAllLabelsConfusionHeatmap, tooltip =c("x", "y", "fill"))# # Save the interactive plot as an HTML file# saveWidget(topClassAllLabelsConfusionInteractivePlot,# "../plots/TopClassAllLabelsConfusionHeatmap.html",# selfcontained = TRUE)topClassAllLabelsConfusionInteractivePlot
Confusion matrix with labels grouped by top level category
These plots show agreement/disagreement at a higher taxonomic level, where each label is aggregated to one of 11 top level categories.
Show the code
#### Create another confusion matrix with only the TLCs ##### Get all unique labels and sort them alphabeticallytlc_labels <-sort(union(unique(userResTBL$tlc_user),unique(userResTBL$tlc_classifier)))# Ensure all levels are included and ordered alphabeticallytlc_confusion_matrix <- userResTBL %>%mutate(tlc_user =factor(tlc_user, levels = tlc_labels),tlc_classifier =factor(tlc_classifier, levels = tlc_labels)) %>%count(tlc_user, tlc_classifier, .drop =FALSE)# Create the ggplot heatmaptopClassTlcConfusionHeatmap <-ggplot(tlc_confusion_matrix,aes(x = tlc_classifier,y = tlc_user,fill = n)) +geom_tile(color ="white") +scale_fill_gradientn(colors =c("white", "#6baed6", "#08306b"),name ="Count") +labs(title ="Confusion Matrix: User vs. Classifier (top-level caegories)",x ="Classifier Labels",y ="User Labels") +theme_minimal(base_size =12) +theme(axis.text.x =element_text(angle =45, hjust =1),plot.title =element_text(hjust =0.5, size =16, face ="bold"),legend.position ="right")# topClassTlcConfusionHeatmap# # ggsave(plot = topClassTlcConfusionHeatmap,# filename = "../plots/TopClassTlcConfusionHeatmap.jpg",# width = 10)# Convert to interactive plot using plotlytopClassTlcConfusionInteractivePlot <-ggplotly(topClassTlcConfusionHeatmap, tooltip =c("x", "y", "fill"))# # Save the interactive plot as an HTML file# saveWidget(topClassTlcConfusionInteractivePlot,# "../plots/TopClassTlcConfusionHeatmap.html",# selfcontained = TRUE)topClassTlcConfusionInteractivePlot