Modeling
ALM
EXAM
R
Author

Thomas Gorman

Published

March 27, 2024

Code
pacman::p_load(dplyr,purrr,tidyr,ggplot2, data.table, here, patchwork, conflicted, 
               stringr,future,furrr, knitr, reactable, flextable,ggstance, htmltools, ggdist)
#conflict_prefer_all("dplyr", quiet = TRUE)
options(scipen = 999)
walk(c("Display_Functions","fun_alm","fun_indv_fit","fun_model"), ~ source(here::here(paste0("Functions/", .x, ".R"))))

Modelling Results

Code
#ds <- readRDS(here::here("data/e1_md_11-06-23.rds"))  |> as.data.table()
ds <- readRDS(here::here("data/e2_md_02-23-24.rds"))  |> as.data.table()
nbins <- 3

fd <- readRDS(here("data/e2_08-21-23.rds"))
test <- fd |> filter(expMode2 == "Test") 
testAvg <- test %>% group_by(id, condit, vb, bandInt,bandType,tOrder) %>%
  summarise(nHits=sum(dist==0),vx=mean(vx),dist=mean(dist),sdist=mean(sdist),n=n(),Percent_Hit=nHits/n)

trainAvg <- fd |> filter(expMode2 == "Train") |> group_by(id) |> 
  mutate(tr=trial,x=vb,Block=case_when(expMode2=="Train" ~ cut(tr,breaks=seq(1,max(tr), length.out=nbins+1),include.lowest=TRUE,labels=FALSE),
                                         expMode2=="Test" ~ 4)) |> 
  group_by(id,condit,vb,x,Block) |> 
  summarise(dist=mean(dist),y=mean(vx))

input_layer <<- output_layer <<-  c(100,350,600,800,1000,1200)
ids2 <- c(1,66,36)

#file_name <- "e2_n_iter_50_ntry_200_2506"
#file_name <- "n_iter_400_ntry_100_2944"
#file_name <- "e2_n_iter_100_ntry_200_3436"
file_name <- "e2_n_iter_200_ntry_300_2344"

ind_fits <- map(list.files(here(paste0('data/abc_reject/'),file_name),full.names=TRUE), readRDS)
ind_fits_df <- ind_fits |> map(~list(dat=.x[[1]], Model = .x[["Model"]], Fit_Method=.x[["Fit_Method"]]))
ind_fits_df <- ind_fits_df |> map(~rbindlist(.x$dat) |> mutate(Model = .x$Model, Fit_Method = .x$Fit_Method)) |> rbindlist() 
Code
generate_data <- function(Model, post_samples, data, num_samples = 1, return_dat = "train_data, test_data") {
  # Filter data for the specific id without invalidating selfref
  sbj_data <- copy(data[id == post_samples$id[1]])
  simulation_function <- ifelse(Model == "EXAM", full_sim_exam, full_sim_alm)

  target_data <- switch(return_dat,
                        "test_data" = copy(sbj_data[expMode2 == "Test"]),
                        "train_data" = copy(sbj_data[expMode2 == "Train"]),
                        "train_data, test_data" = copy(sbj_data[expMode2 %in% c("Test", "Train")]))
  
  post_samples <- post_samples[order(mean_error)][1:num_samples, .(c, lr, mean_error, rank = .I)]

  simulated_data_list <- lapply(1:nrow(post_samples), function(i) {
    params <- post_samples[i]
    sim_data <- simulation_function(sbj_data, params$c, params$lr, input_layer = input_layer, 
                                    output_layer = output_layer, return_dat = return_dat)
    sim_data_dt <- data.table(id = sbj_data$id[1], condit = sbj_data$condit[1], 
                              expMode2 = target_data$expMode2, Model = Model,tr=target_data$tr,
                              y = target_data$y, x = target_data$x, c = params$c, 
                              lr = params$lr, mean_error = params$mean_error, rank = i,
                              pred = sim_data)
    return(sim_data_dt)
  })
  
  result_dt <- rbindlist(simulated_data_list)
  setcolorder(result_dt, c("id", "condit", "expMode2","tr", "c", "lr", "x", "y", "pred"))
  return(result_dt)
}

#future::plan(multisession)

nestSbjModelFit <- ind_fits_df %>% nest(.by=c(id,Model,Fit_Method))


#  post_dat <- nestSbjModelFit |> mutate(pp=furrr::future_pmap(list(id,Model,Fit_Method,data), ~{
#     generate_data(..2, ..4 |> mutate(id=..1), ds, num_samples = 50, return_dat="test_data")
#     })) |> 
#    select(Fit_Method,pp,-data) |>  
#    unnest(pp) |>  filter(expMode2=="Test") |> as.data.table()

# saveRDS(post_dat, here("data/model_cache/post_dat_e2.rds"))

post_dat <- readRDS(here("data/model_cache/post_dat_e2.rds"))



post_dat_avg <- post_dat |> group_by(id, condit, Model, Fit_Method, x, c, lr, rank) |> 
  mutate(error2 = y - pred) |>
  summarise(y = mean(y), pred = mean(pred), error = y - pred, error2=mean(error2)) |> as.data.table()

setorder(post_dat_avg, id, x, rank)
post_dat_l <- melt(post_dat_avg, id.vars = c("id", "condit", "Model", "Fit_Method", "x", "c", "lr", "rank","error"),
                   measure.vars = c("pred", "y"), variable.name = "Resp", value.name = "val")
post_dat_l[, Resp := fifelse(Resp == "y", "Observed",
                             fifelse(Model == "ALM", "ALM", "EXAM"))]
setorder(post_dat_l, id, Resp)
#rm(post_dat_avg)

post_dat_l <- post_dat_l |> mutate(dist = case_when(
    val >= x & val <= x + 200 ~ 0,                 
    val < x ~ abs(x - val),                       
    val > x + 200 ~ abs(val - (x + 200)),           
    TRUE ~ NA_real_                                 
  ))


# organize training data predictions
 pd_train <- nestSbjModelFit |> mutate(pp=furrr::future_pmap(list(id,Model,Fit_Method,data), ~{
   generate_data(..2, ..4 |> mutate(id=..1), ds, num_samples = 20, return_dat="train_data")
    })) |>
   select(Fit_Method,pp,-data) |>
  unnest(pp) |> as.data.table() |> filter(expMode2=="Train")

#saveRDS(pd_train, here("data/model_cache/pd_train.rds"))

#pd_train <- readRDS(here("data/model_cache/pd_train.rds"))

nbins <- 3
pd_train <- pd_train |> group_by(id,condit,Model,Fit_Method) |>
  mutate(Block=cut(tr,breaks=seq(1,max(tr), length.out=nbins+1),include.lowest=TRUE,labels=FALSE))
setorder(pd_train, id, x,Block, rank)

pd_train_l <- melt(pd_train, id.vars = c("id", "condit", "Model","Block", "Fit_Method", "x", "c", "lr", "rank"),
                   measure.vars = c("pred", "y"), variable.name = "Resp", value.name = "val") |> as.data.table()
pd_train_l[, Resp := fifelse(Resp == "y", "Observed",
                             fifelse(Model == "ALM", "ALM", "EXAM"))] 
setorder(pd_train_l, id,Block, Resp) 

pd_train_l <- pd_train_l  |>
  mutate(dist = case_when(
    val >= x & val <= x + 200 ~ 0,                 
    val < x ~ abs(x - val),                       
    val > x + 200 ~ abs(val - (x + 200)),           
    TRUE ~ NA_real_                                 
  ))

#plan(sequential)

Group level aggregations

Code
post_tabs <- abc_tables(post_dat,post_dat_l)

post_tabs$agg_pred_full |> 
  mutate(Fit_Method=rename_fm(Fit_Method)) |>
  flextable::tabulator(rows=c("Fit_Method","Model"), columns=c("condit"),
                       `ME` = as_paragraph(mean_error)) |> as_flextable()
#post_tabs$agg_pred_full |> pander::pandoc.table()
Table 1: Mean model errors predicting testing data, aggregated over all participants and velocity bands. Note that Fit Method refers to how model parameters were optimized, while error values reflect mean absolute error for the 6 testing bands

Fit_Method

Model

Constant

Varied

Fit to Test Data

ALM

319.2

233.9

EXAM

219.8

205.0

Fit to Test & Training Data

ALM

337.6

275.4

EXAM

234.6

222.6

Fit to Training Data

ALM

412.6

342.0

EXAM

367.3

295.7

Code
##| layout: [[45,-5, 45], [100]]
##| fig-subcap: ["Model Residuals - training data", "Model Residuals - testing data","Full posterior predictive distributions vs. observed data from participants."]
train_resid <- pd_train |> group_by(id,condit,Model,Fit_Method, Block) |> 
  summarise(y = mean(y), pred = mean(pred), error = y - pred) |>
  ggplot(aes(x = Block, y = abs(error), fill=Model)) + 
  stat_bar + 
  ggh4x::facet_nested_wrap(rename_fm(Fit_Method)~condit, scales="free",ncol=2) +
  scale_fill_manual(values=wes_palette("AsteroidCity2"))+
  labs(title="Model Residual Errors - Training Stage", y="RMSE", x= "Training Block") +
  theme(legend.title = element_blank(), legend.position="top")

test_resid <-  post_dat |> 
   group_by(id,condit,x,Model,Fit_Method,rank) |>
   summarize(error=mean(abs(y-pred)),n=n()) |>
   group_by(id,condit,x,Model,Fit_Method) |>
   summarize(error=mean(error)) |>
  mutate(vbLab = factor(paste0(x,"-",x+200))) |>
  ggplot(aes(x = vbLab, y = abs(error), fill=Model)) + 
  stat_bar + 
  scale_fill_manual(values=wes_palette("AsteroidCity2"))+
  ggh4x::facet_nested_wrap(rename_fm(Fit_Method)~condit, axes = "all",ncol=2,scale="free") +
  labs(title="Model Residual Errors - Testing Stage",y="RMSE", x="Velocity Band") +
  theme(axis.text.x = element_text(angle = 45, hjust = 0.5, vjust = 0.5)) 

group_pred <- post_dat_l |> 
  mutate(vbLab = factor(paste0(x,"-",x+200),levels=levels(testAvg$vb))) |>
  ggplot(aes(x=val,y=vbLab,col=Resp)) + 
  stat_pointinterval(position=position_dodge(.5), alpha=.9) + 
  scale_color_manual(values=wes_palette("AsteroidCity2"))+
  ggh4x::facet_nested_wrap(rename_fm(Fit_Method)~condit, axes = "all",ncol=2,scale="free") +
  labs(title="Posterior Predictions - Testing Stage",y="Velocity Band (lower bound)", x="X Velocity") +
theme(legend.title=element_blank(),axis.text.y = element_text(angle = 45, hjust = 0.5, vjust = 0.5))


((train_resid | test_resid) / group_pred) +
  plot_layout(heights=c(1,1.5)) & 
  plot_annotation(tag_levels = list(c('A1','A2','B')),tag_suffix = ') ') & 
  theme(plot.tag.position = c(0, 1))
Figure 1: A) Model residuals for each combination of training condition, fit method, and model. Residuals reflect the difference between observed and predicted values. Lower values indicate better model fit. Note that y axes are scaled differently between facets. B) Full posterior predictive distributions vs. observed data from participants.Points represent median values, thicker intervals represent 66% credible intervals and thin intervals represent 95% credible intervals around the median.

Deviation Predictions

Code
 post_dat_l |> 
  mutate(vbLab = factor(paste0(x,"-",x+200),levels=levels(testAvg$vb))) |>
  ggplot(aes(x=condit,y=dist,fill=vbLab)) + 
  stat_bar + 
  #facet_wrap(~Resp)
  ggh4x::facet_nested_wrap(rename_fm(Fit_Method)~Resp, axes = "all",ncol=3,scale="free")

Code
c_post <- post_dat_avg %>%
    group_by(id, condit, Model, Fit_Method, rank) %>%
    slice_head(n = 1) |>
    ggplot(aes(y=log(c), x = Fit_Method,col=condit)) + stat_pointinterval(position=position_dodge(.2)) +
    ggh4x::facet_nested_wrap(~Model) + labs(title="c parameter") +
  theme(legend.title = element_blank(), legend.position="right",plot.title=element_text(hjust=.4))

lr_post <- post_dat_avg %>%
    group_by(id, condit, Model, Fit_Method, rank) %>%
    slice_head(n = 1) |>
    ggplot(aes(y=lr, x = Fit_Method,col=condit)) + stat_pointinterval(position=position_dodge(.4)) +
    ggh4x::facet_nested_wrap(~Model) + labs(title="learning rate parameter") +
  theme(legend.title = element_blank(), legend.position = "none",plot.title=element_text(hjust=.5))
c_post + lr_post
Figure 2: Posterior Distributions of \(c\) and \(lr\) parameters. Points represent median values, thicker intervals represent 66% credible intervals and thin intervals represent 95% credible intervals around the median. Note that the y axes of the plots for the c parameter are scaled logarithmically.

Accounting for individual patterns

Code
# could compute best model for each posterior parameter - examine consistency
# then I'd have an error bar for each subject in the model error diff. figure

tid1 <- post_dat  |> group_by(id,condit,Model,Fit_Method,x) |> 
  mutate(e2=abs(y-pred)) |> 
  summarise(y1=mean(y), pred1=mean(pred),mean_error=abs(y1-pred1)) |>
  group_by(id,condit,Model,Fit_Method) |> 
  summarise(mean_error=mean(mean_error)) |> 
  arrange(id,condit,Fit_Method) |>
  round_tibble(1) 

best_id <- tid1 |> 
  group_by(id,condit,Fit_Method) |> mutate(best=ifelse(mean_error==min(mean_error),1,0)) 

lowest_error_model <- best_id %>%
  group_by(id, condit,Fit_Method) %>%
  summarise(Best_Model = Model[which.min(mean_error)],
            Lowest_error = min(mean_error),
            differential = min(mean_error) - max(mean_error)) %>%
  ungroup()


error_difference<- best_id %>%
  select(id, condit, Model,Fit_Method, mean_error) %>%
  pivot_wider(names_from = Model, values_from = c(mean_error)) %>%
  mutate(Error_difference = (ALM - EXAM))

full_comparison <- lowest_error_model |> left_join(error_difference, by=c("id","condit","Fit_Method"))  |> 
  group_by(condit,Fit_Method,Best_Model) |> mutate(nGrp=n(), model_rank = nGrp - rank(Error_difference) ) |> 
  arrange(Fit_Method,-Error_difference)

full_comparison |> filter(Fit_Method=="Test_Train") |> 
  ungroup() |>
  mutate(id = reorder(id, Error_difference)) %>%
  ggplot(aes(y=id,x=Error_difference,fill=Best_Model))+
  geom_col() +
  ggh4x::facet_grid2(~condit,axes="all",scales="free_y", independent = "y")+
  labs(fill="Best Model",x="Mean Model Error Difference (ALM - EXAM)",y="Participant")



# full_comparison |> filter(Fit_Method=="Test_Train") |> 
#   ungroup() |>
#   mutate(id = reorder(id, Error_difference)) |>
#   left_join(post_dat_avg |> filter(x==100) |> select(-x) |> ungroup(), by=c("id","condit")) |>
#   ggplot(aes(y=id,x=c,fill=Best_Model))+
#   stat_pointinterval(position=position_dodge(.1))
Figure 3: Difference in model errors for each participant, with models fit to both train and test data. Positive values favor EXAM, while negative values favor ALM.

Subjects with biggest differential favoring ALM

Code
vAlm <- c(307,331,197); cAlm <- c(372,173,157)

post_dat_l |> filter(id %in% c(vAlm,cAlm), Fit_Method=="Test_Train") |> 
   mutate(x=as.factor(x), Resp=as.factor(Resp)) |>
  group_by(id,condit,Fit_Method,Model,Resp) |>
   mutate(flab=paste0("Subject: ",id)) |>
  ggplot(aes(x = Resp, y = val, fill=x)) + 
  stat_bar_sd + ggh4x::facet_nested_wrap(condit~flab, axes = "all",ncol=3) +
  labs(title="Subjects with biggest differential favoring ALM",
       y="X Velocity",fill="Target Velocity") +
   guides(fill = guide_legend(nrow = 1)) + 
  theme(legend.position = "bottom",axis.title.x = element_blank())

Subjects with biggest differential favoring EXAM

Code
vAlm <- c(312,334,295); cAlm <- c(132,366,415)

post_dat_l |> filter(id %in% c(vAlm,cAlm), Fit_Method=="Test_Train") |> 
   mutate(x=as.factor(x), Resp=as.factor(Resp)) |>
  group_by(id,condit,Fit_Method,Model,Resp) |>
   mutate(flab=paste0("Subject: ",id)) |>
  ggplot(aes(x = Resp, y = val, fill=x)) + 
  stat_bar_sd + ggh4x::facet_nested_wrap(condit~flab, axes = "all",ncol=3) +
  labs(title="Subjects with biggest differential favoring EXAM",
       y="X Velocity",fill="Target Velocity") +
   guides(fill = guide_legend(nrow = 1)) + 
  theme(legend.position = "bottom",axis.title.x = element_blank())