MERMAID Image Classification Results

Visualizations to show model perfomance

Author

Iain R. Caldwell

Published

June 3, 2025

Context - Visualizing MERMAID image classification performance

This code visualizes model performance results from the MERMAID image classification model.

Loading package libraries and setting parameters

The following libraries and parameters are used throughout the code.

Show the code
rm(list = ls()) #remove past stored objects
options(scipen = 999) #turn off scientific notation

#Package libraries
library(readxl)
library(tidyverse)
library(ggplot2)
library(mermaidr)
library(ggtext)
library(xfun)
library(plotly) #for turning the confusion matrices into interactive plots
library(htmlwidgets) #saving interactive plots as html files

Loading results files

Show the code
### Report with the overall performance
overallClassReportTBL <- 
  read_excel(path = "../data/classifier_metrics.xlsx",
             sheet = "Metrics")

### Report with the performance per label (UUIDs)
labelClassReportTBL <-
  read_excel(path = "../data/classifier_metrics.xlsx",
             sheet = "Classification_Report")

### Label mapping --> UUID to character
labelMapTBL <- read_csv(file = "../data/LabelMap.csv")

### MERMAID benthic attributes
benthicAttTBL <- mermaid_get_reference(reference = "benthicattributes") %>% 
  select(name, parent) %>% 
  rename(ba = name)

### User testing results (for confusion matrices)
userResTBL <- read_csv(file = "../data/anonymizedUserTestingResults.csv")

Prepare data

Show the code
overallClassReportTBL <- overallClassReportTBL %>% 
  filter(...1 %in% c("precision", "recall", "f1_score")) %>% 
  pivot_wider(names_from = ...1,
              values_from = `0.0`) %>% 
  mutate(CoralFocus3Label = "<b>Overall</b>",
         ba = "Overall",
         tlc = "Overall") %>% 
  rename(`f1-score` = f1_score)

labelClassReportTBL <- labelClassReportTBL %>% 
  rename(label_id = ...1) %>% 
  filter(!label_id %in% c("accuracy", "weighted avg", "macro avg")) %>% 
  left_join(labelMapTBL, by = "label_id") %>% 
  select(CoralFocus3Label, `f1-score`, precision, recall)

#### Get the top level categories for each of the labels
## Extract just the benthic attribute from the label
labelClassReportTBL <- labelClassReportTBL %>% 
  mutate(ba = gsub(pattern = paste(" -",
                                   c("Branching",
                                     "Foliose",
                                     "Encrusting",
                                     "Plates or tables",
                                     "Massive",
                                     "Digitate"),
                                   collapse = "|"),
                   replacement = "",
                   x = CoralFocus3Label))

### Get all the unique benthic attributes and assign them to the top level categories
uniqueBaParentTBL <- labelClassReportTBL %>% 
  select(ba) %>% 
  distinct() 

# Function to find the top-level category
find_top_level <- function(ba, lookup_table) {
  parent <- lookup_table$parent[lookup_table$ba == ba]
  if (is.na(parent)) {
    return(ba) # If no parent, the current ba is the top-level category
  } else {
    # Recursively find the top-level category
    return(find_top_level(parent, lookup_table))
  }
}

uniqueBaTlcTBL <- uniqueBaParentTBL %>%
  rowwise() %>%
  mutate(tlc = find_top_level(ba, benthicAttTBL)) %>%
  ungroup()

labelClassReportTBL <- labelClassReportTBL %>% 
  left_join(uniqueBaTlcTBL, by = "ba")

##Get the maximum f1 score by tlc to order tlc as a factor
f1ScoreByTlcTBL <- labelClassReportTBL %>% 
  group_by(tlc) %>% 
  summarise(max_f1_score = max(`f1-score`)) %>%
  ungroup() %>% 
  arrange(desc(max_f1_score))

allClassReportTBL <- bind_rows(labelClassReportTBL, overallClassReportTBL) 

### Add an asterisk to all the user labels that are not represented in the model
userResTBL <- userResTBL %>% 
  mutate(ba_user = ifelse(ba_user %in% labelClassReportTBL$ba,
                          as.character(ba_user),
                          paste0(ba_user, "*")))

Plot the results - horizontal barplots by label

Horizontal barplot with color gradient

Show the code
# Reshape data for faceted plotting
longAllClassReportTBL <- allClassReportTBL %>%
  pivot_longer(cols = c(precision, recall, `f1-score`),
               names_to = "Metric", values_to = "Score") %>% 
  mutate(Group = ifelse(CoralFocus3Label == "<b>Overall</b>", "Overall", "Label"))

# Order labels by F1-score 
label_order <- labelClassReportTBL %>%
  arrange(`f1-score`) %>%
  pull(CoralFocus3Label)

tlc_order <- f1ScoreByTlcTBL %>% 
  pull(tlc)

longAllClassReportTBL <- longAllClassReportTBL %>%
  mutate(CoralFocus3Label = factor(CoralFocus3Label,
                                   levels = c(label_order, "<b>Overall</b>")),
         Metric = factor(Metric, levels = c("f1-score", "precision", "recall")),
         Group = factor(Group, levels = c("Overall", "Label")),
         tlc = factor(tlc, levels = c("Overall", tlc_order)))

# Create a plot with group as the row facets
ggplot(longAllClassReportTBL,
       aes(x = Score, y = CoralFocus3Label, fill = Score)) +
  facet_grid(Group ~ Metric, scales = "free_y", space = "free_y") +
  geom_col(alpha = 0.75) +
  geom_vline(xintercept = c(0, 0.25, 0.5, 0.75, 1), linetype = "dotted") +
  scale_fill_viridis_c(option = "D", direction = -1) +
  scale_x_continuous(breaks = c(0, 0.25, 0.5, 0.75, 1),
                     labels = c(0, 0.25, 0.5, 0.75, 1)) +
  labs(x = "Score",
       y = NULL,
       title = "MERMAID Classification Model Performance",
       subtitle = "Overall results and by label") +
  theme_minimal(base_size = 12) +
  theme(
    strip.text = element_text(face = "bold", size = 14),
    strip.text.y.right = element_blank(),
    axis.text.x = element_text(size = 10),
    axis.ticks.x = element_line(),
    axis.title.x = element_text(size = 12, face = "bold", margin = margin(t = 10)),
    axis.text.y = element_markdown(size = 10),
    panel.spacing.x = unit(0.75, "lines"),
    legend.position = "none"
  )

Horizontal barplot with labels organized into top level categories

Show the code
tlcColors <- c(Overall = "black",
               `Hard coral` = "#498FC9",
               Macroalgae = "#B2B002",
               Sand = "#C1B180",
               `Soft coral` = "#9BE5FA",
               Rubble = "#F5F7AF",
               `Other invertebrates` = "#A6A6A6",
               Cyanobacteria = "#860E00",
               `Bare substrate` = "#F2F3F3",
               Seagrass = "#4D4D4D",
               `Crustose coralline algae` = "#FBD7D5",
               `Turf algae` = "#D8EEA8")

tlcLabels <- c(Overall = "",
               `Hard coral` = "Hard coral",
               Macroalgae = "",
               Sand = "",
               `Soft coral` = "Soft coral",
               Rubble = "",
               `Other invertebrates` = "Other invertebrates",
               Cyanobacteria = "",
               `Bare substrate` = "",
               Seagrass = "",
               `Crustose coralline algae` = "",
               `Turf algae` = "")

# Create a plot with tlc as the row facets
ggplot(longAllClassReportTBL,
       aes(x = Score, y = CoralFocus3Label, fill = tlc)) +
  facet_grid(tlc ~ Metric, scales = "free_y", space = "free_y",
             labeller = labeller(tlc = tlcLabels)) +
  geom_col(alpha = 0.75) +
  geom_vline(xintercept = c(0, 0.25, 0.5, 0.75, 1), linetype = "dotted") +
  scale_x_continuous(breaks = c(0, 0.25, 0.5, 0.75, 1),
                     labels = c(0, 0.25, 0.5, 0.75, 1)) +
  scale_fill_manual(values = tlcColors) +
  labs(x = "Score",
       y = NULL,
       title = "MERMAID Classification Model Performance",
       subtitle = "Overall results and by label") +
  theme_minimal(base_size = 12) +
  theme(
    strip.text = element_text(face = "bold", size = 12),
    #strip.text.y.right = element_blank(),
    strip.text.y.right = element_text(angle = 0, hjust = 0),
    axis.text.x = element_text(size = 10),
    axis.ticks.x = element_line(),
    axis.title.x = element_text(size = 12, face = "bold", margin = margin(t = 10)),
    axis.text.y = element_markdown(size = 10),
    panel.spacing.x = unit(0.25, "lines"),
    legend.position = "none"
  )

Plot user data vs. model - confusion matrices

Confusion matrices are a visualized representation of the how frequently the model and users assign the same or different labels to points on images. Cells along the diagonal indicate agreement and off the diagonal indicate disagreement between the model assignment and users.

Confusion matrix with all labels

This confusion matrix includes all labels assigned by the model or users. The labels with asterisks were assigned by users but were not included in the model, so could not have been assigned as such.

Show the code
## Note - this uses the best model guess with no score cutoff
# Get all unique labels and sort them alphabetically
all_labels <- sort(union(
  unique(userResTBL$ba_user),
  unique(userResTBL$ba_classifier)
))

# Ensure all levels are included and ordered alphabetically
confusion_matrix <- userResTBL %>%
  mutate(
    ba_user = factor(ba_user, levels = all_labels),
    ba_classifier = factor(ba_classifier, levels = all_labels)
  ) %>%
  count(ba_user, ba_classifier, .drop = FALSE)

# Create the ggplot heatmap
topClassAllLabelsConfusionHeatmap <- ggplot(confusion_matrix,
                                         aes(x = ba_classifier,
                                             y = ba_user,
                                             fill = n)) +
  geom_tile(color = "white") +
  scale_fill_gradientn(colors = c("white", "#6baed6", "#08306b"),
                       name = "Count") +
  labs(title = "Confusion Matrix: User vs. Classifier (all labels)",
       x = "Classifier Labels",
       y = "User Labels") +
  theme_minimal(base_size = 12) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, size = 8),
        axis.text.y = element_text(size = 8),
        plot.title = element_text(hjust = 0.5, size = 16, face = "bold"),
        legend.position = "right")

#topClassAllLabelsConfusionHeatmap

#ggsave(plot = topClassAllLabelsConfusionHeatmap,
#       filename = "../plots/TopClassAllLabelsConfusionHeatmap.jpg",
#       width = 11,
#       height = 8)

# Convert to interactive plot using plotly
topClassAllLabelsConfusionInteractivePlot <-
  ggplotly(topClassAllLabelsConfusionHeatmap, tooltip = c("x", "y", "fill"))

# # Save the interactive plot as an HTML file
# saveWidget(topClassAllLabelsConfusionInteractivePlot,
#            "../plots/TopClassAllLabelsConfusionHeatmap.html",
#            selfcontained = TRUE)

topClassAllLabelsConfusionInteractivePlot

Confusion matrix with labels grouped by top level category

These plots show agreement/disagreement at a higher taxonomic level, where each label is aggregated to one of 11 top level categories.

Show the code
#### Create another confusion matrix with only the TLCs ####
# Get all unique labels and sort them alphabetically
tlc_labels <- sort(union(unique(userResTBL$tlc_user),
                         unique(userResTBL$tlc_classifier)))

# Ensure all levels are included and ordered alphabetically
tlc_confusion_matrix <- userResTBL %>%
  mutate(tlc_user = factor(tlc_user, levels = tlc_labels),
         tlc_classifier = factor(tlc_classifier, levels = tlc_labels)) %>%
  count(tlc_user, tlc_classifier, .drop = FALSE)

# Create the ggplot heatmap
topClassTlcConfusionHeatmap <- ggplot(tlc_confusion_matrix,
                                      aes(x = tlc_classifier,
                                          y = tlc_user,
                                          fill = n)) +
  geom_tile(color = "white") +
  scale_fill_gradientn(colors = c("white", "#6baed6", "#08306b"),
                       name = "Count") +
  labs(title = "Confusion Matrix: User vs. Classifier (top-level caegories)",
       x = "Classifier Labels",
       y = "User Labels") +
  theme_minimal(base_size = 12) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1),
        plot.title = element_text(hjust = 0.5, size = 16, face = "bold"),
        legend.position = "right")

# topClassTlcConfusionHeatmap
# 
# ggsave(plot = topClassTlcConfusionHeatmap,
#        filename = "../plots/TopClassTlcConfusionHeatmap.jpg",
#        width = 10)

# Convert to interactive plot using plotly
topClassTlcConfusionInteractivePlot <-
  ggplotly(topClassTlcConfusionHeatmap, tooltip = c("x", "y", "fill"))

# # Save the interactive plot as an HTML file
# saveWidget(topClassTlcConfusionInteractivePlot,
#            "../plots/TopClassTlcConfusionHeatmap.html",
#            selfcontained = TRUE)

topClassTlcConfusionInteractivePlot
 

Powered by

Logo