Project Setup

Start by loading the COVID Tracking Project (CTP) dataset (states daily) from the API. Documentation can be found on the website. https://covidtracking.com/data/api

I’ll also load the libraries we’ll need for the project.

# load libraries
library(tidyverse)
library(jsonlite)
library(lubridate)
library(pracma)
library(tidymodels)
library(naniar)
# load and view
ctp <- fromJSON("https://api.covidtracking.com/v1/states/daily.json")
names(ctp)
##  [1] "date"                        "state"                      
##  [3] "positive"                    "probableCases"              
##  [5] "negative"                    "pending"                    
##  [7] "totalTestResultsSource"      "totalTestResults"           
##  [9] "hospitalizedCurrently"       "hospitalizedCumulative"     
## [11] "inIcuCurrently"              "inIcuCumulative"            
## [13] "onVentilatorCurrently"       "onVentilatorCumulative"     
## [15] "recovered"                   "dataQualityGrade"           
## [17] "lastUpdateEt"                "dateModified"               
## [19] "checkTimeEt"                 "death"                      
## [21] "hospitalized"                "dateChecked"                
## [23] "totalTestsViral"             "positiveTestsViral"         
## [25] "negativeTestsViral"          "positiveCasesViral"         
## [27] "deathConfirmed"              "deathProbable"              
## [29] "totalTestEncountersViral"    "totalTestsPeopleViral"      
## [31] "totalTestsAntibody"          "positiveTestsAntibody"      
## [33] "negativeTestsAntibody"       "totalTestsPeopleAntibody"   
## [35] "positiveTestsPeopleAntibody" "negativeTestsPeopleAntibody"
## [37] "totalTestsPeopleAntigen"     "positiveTestsPeopleAntigen" 
## [39] "totalTestsAntigen"           "positiveTestsAntigen"       
## [41] "fips"                        "positiveIncrease"           
## [43] "negativeIncrease"            "total"                      
## [45] "totalTestResultsIncrease"    "posNeg"                     
## [47] "deathIncrease"               "hospitalizedIncrease"       
## [49] "hash"                        "commercialScore"            
## [51] "negativeRegularScore"        "negativeScore"              
## [53] "positiveScore"               "score"                      
## [55] "grade"

Data Preparation

There’s a lot of columns here, but we don’t need all of them. I’m going to focus on new positive cases, new tests, new deaths, and number of people currently hospitalized – along with the state and date variables.

Before selecting these, I’m just going to confirm that state and date uniquely identify the observations in the data.

# check uniqueness of state and date
if (nrow(ctp %>% distinct(state, date)) / nrow(ctp) != 1) {
  print("Go find your duplicates")
} else{
    print("There's one observation for each state and date.")
}
## [1] "There's one observation for each state and date."

Now just pull in the six variables we need. I’m also going to do some simple renaming and date formatting to prepare for analysis. I’ll also check whether there are negative values. If there are negative values, I’ll remove them.

# get only the variables that we need
ctp <- ctp %>%
  select(state, date, totalTestResultsIncrease, positiveIncrease, deathIncrease, hospitalizedCurrently) %>%
  
  # rename variables to make them a bit shorter and easier
  rename(cases_new = positiveIncrease,
         tests_new = totalTestResultsIncrease,
         deaths_new = deathIncrease,
         hosp_current = hospitalizedCurrently) %>%
  
  # make sure date is in date format
  mutate(date = ymd(date))

knitr::kable(head(ctp))
state date tests_new cases_new deaths_new hosp_current
AK 2020-12-06 10545 757 0 164
AL 2020-12-06 7880 2288 12 1927
AR 2020-12-06 14704 1542 40 1076
AS 2020-12-06 0 0 0 NA
AZ 2020-12-06 20586 5376 25 2977
CA 2020-12-06 293071 30075 85 10624
# check for negative values
nrow(ctp %>%
  filter(cases_new < 0 | deaths_new < 0 | tests_new < 0))
## [1] 119
# remove negative values
ctp <- ctp %>%
  replace_with_na_if(is.numeric, ~.x < 0)

# check again for negative values
nrow(ctp %>%
  filter(cases_new < 0 | deaths_new < 0 | tests_new < 0))
## [1] 0

We’re also going to load data on states’ populations, which is useful for normalizing results at the state level. In general, we’re going to focus on cases, deaths, etc. per unit of population

# get data from CTP GitHub page
state_pop <- read_csv("https://raw.githubusercontent.com/COVID19Tracking/associated-data/master/us_census_data/us_census_2018_population_estimates_states.csv") %>%
  select(state, population)

# join to our main dataset
ctp <- ctp %>%
  left_join(state_pop, by = "state")

knitr::kable(head(ctp))
state date tests_new cases_new deaths_new hosp_current population
AK 2020-12-06 10545 757 0 164 737438
AL 2020-12-06 7880 2288 12 1927 4887871
AR 2020-12-06 14704 1542 40 1076 3013825
AS 2020-12-06 0 0 0 NA NA
AZ 2020-12-06 20586 5376 25 2977 7171646
CA 2020-12-06 293071 30075 85 10624 39557045

Exploratory Analysis

Before diving in to modeling, we need to get a sense of what the data looks like. We’ll use the ggplot2 package for data visualization to explore how trends in new cases, deaths, tests, and hospitalizations have varied over time. First, we’ll look at a national level, using group_by and summarize functions to get a single observation for the US on each date that is the sum of values in individual states and territories.

# first, need to get data to a national level
national_ctp <- ctp %>%
  
  # group by date since we want one row for each date
  group_by(date) %>%
  
  # sum our main columns to get country-level variables
  summarize_if(is.numeric, sum, na.rm = TRUE) %>%
  mutate(entity = "United States")

# I'll be using `tail` from now on to look at the data, because the data is 
# better populated more recently than it was in March when CTP began
knitr::kable(tail(national_ctp))
date tests_new cases_new deaths_new hosp_current population entity
2020-12-01 2340996 176753 2473 98777 330362587 United States
2020-12-02 1470464 195796 2733 100322 330362587 United States
2020-12-03 1828230 210204 2706 100755 330362587 United States
2020-12-04 1854869 224831 2563 101276 330362587 United States
2020-12-05 2169756 211073 2445 101190 330362587 United States
2020-12-06 1634532 176771 1138 101487 330362587 United States

Now that we have a national level dataset, let’s begin plotting. Start with a simple bar plot of cases and then deaths and then all metrics together

# cases plot
cases.plot <- national_ctp %>%
  ggplot(aes(x = date, y = cases_new)) +
  geom_col(color = "orange") +
  labs(title = "New COVID-19 Cases Reported by Date in the United States")
cases.plot

# deaths plot
deaths.plot <- national_ctp %>%
  ggplot(aes(x = date, y = deaths_new)) +
  geom_col(color = "gray") +
  labs(title = "New COVID-19 Deaths Reported by Date in the United States")
deaths.plot

# plot all metrics together by reshaping the data to long format
all_metrics.plot <- national_ctp %>%
  pivot_longer(cols = -c(date, entity, population),
               values_to = "value",
               names_to = "metric") %>%
  ggplot(aes(x = date, y = value, color = metric)) +
  geom_col() +
  facet_wrap(~metric, scales = "free") +
  theme(legend.position = "none") +
  labs(title = "COVID-19 Metrics by Date in the United States")
all_metrics.plot

From looking at these plots, particularly for individual cases and deaths, we can see that the data is pretty noisy. There are lots of dates when the number of deaths, for example, are relatively low only to double the next day. The reason for this is reporting. State health departments generally process fewer cases, deaths, and other related measures on weekends, so the days following weekend days (Monday and Tuesday) generally have lower values. If we’re trying to get a sense of trends and eventually predict the number of deaths on a given day in the future, this makes things difficult. The solution, however, is simple. We’ll take a 7-day rolling average for our main values to smooth the day-of-week variation in reporting. This will use the pracma package. It also uses !!sym in the context of a loop so that mutate recognizes a string as a column name. Read more about how this works here: https://stackoverflow.com/questions/57136322/what-does-the-operator-mean-in-r-particularly-in-the-context-symx.

# use a loop to create the 7-day average variables
# if the state has had less than 7 days reporting, don't calculate an average
for (v in c("tests_new", "cases_new", "deaths_new", "hosp_current")){

  ctp <- ctp %>%
    group_by(state) %>%
    arrange(date) %>%
    mutate(day_count = row_number()) %>%
             
    mutate(!!sym(paste0(v, "_7d")) := ifelse(day_count >= 7, movavg(!!sym(v), 7, type = "s"), NA)) %>%
    ungroup()
}

Now when we plot the values we see a much smoother trend.

# plot of 7-day average for all metrics
national_7d <- ctp %>%
  select(state, date, ends_with("7d")) %>%
  group_by(date) %>%
  summarize_if(is.numeric, sum, na.rm = TRUE)
  
all_metrics7d.plot <- national_7d %>%
  pivot_longer(cols = -c(date), # pivot longer to plot all metrics together
               values_to = "value",
               names_to = "metric") %>%
  ggplot(aes(x = date, y = value, color = metric)) +
  geom_col() +
  facet_wrap(~metric, scales = "free") +
  theme(legend.position = "none") +
  labs(title = "COVID-19 Metrics (7-Day Average) by Date in the United States")

all_metrics7d.plot

We can even plot the the raw values and 7-day averages over each other.

# overlay column and line chart
combined_cases.plot <- ctp %>%
  
  select(state, date, cases_new, cases_new_7d) %>%
  group_by(date) %>%
  summarize_if(is.numeric, sum, na.rm = TRUE) %>%
  
  ggplot() +
  geom_col(aes(x = date, y = cases_new), fill = "orange", alpha = 0.5)+
  geom_line(aes(x = date, y = cases_new_7d), color = "orange", alpha = 1) +
  labs(title = "New COVID-19 Cases Reported by Date in the United States",
       subtitle = "Lines are 7-day averages values; bars are daily values",
       x = "Date",
       y = "")

combined_cases.plot

If you look at the shape of the cases, hospitalizations, and deaths plots, you can see that they are somewhat related and have relative peaks (local maxima) at similar places but with a lag (cases before hospitalizations before deaths). Here are two charts that show deaths and cases and deaths and hospitalizations on the same plot. ggplot doesn’t support showing two unrelated variables on a dual-axis chart so keep in mind how we set the scale transformation for adding deaths to the cases and hospitalizations plot.

# assign ratios for axis transformation
max_deaths <- max(national_7d$deaths_new_7d, na.rm = TRUE)
max_cases <- max(national_7d$cases_new_7d, na.rm = TRUE)
max_hosps <- max(national_7d$hosp_current_7d, na.rm = TRUE)

case_death_ratio <- max_deaths / max_cases
hosp_death_ratio <- max_deaths / max_hosps
# plot cases and deaths together, with a secondary axis
cases_deaths.plot <- national_7d %>%
  ggplot() +
  geom_line(aes(x = date, y = cases_new_7d, color = "Cases")) +
  geom_line(aes(x = date, y = deaths_new_7d / case_death_ratio, color = "Deaths")) +
  
  scale_y_continuous(
    name = "New Cases",
    sec.axis = sec_axis(~.*case_death_ratio, name="New Deaths"))
cases_deaths.plot

# hospitalizations and deaths
hosp_deaths.plot <- national_7d %>%
  ggplot() +
  geom_line(aes(x = date, y = hosp_current_7d, color = "Hospitalizations")) +
  geom_line(aes(x = date, y = deaths_new_7d / hosp_death_ratio, color = "Deaths")) +
  
  scale_y_continuous(
    name = "Current Hospitalizations",
    sec.axis = sec_axis(~.*hosp_death_ratio, name="New Deaths"))
hosp_deaths.plot

Modeling

From here, we’re going to focus on understanding what factors (tests, cases, hospitalizations) best predict deaths and on what lag. For example, do the number of cases today best predict deaths in two weeks or three? To do this, we’ll make a few transformations:

ctp <- ctp %>%
  select(state, date, population, ends_with("_7d")) %>%
  mutate(case_test_ratio = ifelse(tests_new_7d > 0, cases_new_7d / tests_new_7d, NA)) %>%
  mutate_at(~ . / (population / 1000000), .vars = vars(contains("_7d"))) %>%
  rename_with(.fn = ~paste0(.,"_perM"), .cols = (contains("_7d")))

Next, we’re going to define functions for data preparation, regression, cross-validation, and plotting of results. You’ll see how these work in a little bit.

# create function for lagging cases, hospitalizations, and testing ratios
lag_vars <- function(df, case_lag, hosp_lag, test_lag, death_lag) {

  df <- ctp %>%
    mutate(month = month(date),
           month = log(month)) %>%
    arrange(state, date) %>%
    group_by(state) %>%
    mutate(cases_lag_7d_perM = lag(cases_new_7d_perM, n = case_lag),
           hosp_lag_7d_perM = lag(hosp_current_7d_perM, n = hosp_lag),
           testing_lag_7d_perM = lag(case_test_ratio, n = test_lag),
           death_lag_7d_perM = lag(deaths_new_7d_perM, n = death_lag)) %>%
    ungroup() 
  
  df
}

# function for fitting a linear regression
fit_linear_model <- function(df, f){
    lm_mod <- linear_reg() %>%
      set_engine("lm") %>%
      set_mode("regression")
  
  lm_fit <- lm_mod %>%
    fit(as.formula(f), data = df)
  
  # returns the original model dataframe with predicted values
  prediction <- predict(lm_fit, df)
  df <- bind_cols(df, prediction)
}

# function for getting the cross-validation results for RMSE
fit_cv_model <- function(df, n_folds, f) {
  
  # set engine again
  lm_mod <- linear_reg() %>%
      set_engine("lm") %>%
      set_mode("regression")
  
  # cv folds
  set.seed(0)
  folds <- vfold_cv(data = df, v = n_folds)
  
  # cv results
  cv_fit <- fit_resamples(lm_mod,
                          as.formula(f),
                          resamples = folds,
                          control = control_resamples(save_pred = TRUE))
  
  # show rmse across the folds
  cv_res <- cv_fit %>%
    collect_metrics() 
  
  return(cv_res)

}

# plot predictions for selected states
plot_state_results <- function(df, states) {
  df %>%
    filter(state %in% states) %>%
    
    ggplot() +
    geom_line(aes(x = date, y = deaths_new_7d_perM, color = "Actual")) +
    geom_line(aes(x = date, y = .pred, color = "Predicted")) +
    facet_wrap(~state, scales = "free") +
    labs(y = "Daily Deaths (7-Day Average) per Million",
           x = "Date",
           title = "COVID-19 Deaths in Selected States",
           subtitle = "Results of linear model predicting deaths based on cases, tests, and hospitalizations") + 
      theme(legend.title = element_blank())
}

# plot national results
plot_national_results <- function(df) {
  df %>%
    group_by(date) %>%
    summarize(predicted = sum(.pred*(population/1000000), na.rm = TRUE),
              actual = sum(deaths_new_7d_perM*(population/1000000), na.rm = TRUE)) %>%

    ggplot() +
      geom_line(aes(x = date, y = actual, color = "Actual")) +
      geom_line(aes(x = date, y = predicted, color = "Predicted")) +
      labs(y = "Daily Deaths (7-Day Average)",
           x = "Date",
           title = "COVID-19 Deaths in the United States",
           subtitle = "Results of linear model predicting deaths based on cases, tests, and hospitalizations") + 
      theme(legend.title = element_blank())
}

Model 1: Linear fit using 14-day-lagged cases

Fit a simple linear model of deaths on cases two weeks ago. How did this model seem to do?

# prepare data
model_df <- lag_vars(ctp, case_lag = 14, hosp_lag = 14, test_lag = 14, death_lag = 14)

# define model formula and name
lm_formula <- "deaths_new_7d_perM ~ cases_lag_7d_perM"
model_name <- "1a: Cases (14)"

# fit linear regression model
results_df <- fit_linear_model(df = model_df, f = lm_formula)

# get cross-validation RMSE
cv_res <- fit_cv_model(df = model_df, n_folds = 10, f = lm_formula) %>%
  mutate(model = model_name)

cv_res
## # A tibble: 2 x 7
##   .metric .estimator  mean     n std_err .config              model        
##   <chr>   <chr>      <dbl> <int>   <dbl> <fct>                <chr>        
## 1 rmse    standard   3.09     10  0.0739 Preprocessor1_Model1 1a: Cases (1…
## 2 rsq     standard   0.333    10  0.0113 Preprocessor1_Model1 1a: Cases (1…
# add to an overall metrics comparison
model_comp <- tibble()
model_comp <- model_comp %>%
  bind_rows(cv_res)

# plot predictions for a few states
plot_state_results(results_df, states = c("AZ", "CA", "FL", "NY", "ND", "MA"))

# plot predictions for the country overall
plot_national_results(results_df)

The fit is pretty bad early on in the pandemic. What if we added a control for month?

# define model formula and name
lm_formula <- "deaths_new_7d_perM ~ cases_lag_7d_perM + month"
model_name <- "1b: Cases (14) + month"

# fit linear regression model
results_df <- fit_linear_model(df = model_df, f = lm_formula)

# get cross-validation RMSE
cv_res <- fit_cv_model(df = model_df, n_folds = 10, f = lm_formula) %>%
  mutate(model = model_name)

cv_res
## # A tibble: 2 x 7
##   .metric .estimator  mean     n std_err .config           model           
##   <chr>   <chr>      <dbl> <int>   <dbl> <fct>             <chr>           
## 1 rmse    standard   2.93     10 0.0653  Preprocessor1_Mo… 1b: Cases (14) …
## 2 rsq     standard   0.399    10 0.00934 Preprocessor1_Mo… 1b: Cases (14) …
# add to overall metrics comparison
model_comp <- model_comp %>%
  bind_rows(cv_res)

# plot national results
plot_national_results(results_df)

That looks a little bit better and the RMSE went down a bit.

Model 2: Add hospitalizations and testing ratios, months

Let’s see what happens when we add two other variables to the linear regression. Just based on my intuition, I’m going to set the lag for cases and testing at 21 days and the lag for hospitalizations at 14 days. Also using the natural log of month as a predictor, since our first model was so off early on. Let’s go!

# model stuff
lm_formula <- "deaths_new_7d_perM ~ cases_lag_7d_perM + 
        hosp_lag_7d_perM + testing_lag_7d_perM + death_lag_7d_perM +
        month"
model_name <- "2a: Cases(21) + hosp (14) + test(14) + deaths(24) +  month"

# lag our variables
model_df <- lag_vars(ctp, case_lag = 21, hosp_lag = 14, test_lag = 21, death_lag = 14)

# fit model using function and get predicted values
results_df <- fit_linear_model(df = model_df, f = lm_formula)

# get cross-validation RMSE
cv_res <- fit_cv_model(df = model_df, n_folds = 10, f = lm_formula) %>%
  mutate(model = model_name)

cv_res
## # A tibble: 2 x 7
##   .metric .estimator  mean     n std_err .config      model                
##   <chr>   <chr>      <dbl> <int>   <dbl> <fct>        <chr>                
## 1 rmse    standard   2.02     10  0.0833 Preprocesso… 2a: Cases(21) + hosp…
## 2 rsq     standard   0.726    10  0.0179 Preprocesso… 2a: Cases(21) + hosp…
# add to overall metrics comparison
model_comp <- model_comp %>%
  bind_rows(cv_res)

# plot national results
plot_national_results(results_df)

Model 3: Test different values for lags

Here, we’re going to loop through different values for the lag of cases, hospitalizations, and testing to find the combination that gives us the best fit, as define by having the lowest rmse - root mean squared error. Note that this code takes about 15 minutes to run.

While it’s intuitive that this might be telling us the actual lag between infection and reported deaths, it’s actually just seeing which values are more predictive once a model has been trained on them. For our final model, we want sufficient visibility to predict deaths up to two weeks in advance, so we won’t accept lags of less than 14 days.

start <- Sys.time()

# define formula
lm_formula <- "deaths_new_7d_perM ~ cases_lag_7d_perM + 
        hosp_lag_7d_perM + testing_lag_7d_perM + death_lag_7d_perM +
        month"

# empty tibble to populate results
all_lag_models <- tibble()

# model_df preparation except for lags
pre_model_df <- ctp %>%
  mutate(month = month(date),
         month = log(month)) %>%
  arrange(state, date) %>%
  group_by(state)

# define sequence for loop (test combos of 14, 21, and 28 days)
seq <- list(14, 21, 28)

# test each combination of lags for cases, hosps, testing, and deaths
for(case_i in seq) {
  for(hosp_i in seq) {
    for(test_i in seq) {
      for(death_i in seq) {
        
        # lag variables
        model_df <- pre_model_df %>%
          mutate(cases_lag_7d_perM = lag(cases_new_7d_perM, n = case_i),
             hosp_lag_7d_perM = lag(hosp_current_7d_perM, n = hosp_i),
             testing_lag_7d_perM = lag(case_test_ratio, n = test_i),
             death_lag_7d_perM = lag(deaths_new_7d_perM, n = death_i)) %>%
          ungroup()
        
        print(paste0("Case lag = ", case_i)) # printing counter (in console only)
        print(paste0("Hospitalization lag = ", hosp_i))
        print(paste0("Testing lag = ", test_i))
        print(paste0("Deaths lag = ", death_i))
        print("-----------------------------------------")
        
        # fit model
        results_df <- fit_linear_model(df = model_df, f = lm_formula)
        
        # calculate RMSE
        cv_res <- fit_cv_model(df = model_df, n_folds = 10, f = lm_formula) %>%
          filter(.metric == "rmse")

        # add to all_models - note which model iteration it is
        cv_res$iter <- paste(case_i, hosp_i, test_i, death_i, sep = ".")
        all_lag_models <- bind_rows(all_lag_models, cv_res)
      }
    }
  }
}

Sys.time() - start
# check solution
solution <- all_lag_models %>%
  filter(mean == min(mean, na.rm = TRUE))

solution
## # A tibble: 1 x 7
##   .metric .estimator  mean     n std_err .config              iter       
##   <chr>   <chr>      <dbl> <int>   <dbl> <fct>                <chr>      
## 1 rmse    standard    1.32    10  0.0239 Preprocessor1_Model1 14.28.14.14

Now that we’ve extracted the value with the lowest rmse, let’s plug it into our model and plot.

# setup
model_name <- "Optimized lag"

# lag our variables
model_df <- lag_vars(ctp, case_lag = 14, hosp_lag = 28, test_lag = 14, death_lag = 14)

# fit model
results_df <- fit_linear_model(model_df, f = lm_formula)

# calculate RMSE
cv_res <- fit_cv_model(df = model_df, n_folds = 10, f = lm_formula) %>%
  mutate(model = model_name)

# add to overall metrics comparison
model_comp <- model_comp %>%
  bind_rows(cv_res)

# plot national results
plot_national_results(results_df)

Model 4: Find the best date to cut off

As we’ve seen before, fitting predictions that match the three peaks in actual deaths (April, July, and December) is difficult. Part of this is due to the fact that the limited availability of testing for COVID-19 in the spring meant that cases when severely uncounted. On top of that, many states were not reporting hospitalizations.

Because our main goal is to predict future COVID-19 deaths, it’s more important that our predictions are accurate than that they match the curve of the pandemic since the beginning. One thing we can test is whether excluding observations before a certain date can improve our accuracy. Below, we test all dates from January to present to see whether setting a more recent starting point will improve the accuracy. The best model will be the one that minimizes RMSE. We store this date in best_date. This code takes a while to run.

# prepare model_df
model_df <- lag_vars(ctp, case_lag = 14, hosp_lag = 28, test_lag = 14, death_lag = 14)

# empty df for populating the results of the upcoming date models
all_date_models <- tibble()

# set series of dates to loop through
min_date <- min(model_df$date)
max_date <- max(model_df$date)

d <- min_date
while(d < max_date) {
  
  model_df <- model_df %>% 
    filter(date > d) # selectively drop observations occurring before a certain date
  
  # fit model
  results_df <- fit_linear_model(model_df, f = lm_formula)

  # get cross-validation RSME
  cv_res <- fit_cv_model(df = model_df, n_folds = 10, f = lm_formula)

  # calculate RMSE
  error <- rmse(results_df, deaths_new_7d_perM, .pred)
      
  # add to all_models
  error$iter <- d
  all_date_models <- bind_rows(all_date_models, error)
  
  d <- d +1
}

# find solution that minimizes the RMSE
solution <- all_date_models %>%
  filter(.estimate == min(.estimate))
solution
## # A tibble: 1 x 4
##   .metric .estimator .estimate iter      
##   <chr>   <chr>          <dbl> <date>    
## 1 rmse    standard        1.14 2020-06-03
best_date <- solution$iter

Try plotting this now. We hard-code the best_date value as `"2020-06-03".

# setup
model_name <- "Date Cut-off"

best_date <- "2020-06-03"

model_df <- lag_vars(ctp, case_lag = 14, hosp_lag = 28, test_lag = 14, death_lag = 14) %>%
  filter(date > best_date)

# fit model
results_df <- fit_linear_model(model_df, f = lm_formula)

# get cross-validation RSME
cv_res <- fit_cv_model(df = model_df, n_folds = 10, f = lm_formula) %>%
  mutate(model = model_name)
cv_res
## # A tibble: 2 x 7
##   .metric .estimator  mean     n std_err .config              model       
##   <chr>   <chr>      <dbl> <int>   <dbl> <fct>                <chr>       
## 1 rmse    standard   1.14     10 0.0307  Preprocessor1_Model1 Date Cut-off
## 2 rsq     standard   0.835    10 0.00609 Preprocessor1_Model1 Date Cut-off
# add to overall metrics comparison
model_comp <- model_comp %>%
  bind_rows(cv_res)

# plot national results
plot_national_results(results_df) +
  labs(caption = paste0("Modeled since ", best_date))

# plot predictions for a few states
plot_state_results(results_df, c("AZ", "CA", "NY", "MA", "ME", "ND")) +
  labs(caption = paste0("Modeled since ", best_date))

model_comp
## # A tibble: 10 x 7
##    .metric .estimator  mean     n std_err .config     model                
##    <chr>   <chr>      <dbl> <int>   <dbl> <fct>       <chr>                
##  1 rmse    standard   3.09     10 0.0739  Preprocess… 1a: Cases (14)       
##  2 rsq     standard   0.333    10 0.0113  Preprocess… 1a: Cases (14)       
##  3 rmse    standard   2.93     10 0.0653  Preprocess… 1b: Cases (14) + mon…
##  4 rsq     standard   0.399    10 0.00934 Preprocess… 1b: Cases (14) + mon…
##  5 rmse    standard   2.02     10 0.0833  Preprocess… 2a: Cases(21) + hosp…
##  6 rsq     standard   0.726    10 0.0179  Preprocess… 2a: Cases(21) + hosp…
##  7 rmse    standard   1.32     10 0.0239  Preprocess… Optimized lag        
##  8 rsq     standard   0.850    10 0.00805 Preprocess… Optimized lag        
##  9 rmse    standard   1.14     10 0.0307  Preprocess… Date Cut-off         
## 10 rsq     standard   0.835    10 0.00609 Preprocess… Date Cut-off

Model Comparison

How have our different models fared?

model_comp %>%
  arrange(.metric, mean) %>%
  knitr::kable() 
.metric .estimator mean n std_err .config model
rmse standard 1.1397906 10 0.0307319 Preprocessor1_Model1 Date Cut-off
rmse standard 1.3230052 10 0.0239420 Preprocessor1_Model1 Optimized lag
rmse standard 2.0243729 10 0.0833336 Preprocessor1_Model1 2a: Cases(21) + hosp (14) + test(14) + deaths(24) + month
rmse standard 2.9331421 10 0.0653004 Preprocessor1_Model1 1b: Cases (14) + month
rmse standard 3.0875804 10 0.0739182 Preprocessor1_Model1 1a: Cases (14)
rsq standard 0.3330637 10 0.0112617 Preprocessor1_Model1 1a: Cases (14)
rsq standard 0.3991520 10 0.0093400 Preprocessor1_Model1 1b: Cases (14) + month
rsq standard 0.7260096 10 0.0178580 Preprocessor1_Model1 2a: Cases(21) + hosp (14) + test(14) + deaths(24) + month
rsq standard 0.8353949 10 0.0060920 Preprocessor1_Model1 Date Cut-off
rsq standard 0.8504139 10 0.0080522 Preprocessor1_Model1 Optimized lag