---
title: HTW Model e2
author: Thomas Gorman
date: last-modified #"`r Sys.Date()`"
page-layout: full
lightbox: true
categories: [Modeling, ALM, EXAM, R]
toc: false
code-fold: true
code-tools: true
execute:
warning: false
eval: true
---
```{r}
pacman:: p_load (dplyr,purrr,tidyr,ggplot2, data.table, here, patchwork, conflicted,
stringr,future,furrr, knitr, reactable, flextable,ggstance, htmltools, ggdist)
#conflict_prefer_all("dplyr", quiet = TRUE)
options (scipen = 999 )
walk (c ("Display_Functions" ,"fun_alm" ,"fun_indv_fit" ,"fun_model" ), ~ source (here:: here (paste0 ("Functions/" , .x, ".R" ))))
```
### Modelling Results
```{r}
#| cache: false
#ds <- readRDS(here::here("data/e1_md_11-06-23.rds")) |> as.data.table()
ds <- readRDS (here:: here ("data/e2_md_02-23-24.rds" )) |> as.data.table ()
nbins <- 3
fd <- readRDS (here ("data/e2_08-21-23.rds" ))
test <- fd |> filter (expMode2 == "Test" )
testAvg <- test %>% group_by (id, condit, vb, bandInt,bandType,tOrder) %>%
summarise (nHits= sum (dist== 0 ),vx= mean (vx),dist= mean (dist),sdist= mean (sdist),n= n (),Percent_Hit= nHits/ n)
trainAvg <- fd |> filter (expMode2 == "Train" ) |> group_by (id) |>
mutate (tr= trial,x= vb,Block= case_when (expMode2== "Train" ~ cut (tr,breaks= seq (1 ,max (tr), length.out= nbins+ 1 ),include.lowest= TRUE ,labels= FALSE ),
expMode2== "Test" ~ 4 )) |>
group_by (id,condit,vb,x,Block) |>
summarise (dist= mean (dist),y= mean (vx))
input_layer <<- output_layer <<- c (100 ,350 ,600 ,800 ,1000 ,1200 )
ids2 <- c (1 ,66 ,36 )
#file_name <- "e2_n_iter_50_ntry_200_2506"
#file_name <- "n_iter_400_ntry_100_2944"
#file_name <- "e2_n_iter_100_ntry_200_3436"
file_name <- "e2_n_iter_200_ntry_300_2344"
ind_fits <- map (list.files (here (paste0 ('data/abc_reject/' ),file_name),full.names= TRUE ), readRDS)
ind_fits_df <- ind_fits |> map (~ list (dat= .x[[1 ]], Model = .x[["Model" ]], Fit_Method= .x[["Fit_Method" ]]))
ind_fits_df <- ind_fits_df |> map (~ rbindlist (.x$ dat) |> mutate (Model = .x$ Model, Fit_Method = .x$ Fit_Method)) |> rbindlist ()
```
```{r}
#| cache: false
generate_data <- function (Model, post_samples, data, num_samples = 1 , return_dat = "train_data, test_data" ) {
# Filter data for the specific id without invalidating selfref
sbj_data <- copy (data[id == post_samples$ id[1 ]])
simulation_function <- ifelse (Model == "EXAM" , full_sim_exam, full_sim_alm)
target_data <- switch (return_dat,
"test_data" = copy (sbj_data[expMode2 == "Test" ]),
"train_data" = copy (sbj_data[expMode2 == "Train" ]),
"train_data, test_data" = copy (sbj_data[expMode2 %in% c ("Test" , "Train" )]))
post_samples <- post_samples[order (mean_error)][1 : num_samples, .(c, lr, mean_error, rank = .I)]
simulated_data_list <- lapply (1 : nrow (post_samples), function (i) {
params <- post_samples[i]
sim_data <- simulation_function (sbj_data, params$ c, params$ lr, input_layer = input_layer,
output_layer = output_layer, return_dat = return_dat)
sim_data_dt <- data.table (id = sbj_data$ id[1 ], condit = sbj_data$ condit[1 ],
expMode2 = target_data$ expMode2, Model = Model,tr= target_data$ tr,
y = target_data$ y, x = target_data$ x, c = params$ c,
lr = params$ lr, mean_error = params$ mean_error, rank = i,
pred = sim_data)
return (sim_data_dt)
})
result_dt <- rbindlist (simulated_data_list)
setcolorder (result_dt, c ("id" , "condit" , "expMode2" ,"tr" , "c" , "lr" , "x" , "y" , "pred" ))
return (result_dt)
}
#future::plan(multisession)
nestSbjModelFit <- ind_fits_df %>% nest (.by= c (id,Model,Fit_Method))
# post_dat <- nestSbjModelFit |> mutate(pp=furrr::future_pmap(list(id,Model,Fit_Method,data), ~{
# generate_data(..2, ..4 |> mutate(id=..1), ds, num_samples = 50, return_dat="test_data")
# })) |>
# select(Fit_Method,pp,-data) |>
# unnest(pp) |> filter(expMode2=="Test") |> as.data.table()
# saveRDS(post_dat, here("data/model_cache/post_dat_e2.rds"))
post_dat <- readRDS (here ("data/model_cache/post_dat_e2.rds" ))
post_dat_avg <- post_dat |> group_by (id, condit, Model, Fit_Method, x, c, lr, rank) |>
mutate (error2 = y - pred) |>
summarise (y = mean (y), pred = mean (pred), error = y - pred, error2= mean (error2)) |> as.data.table ()
setorder (post_dat_avg, id, x, rank)
post_dat_l <- melt (post_dat_avg, id.vars = c ("id" , "condit" , "Model" , "Fit_Method" , "x" , "c" , "lr" , "rank" ,"error" ),
measure.vars = c ("pred" , "y" ), variable.name = "Resp" , value.name = "val" )
post_dat_l[, Resp : = fifelse (Resp == "y" , "Observed" ,
fifelse (Model == "ALM" , "ALM" , "EXAM" ))]
setorder (post_dat_l, id, Resp)
#rm(post_dat_avg)
post_dat_l <- post_dat_l |> mutate (dist = case_when (
val >= x & val <= x + 200 ~ 0 ,
val < x ~ abs (x - val),
val > x + 200 ~ abs (val - (x + 200 )),
TRUE ~ NA_real_
))
# organize training data predictions
pd_train <- nestSbjModelFit |> mutate (pp= furrr:: future_pmap (list (id,Model,Fit_Method,data), ~ {
generate_data (..2 , ..4 |> mutate (id= ..1 ), ds, num_samples = 20 , return_dat= "train_data" )
})) |>
select (Fit_Method,pp,- data) |>
unnest (pp) |> as.data.table () |> filter (expMode2== "Train" )
#saveRDS(pd_train, here("data/model_cache/pd_train.rds"))
#pd_train <- readRDS(here("data/model_cache/pd_train.rds"))
nbins <- 3
pd_train <- pd_train |> group_by (id,condit,Model,Fit_Method) |>
mutate (Block= cut (tr,breaks= seq (1 ,max (tr), length.out= nbins+ 1 ),include.lowest= TRUE ,labels= FALSE ))
setorder (pd_train, id, x,Block, rank)
pd_train_l <- melt (pd_train, id.vars = c ("id" , "condit" , "Model" ,"Block" , "Fit_Method" , "x" , "c" , "lr" , "rank" ),
measure.vars = c ("pred" , "y" ), variable.name = "Resp" , value.name = "val" ) |> as.data.table ()
pd_train_l[, Resp : = fifelse (Resp == "y" , "Observed" ,
fifelse (Model == "ALM" , "ALM" , "EXAM" ))]
setorder (pd_train_l, id,Block, Resp)
pd_train_l <- pd_train_l |>
mutate (dist = case_when (
val >= x & val <= x + 200 ~ 0 ,
val < x ~ abs (x - val),
val > x + 200 ~ abs (val - (x + 200 )),
TRUE ~ NA_real_
))
#plan(sequential)
```
### Group level aggregations
```{r}
#| eval: true
#| label: tbl-htw-modelError
#| tbl-cap: "Mean model errors predicting testing data, aggregated over all participants and velocity bands. Note that Fit Method refers to how model parameters were optimized, while error values reflect mean absolute error for the 6 testing bands"
post_tabs <- abc_tables (post_dat,post_dat_l)
post_tabs$ agg_pred_full |>
mutate (Fit_Method= rename_fm (Fit_Method)) |>
flextable:: tabulator (rows= c ("Fit_Method" ,"Model" ), columns= c ("condit" ),
` ME ` = as_paragraph (mean_error)) |> as_flextable ()
#post_tabs$agg_pred_full |> pander::pandoc.table()
```
```{r fig.height=12,fig.width=11}
#| label: fig-htw-resid-pred
#| column: page-inset-right
#| fig-cap: A) Model residuals for each combination of training condition, fit method, and model. Residuals reflect the difference between observed and predicted values. Lower values indicate better model fit. Note that y axes are scaled differently between facets. B) Full posterior predictive distributions vs. observed data from participants.Points represent median values, thicker intervals represent 66% credible intervals and thin intervals represent 95% credible intervals around the median.
#| fig-height: 19
#| fig-width: 12
##| layout: [[45,-5, 45], [100]]
##| fig-subcap: ["Model Residuals - training data", "Model Residuals - testing data","Full posterior predictive distributions vs. observed data from participants."]
train_resid <- pd_train |> group_by(id,condit,Model,Fit_Method, Block) |>
summarise(y = mean(y), pred = mean(pred), error = y - pred) |>
ggplot(aes(x = Block, y = abs(error), fill=Model)) +
stat_bar +
ggh4x::facet_nested_wrap(rename_fm(Fit_Method)~condit, scales="free",ncol=2) +
scale_fill_manual(values=wes_palette("AsteroidCity2"))+
labs(title="Model Residual Errors - Training Stage", y="RMSE", x= "Training Block") +
theme(legend.title = element_blank(), legend.position="top")
test_resid <- post_dat |>
group_by(id,condit,x,Model,Fit_Method,rank) |>
summarize(error=mean(abs(y-pred)),n=n()) |>
group_by(id,condit,x,Model,Fit_Method) |>
summarize(error=mean(error)) |>
mutate(vbLab = factor(paste0(x,"-",x+200))) |>
ggplot(aes(x = vbLab, y = abs(error), fill=Model)) +
stat_bar +
scale_fill_manual(values=wes_palette("AsteroidCity2"))+
ggh4x::facet_nested_wrap(rename_fm(Fit_Method)~condit, axes = "all",ncol=2,scale="free") +
labs(title="Model Residual Errors - Testing Stage",y="RMSE", x="Velocity Band") +
theme(axis.text.x = element_text(angle = 45, hjust = 0.5, vjust = 0.5))
group_pred <- post_dat_l |>
mutate(vbLab = factor(paste0(x,"-",x+200),levels=levels(testAvg$vb))) |>
ggplot(aes(x=val,y=vbLab,col=Resp)) +
stat_pointinterval(position=position_dodge(.5), alpha=.9) +
scale_color_manual(values=wes_palette("AsteroidCity2"))+
ggh4x::facet_nested_wrap(rename_fm(Fit_Method)~condit, axes = "all",ncol=2,scale="free") +
labs(title="Posterior Predictions - Testing Stage",y="Velocity Band (lower bound)", x="X Velocity") +
theme(legend.title=element_blank(),axis.text.y = element_text(angle = 45, hjust = 0.5, vjust = 0.5))
((train_resid | test_resid) / group_pred) +
plot_layout(heights=c(1,1.5)) &
plot_annotation(tag_levels = list(c('A1','A2','B')),tag_suffix = ') ') &
theme(plot.tag.position = c(0, 1))
```
## Deviation Predictions
```{r}
#| fig-height: 12
#| fig-width: 11
post_dat_l |>
mutate (vbLab = factor (paste0 (x,"-" ,x+ 200 ),levels= levels (testAvg$ vb))) |>
ggplot (aes (x= condit,y= dist,fill= vbLab)) +
stat_bar +
#facet_wrap(~Resp)
ggh4x:: facet_nested_wrap (rename_fm (Fit_Method)~ Resp, axes = "all" ,ncol= 3 ,scale= "free" )
```
```{r}
#| label: fig-htw-post-dist
#| fig-cap: Posterior Distributions of $c$ and $lr$ parameters. Points represent median values, thicker intervals represent 66% credible intervals and thin intervals represent 95% credible intervals around the median. Note that the y axes of the plots for the c parameter are scaled logarithmically.
#| fig-height: 7
#| fig-width: 11
c_post <- post_dat_avg %>%
group_by (id, condit, Model, Fit_Method, rank) %>%
slice_head (n = 1 ) |>
ggplot (aes (y= log (c), x = Fit_Method,col= condit)) + stat_pointinterval (position= position_dodge (.2 )) +
ggh4x:: facet_nested_wrap (~ Model) + labs (title= "c parameter" ) +
theme (legend.title = element_blank (), legend.position= "right" ,plot.title= element_text (hjust= .4 ))
lr_post <- post_dat_avg %>%
group_by (id, condit, Model, Fit_Method, rank) %>%
slice_head (n = 1 ) |>
ggplot (aes (y= lr, x = Fit_Method,col= condit)) + stat_pointinterval (position= position_dodge (.4 )) +
ggh4x:: facet_nested_wrap (~ Model) + labs (title= "learning rate parameter" ) +
theme (legend.title = element_blank (), legend.position = "none" ,plot.title= element_text (hjust= .5 ))
c_post + lr_post
```
### Accounting for individual patterns
```{r fig.width=11, fig.height=10}
#| eval: false
#| include: false
#| label: fig-htw-indv-pred
#| fig-cap: Model predictions alongside observed data for a subset of individual participants. A) 3 constant and 3 varied participants fit to both the test and training data. B) 3 constant and 3 varied subjects fit to only the trainign data.
#| fig-height: 13
#| fig-width: 12
cId_tr <- c(137, 181, 11)
vId_tr <- c(14, 193, 47)
cId_tt <- c(11, 93, 35)
vId_tt <- c(1,14,74)
# filter(id %in% (filter(bestTestEXAM,group_rank<=9, Fit_Method=="Test")
testIndv <- post_dat_l |> filter(id %in% c(cId_tt,vId_tt), Fit_Method=="Test_Train") |>
mutate(x=as.factor(x), Resp=as.factor(Resp)) |>
group_by(id,condit,Fit_Method,Model,Resp) |>
mutate(flab=paste0("Subject: ",id)) |>
ggplot(aes(x = Resp, y = val, fill=x)) +
stat_bar_sd + ggh4x::facet_nested_wrap(condit~flab, axes = "all",ncol=3) +
labs(title="Individual Participant fits from Test & Train Fitting Method",
y="X Velocity",fill="Target Velocity") +
guides(fill = guide_legend(nrow = 1)) +
theme(legend.position = "bottom",axis.title.x = element_blank())
trainIndv <- post_dat_l |> filter(id %in% c(cId_tr,vId_tr), Fit_Method=="Train") |>
mutate(x=as.factor(x), Resp=as.factor(Resp), flab=paste0("Subject: ",id)) |>
group_by(id,condit,Fit_Method,Model,Resp) |>
ggplot(aes(x = Resp, y = val, fill=x)) +
stat_bar +
ggh4x::facet_nested_wrap(condit~flab, axes = "all",ncol=3) +
labs(title="Individual Participant fits from Train Only Fitting Method", y="X Velocity",
fill="Target Velocity") +
guides(fill = guide_legend(nrow = 1)) +
theme(legend.position = "bottom",axis.title.x = element_blank())
(testIndv / trainIndv) +
plot_annotation(tag_levels = list(c('A','B')),tag_suffix = ') ') &
theme(plot.tag.position = c(0, 1))
```
```{r}
#| label: fig-htw-best-model
#| fig-cap: Difference in model errors for each participant, with models fit to both train and test data. Positive values favor EXAM, while negative values favor ALM.
#| fig-height: 9
#| fig-width: 11
# could compute best model for each posterior parameter - examine consistency
# then I'd have an error bar for each subject in the model error diff. figure
tid1 <- post_dat |> group_by (id,condit,Model,Fit_Method,x) |>
mutate (e2= abs (y- pred)) |>
summarise (y1= mean (y), pred1= mean (pred),mean_error= abs (y1- pred1)) |>
group_by (id,condit,Model,Fit_Method) |>
summarise (mean_error= mean (mean_error)) |>
arrange (id,condit,Fit_Method) |>
round_tibble (1 )
best_id <- tid1 |>
group_by (id,condit,Fit_Method) |> mutate (best= ifelse (mean_error== min (mean_error),1 ,0 ))
lowest_error_model <- best_id %>%
group_by (id, condit,Fit_Method) %>%
summarise (Best_Model = Model[which.min (mean_error)],
Lowest_error = min (mean_error),
differential = min (mean_error) - max (mean_error)) %>%
ungroup ()
error_difference<- best_id %>%
select (id, condit, Model,Fit_Method, mean_error) %>%
pivot_wider (names_from = Model, values_from = c (mean_error)) %>%
mutate (Error_difference = (ALM - EXAM))
full_comparison <- lowest_error_model |> left_join (error_difference, by= c ("id" ,"condit" ,"Fit_Method" )) |>
group_by (condit,Fit_Method,Best_Model) |> mutate (nGrp= n (), model_rank = nGrp - rank (Error_difference) ) |>
arrange (Fit_Method,- Error_difference)
full_comparison |> filter (Fit_Method== "Test_Train" ) |>
ungroup () |>
mutate (id = reorder (id, Error_difference)) %>%
ggplot (aes (y= id,x= Error_difference,fill= Best_Model))+
geom_col () +
ggh4x:: facet_grid2 (~ condit,axes= "all" ,scales= "free_y" , independent = "y" )+
labs (fill= "Best Model" ,x= "Mean Model Error Difference (ALM - EXAM)" ,y= "Participant" )
# full_comparison |> filter(Fit_Method=="Test_Train") |>
# ungroup() |>
# mutate(id = reorder(id, Error_difference)) |>
# left_join(post_dat_avg |> filter(x==100) |> select(-x) |> ungroup(), by=c("id","condit")) |>
# ggplot(aes(y=id,x=c,fill=Best_Model))+
# stat_pointinterval(position=position_dodge(.1))
```
#### Subjects with biggest differential favoring ALM
```{r}
#| fig-height: 6
#| fig-width: 12
vAlm <- c (307 ,331 ,197 ); cAlm <- c (372 ,173 ,157 )
post_dat_l |> filter (id %in% c (vAlm,cAlm), Fit_Method== "Test_Train" ) |>
mutate (x= as.factor (x), Resp= as.factor (Resp)) |>
group_by (id,condit,Fit_Method,Model,Resp) |>
mutate (flab= paste0 ("Subject: " ,id)) |>
ggplot (aes (x = Resp, y = val, fill= x)) +
stat_bar_sd + ggh4x:: facet_nested_wrap (condit~ flab, axes = "all" ,ncol= 3 ) +
labs (title= "Subjects with biggest differential favoring ALM" ,
y= "X Velocity" ,fill= "Target Velocity" ) +
guides (fill = guide_legend (nrow = 1 )) +
theme (legend.position = "bottom" ,axis.title.x = element_blank ())
```
#### Subjects with biggest differential favoring EXAM
```{r}
#| fig-height: 6
#| fig-width: 12
vAlm <- c (312 ,334 ,295 ); cAlm <- c (132 ,366 ,415 )
post_dat_l |> filter (id %in% c (vAlm,cAlm), Fit_Method== "Test_Train" ) |>
mutate (x= as.factor (x), Resp= as.factor (Resp)) |>
group_by (id,condit,Fit_Method,Model,Resp) |>
mutate (flab= paste0 ("Subject: " ,id)) |>
ggplot (aes (x = Resp, y = val, fill= x)) +
stat_bar_sd + ggh4x:: facet_nested_wrap (condit~ flab, axes = "all" ,ncol= 3 ) +
labs (title= "Subjects with biggest differential favoring EXAM" ,
y= "X Velocity" ,fill= "Target Velocity" ) +
guides (fill = guide_legend (nrow = 1 )) +
theme (legend.position = "bottom" ,axis.title.x = element_blank ())
```