MERMAID Image Classification Results

Visualizations to show model perfomance

Author

Iain R. Caldwell

Published

June 3, 2025

Context - Visualizing MERMAID image classification performance

This code visualizes model performance results from the MERMAID image classification model.

Loading package libraries and setting parameters

The following libraries and parameters are used throughout the code.

Show the code
rm(list = ls()) #remove past stored objects
options(scipen = 999) #turn off scientific notation

#Package libraries
library(readxl)
library(tidyverse)
library(ggplot2)
library(mermaidr)
library(ggtext)
library(xfun)
library(plotly) #for turning the confusion matrices into interactive plots
library(htmlwidgets) #saving interactive plots as html files
library(DT) #for interactive data tables

Loading results files

Show the code
### Report with the overall performance
overallClassReportTBL <- 
  read_excel(path = "../data/classifier_metrics.xlsx",
             sheet = "Metrics")

### Report with the performance per label (UUIDs)
labelClassReportTBL <-
  read_excel(path = "../data/classifier_metrics.xlsx",
             sheet = "Classification_Report")

### Label mapping --> UUID to character
labelMapTBL <- read_csv(file = "../data/LabelMap.csv")

### MERMAID benthic attributes
benthicAttTBL <- mermaid_get_reference(reference = "benthicattributes") %>% 
  select(name, parent) %>% 
  rename(ba = name)

### User testing results (for confusion matrices)
userResTBL <- read_csv(file = "../data/anonymizedUserTestingResults.csv")

Prepare data

Show the code
overallClassReportTBL <- overallClassReportTBL %>% 
  filter(...1 %in% c("precision", "recall", "f1_score")) %>% 
  pivot_wider(names_from = ...1,
              values_from = `0.0`) %>% 
  mutate(CoralFocus3Label = "<b>Overall</b>",
         ba = "Overall",
         tlc = "Overall") %>% 
  rename(`f1-score` = f1_score)

labelClassReportTBL <- labelClassReportTBL %>% 
  rename(label_id = ...1) %>% 
  filter(!label_id %in% c("accuracy", "weighted avg", "macro avg")) %>% 
  left_join(labelMapTBL, by = "label_id") %>% 
  select(CoralFocus3Label, `f1-score`, precision, recall)

#### Get the top level categories for each of the labels
## Extract just the benthic attribute from the label
labelClassReportTBL <- labelClassReportTBL %>% 
  mutate(ba = gsub(pattern = paste(" -",
                                   c("Branching",
                                     "Foliose",
                                     "Encrusting",
                                     "Plates or tables",
                                     "Massive",
                                     "Digitate"),
                                   collapse = "|"),
                   replacement = "",
                   x = CoralFocus3Label))

### Get all the unique benthic attributes and assign them to the top level categories
uniqueBaParentTBL <- labelClassReportTBL %>% 
  select(ba) %>% 
  distinct() 

# Function to find the top-level category
find_top_level <- function(ba, lookup_table) {
  parent <- lookup_table$parent[lookup_table$ba == ba]
  if (is.na(parent)) {
    return(ba) # If no parent, the current ba is the top-level category
  } else {
    # Recursively find the top-level category
    return(find_top_level(parent, lookup_table))
  }
}

uniqueBaTlcTBL <- uniqueBaParentTBL %>%
  rowwise() %>%
  mutate(tlc = find_top_level(ba, benthicAttTBL)) %>%
  ungroup()

labelClassReportTBL <- labelClassReportTBL %>% 
  left_join(uniqueBaTlcTBL, by = "ba")

##Get the maximum f1 score by tlc to order tlc as a factor
f1ScoreByTlcTBL <- labelClassReportTBL %>% 
  group_by(tlc) %>% 
  summarise(max_f1_score = max(`f1-score`)) %>%
  ungroup() %>% 
  arrange(desc(max_f1_score))

allClassReportTBL <- bind_rows(labelClassReportTBL, overallClassReportTBL) 

### Add an asterisk to all the user labels that are not represented in the model
userResTBL <- userResTBL %>% 
  mutate(ba_user = ifelse(ba_user %in% labelClassReportTBL$ba,
                          as.character(ba_user),
                          paste0(ba_user, "*")))

Plot the results - horizontal barplots by label

Horizontal barplot with labels organized into top level categories

Show the code
# Reshape data for faceted plotting
longAllClassReportTBL <- allClassReportTBL %>%
  pivot_longer(cols = c(precision, recall, `f1-score`),
               names_to = "Metric", values_to = "Score") %>% 
  mutate(Group = ifelse(CoralFocus3Label == "<b>Overall</b>", "Overall", "Label"))

# Order labels by F1-score 
label_order <- labelClassReportTBL %>%
  arrange(`f1-score`) %>%
  pull(CoralFocus3Label)

tlc_order <- f1ScoreByTlcTBL %>% 
  pull(tlc)

longAllClassReportTBL <- longAllClassReportTBL %>%
  mutate(CoralFocus3Label = factor(CoralFocus3Label,
                                   levels = c(label_order, "<b>Overall</b>")),
         Metric = factor(Metric, levels = c("f1-score", "precision", "recall")),
         Group = factor(Group, levels = c("Overall", "Label")),
         tlc = factor(tlc, levels = c("Overall", tlc_order)))

tlcColors <- c(Overall = "black",
               `Hard coral` = "#498FC9",
               Macroalgae = "#B2B002",
               Sand = "#C1B180",
               `Soft coral` = "#9BE5FA",
               Rubble = "#F5F7AF",
               `Other invertebrates` = "#A6A6A6",
               Cyanobacteria = "#860E00",
               `Bare substrate` = "#F2F3F3",
               Seagrass = "#4D4D4D",
               `Crustose coralline algae` = "#FBD7D5",
               `Turf algae` = "#D8EEA8")

tlcLabels <- c(Overall = "",
               `Hard coral` = "Hard coral",
               Macroalgae = "",
               Sand = "",
               `Soft coral` = "Soft coral",
               Rubble = "",
               `Other invertebrates` = "Other invertebrates",
               Cyanobacteria = "",
               `Bare substrate` = "",
               Seagrass = "",
               `Crustose coralline algae` = "",
               `Turf algae` = "")

# Create a plot with tlc as the row facets
ggplot(longAllClassReportTBL,
       aes(x = Score, y = CoralFocus3Label, fill = tlc)) +
  facet_grid(tlc ~ Metric, scales = "free_y", space = "free_y",
             labeller = labeller(tlc = tlcLabels)) +
  geom_col(alpha = 0.75) +
  geom_vline(xintercept = c(0, 0.25, 0.5, 0.75, 1), linetype = "dotted") +
  scale_x_continuous(breaks = c(0, 0.25, 0.5, 0.75, 1),
                     labels = c(0, 0.25, 0.5, 0.75, 1)) +
  scale_fill_manual(values = tlcColors) +
  labs(x = "Score",
       y = NULL,
       title = "MERMAID Classification Model Performance",
       subtitle = "Overall results and by label") +
  theme_minimal(base_size = 12) +
  theme(
    strip.text = element_text(face = "bold", size = 12),
    #strip.text.y.right = element_blank(),
    strip.text.y.right = element_text(angle = 0, hjust = 0),
    axis.text.x = element_text(size = 10),
    axis.ticks.x = element_line(),
    axis.title.x = element_text(size = 12, face = "bold", margin = margin(t = 10)),
    axis.text.y = element_markdown(size = 10),
    panel.spacing.x = unit(0.25, "lines"),
    legend.position = "none"
  )