In this post I’ll describe how to build a nice cohort graph/table in R.
Updated in 2020: used some reader’s comments – thank you to make it even better (and I fixed the layout). Happy hacking!
Data source that I’m using is from a mobile app that is tagged with Snowplow Analytics, but if you follow all the steps, using a similar dataset, you’ll get the same result, which looks something like this:
Full code is below, or in this gist, for easier cut/paste: https://gist.github.com/zjuul/cac3f6e210cc64938bb5740d2920897f
Step 1: getting the data
What you need is a table with 2 columns.
- user ID
- date (I use a weekly bucket, “year – week”)
To retrieve this data from Snowplow, I use the following query:
appname <- "NAME" # used as selection criterium, replace that with the name of your app q <- paste0(" SELECT user_id, TO_CHAR(CONVERT_TIMEZONE('UTC', 'Europe/Amsterdam', derived_tstamp),'YYYY ww') AS yw FROM atomic.events WHERE app_id = '", appname, "' AND derived_tstamp >= '2018-01-01' GROUP BY 1,2; ") conn <- dbConnect(driver, url) # make the database connection - this is from the RJDBC package (library(RJDBC)) dbdata <- dbGetQuery(conn, q) # get the data dbDisconnect(conn) # close the connection
The dataset can become quite large, since all users x all weeks are retrieved. In my case, the data looks like this (actual rows are about 85.000).
user_id yw 1 dafcae8d19d42d08af6857592dba74efee3cfdf2 2018 09 2 341a909eeace656fa8e4fe24386187988ac2bc58 2018 09 3 871b95c5ff010b9adffef943ffc55d23182d0196 2018 09 4 87be7eed91b19c26bdb0bb9bdb95b06d5c1ac464 2018 09 5 14f96795fb6a7b25568a96dc72f85a268d9aa380 2018 09 6 4569480718013de76ca2c58bf7dc8a8c160a3790 2018 09
If you don’t use Snowplow, or don’t want to use larger or smaller timeframes: get the data somewhere else, but shape it into a similar format.
Step 2: wrangle the data
Thanks to the excellent tidyverse package set, the actual work is done in only a couple of lines of code.
library(tidyverse) cohort <- dbdata %>% # store in cohort table, get from dbdata group_by(user_id) %>% # group all users together mutate(first = min(yw)) %>% # for every user, find the first period group_by(first, yw) %>% # group by this first period + the other periods summarise(users = n_distinct(user_id)) %>% # for each combination, count the number of users spread(yw, users) # and make columns with period names
The last line is actually where the magic happens, and what creates the actual cohort table. After this, it’s all cosmetics :)
Before the spread, the data looks like this:
# A tibble: 6 x 3 # Groups: first [1] first yw users <chr> <chr> <int> 1 2018 01 2018 01 9119 2 2018 01 2018 02 5767 3 2018 01 2018 03 5628 4 2018 01 2018 04 6098 5 2018 01 2018 05 5869 6 2018 01 2018 06 5574
After, it looks like this:
> head(cohort) # A tibble: 6 x 11 # Groups: first [6] first `2018 01` `2018 02` `2018 03` `2018 04` `2018 05` `2018 06` `2018 07` `2018 08` `2018 09` <chr> <int> <int> <int> <int> <int> <int> <int> <int> <int> 1 2018 01 9119 5767 5628 6098 5869 5574 5499 5957 6151 2 2018 02 NA 2314 996 1186 1133 1069 1077 1180 1245 3 2018 03 NA NA 1227 571 541 493 499 579 590 4 2018 04 NA NA NA 1077 424 397 403 483 528 5 2018 05 NA NA NA NA 643 246 215 262 306 6 2018 06 NA NA NA NA NA 618 205 211 208
Step 3: cosmetics, percentages, etc
Ok, a normal cohort table is aligned to the left, and column names should reflect that. So let’s get to work, and do some R magic.
The logic is this: take the cohort table, and align the columns to the left, starting from column 2 (first column is the “first” column)
shiftrow <- function(v) { # put a vector in, strip off leading NA values, and place that amount at the end first_na_index <- min( which(!is.na(v)) ) # return that bit to the end, and pad with NAs. c(v[first_na_index:length(v)], rep(NA, first_na_index-1)) } # create a new dataframe, with shifted rows (and keep the first one) shifted <- data.frame( cohort = cohort$first, t(apply( select(as.data.frame(cohort), 2:ncol(cohort)), # 2nd column to the end 1, # for every row shiftrow )) ) # and make column names readable # first should be "cohort" and the rest week.<number>, (padded) colnames(shifted) <- c("cohort", sub("","week.", str_pad(1:(ncol(shifted)-1),2,pad = "0")))
This does the trick. We now have a table that looks like this:
cohort week.01 week.02 week.03 week.04 week.05 week.06 week.07 week.08 week.09 week.10 1 2018 01 9137 5772 5640 6108 5878 5580 5503 5958 6152 4388 2 2018 02 2320 998 1189 1136 1069 1077 1180 1245 825 NA 3 2018 03 1233 570 543 493 498 579 590 400 NA NA 4 2018 04 1081 424 398 404 483 528 327 NA NA NA 5 2018 05 647 246 215 263 306 218 NA NA NA NA 6 2018 06 619 205 211 208 157 NA NA NA NA NA
Next up: percentages.. we want every week to be expressed as a percentage of week.01. Let’s do this.
We’ll create a new table for this. We divide all week-columns by week.01 of that row.
shifted_pct <- data.frame( cohort = shifted$cohort, # first column shifted[,2:nrow(shifted)+1] / shifted[["week.01"]] # rest: divide by week.01 )
Result is this: (and of course, week 1 = 100%)
cohort week.01 week.0 cohort week.01 week.02 week.03 week.04 week.05 week.06 week.07 week.08 week.09 1 2018 01 1 0.6317172 0.6172704 0.6684908 0.6433184 0.6107037 0.6022765 0.6520740 0.6733063 2 2018 02 1 0.4301724 0.5125000 0.4896552 0.4607759 0.4642241 0.5086207 0.5366379 0.3556034 3 2018 03 1 0.4622871 0.4403893 0.3998378 0.4038929 0.4695864 0.4785077 0.3244120 NA 4 2018 04 1 0.3922294 0.3681776 0.3737280 0.4468085 0.4884366 0.3024977 NA NA 5 2018 05 1 0.3802164 0.3323029 0.4064915 0.4729521 0.3369397 NA NA NA 6 2018 06 1 0.3311793 0.3408724 0.3360258 0.2536349 NA NA NA NA
Tataa.. we’re there. These tables can be printed, exported to a spreadsheet, or whatever. They look kind of neat, and aren’t too long (depending on the timeframe).
But.. I promised you a plot. And it comes (as all plots do) at the end.
Step 4: the plot
Ok, now that we have the pretty table, we’re going to reshape it again. Because ggplot likes long format.
We’re going to melt (or rather: gather) the data first, and then we’re going to write some code to produce nice labels.
# ggplot loves long data. Let's melt it. One for the absolute values, one for the pcts plotdata_abs <- gather(shifted, "cohort_age", "people" ,2:ncol(shifted )) plotdata_pct <- gather(shifted_pct, "cohort_age", "percent" ,2:ncol(shifted_pct)) # now add some data.. we need pretty labels.. # first bit is the length of the width of the wide column (minus 1, that's the cohort name) # that contains the absolute numbers # last bit is the rest, those are percentages. labelnames <- c( plotdata_abs$people[1:(ncol(shifted)-1)], plotdata_pct$percent[(ncol(shifted)):(nrow(plotdata_pct))]) # we need pretty labels. pretty_print <- function(n) { case_when( n <= 1 ~ sprintf("%1.0f %%", n*100), n > 1 ~ as.character(n), TRUE ~ " ") # for NA values, skip the label } # create the plot data plotdata <- data.frame( cohort = plotdata_pct$cohort, cohort_age = plotdata_pct$cohort_age, percentage = plotdata_pct$percent, label = pretty_print(labelnames) )
Ok. We have long-form data with 4 columns. The cohort column is used on the y axis, and the cohort_age on the x axis.
Finally, we’re using the percentages for the fill colouring, and the label column for.. labels.
# plot (with reordered y axis, oldest group on top) # optional: if the percentages are really low, replace the 1.0 in the first column with zero plotdata[which(plotdata$percentage == 1), "percentage"] <- 0 ggplot(plotdata, aes(x = cohort_age, y = reorder(cohort, desc(cohort)))) + geom_raster(aes(fill = percentage)) + scale_fill_continuous(guide = FALSE) + # no legend geom_text(aes(label = label), color = "white") + xlab("cohort age") + ylab("cohort") + ggtitle(paste("Retention table (cohort) for",appname, "app")) # the end
Step 5: look at the graph, and pat yourself on the back
This is the result of my dataset. Doesn’t it look pretty?
Please let me know if this was useful, and if it worked out for you!
Roel says
This is really awesome, thank you for sharing and teaching me how to do it.
Julia says
Thank you. This is great! I am wondering how you can do the fill more prominent in the plot?
jules says
Hi Julia, You can make the 100% of the first period a but lower. For example, if all your other percentages are max 9%, do this before plotting:
plotdata[which(plotdata$percentage == 1), “percentage”] = 0.1
Bastien says
Nice approach – thanks for this.
I’ve done une modification where you create the shifted_pct.
I’ve put “shifted[,2:(nrow(shifted)+1)] / shifted[[“week.01″]] # rest: divide by week.01” in order to get all the periods. On your plot, you’re missing week10.
jules says
Fixed, thank you Bastien
sue says
Thanks! It’s perfect~, helping me a lot
roshan says
can you please share the sample data ?
Roshan says
hey i think there is some issue in wrangle the data step
it should be something like
cohort % # store in cohort table, get from dbdata
group_by(user_id) %>% # group all users together
mutate(first = min(yw)) %>% # for every user, find the first period
group_by(first, yw) %>% # group by this first period + the other periods
summarise(users = n_distinct(user_id)) %>% # for each combination, distinct count the number of users
spread(yw , users)
jules says
Thanks Roshan, that makes perfect sense. I was just using it the other day and did it exactly like you described. Can’t believe this n() has been in here so long. Thanks for paying attention!
mr_z_ro says
Thanks for this! A quick note that if you want the chart presented as the final result, you will have to include the first week in `shifted_pct`. Otherwise, your x labels will be off (week.01 will be labeled as week.02, etc, and your labels will be color coded with the rest of the chart).
In other words, you’ll want to change
“`
shifted_pct <- data.frame(
cohort = shifted$cohort, # first column
shifted[,2:nrow(shifted)+1] / shifted[["month.01"]] # rest: divide by week.01
)
“`
to
“`
shifted_pct <- data.frame(
cohort = shifted$cohort, # first column
shifted[,1:nrow(shifted)+1] / shifted[["month.01"]] # rest: divide by week.01
)
“`
jules says
Thanks! Given the popularity of this post, I think it’s time to update it – I have much better versions of this code lying around.. :)
CM says
I’d love an updated version if you have one, this is fantastic and a very common/repeated type of analysis for many of us who have landed on this page I’m sure.
Vyach says
Hihi! I think that there is a mistake in raw [40]: there should be no transposing operation, i.e. no t() function should be applied.
As a result in the final Retention table the columns and rows are mixed up. For example in your table the “week 2” retention of cohort 10 (2018-09) is only 22%. But according to the data manipulations that you did previously this should be the “week 10” retention of cohort 1.
When you get rid of transposing you get the right output: the retention for all cohorts starts from pretty high value in “week 2” (> 60%) and gradually decreases in the next weeks.