Stuifbergen.com

Analytics & Growth Hacking

  • About
  • Contact
  • Professional Services

Find me here

  • Mastodon (social)
  • Mastodon (measure)
  • No more Twitter
You are here: Home / Blog / Doing Cohort Analysis in R (using ggplot)

Doing Cohort Analysis in R (using ggplot)

9 March 2018 by jules 13 Comments

In this post I’ll describe how to build a nice cohort graph/table in R.

Updated in 2020: used some reader’s comments – thank you to make it even better (and I fixed the layout). Happy hacking!

Data source that I’m using is from a mobile app that is tagged with Snowplow Analytics, but if you follow all the steps, using a similar dataset, you’ll get the same result, which looks something like this:

Full code is below, or in this gist, for easier cut/paste: https://gist.github.com/zjuul/cac3f6e210cc64938bb5740d2920897f

Step 1: getting the data

What you need is a table with 2 columns.

  1. user ID
  2. date (I use a weekly bucket, “year – week”)

To retrieve this data from Snowplow, I use the following query:

appname <- "NAME" # used as selection criterium, replace that with the name of your app

q <- paste0("
  SELECT
  user_id,
  TO_CHAR(CONVERT_TIMEZONE('UTC', 'Europe/Amsterdam', derived_tstamp),'YYYY ww') AS yw
  FROM atomic.events
  WHERE app_id = '", appname, "' AND derived_tstamp >= '2018-01-01'
  GROUP BY 1,2;
")

conn <- dbConnect(driver, url) # make the database connection - this is from the RJDBC package (library(RJDBC))
dbdata <- dbGetQuery(conn, q)  # get the data
dbDisconnect(conn)             # close the connection

The dataset can become quite large, since all users x all weeks are retrieved. In my case, the data looks like this (actual rows are about 85.000).

                                   user_id      yw
1 dafcae8d19d42d08af6857592dba74efee3cfdf2 2018 09
2 341a909eeace656fa8e4fe24386187988ac2bc58 2018 09
3 871b95c5ff010b9adffef943ffc55d23182d0196 2018 09
4 87be7eed91b19c26bdb0bb9bdb95b06d5c1ac464 2018 09
5 14f96795fb6a7b25568a96dc72f85a268d9aa380 2018 09
6 4569480718013de76ca2c58bf7dc8a8c160a3790 2018 09

If you don’t use Snowplow, or don’t want to use larger or smaller timeframes: get the data somewhere else, but shape it into a similar format.

Step 2: wrangle the data

Thanks to the excellent tidyverse package set, the actual work is done in only a couple of lines of code.

library(tidyverse)

cohort <- dbdata %>%                         # store in cohort table, get from dbdata
  group_by(user_id) %>%                      # group all users together
  mutate(first = min(yw)) %>%                # for every user, find the first period
  group_by(first, yw) %>%                    # group by this first period + the other periods
  summarise(users = n_distinct(user_id)) %>% # for each combination, count the number of users
  spread(yw, users)                          # and make columns with period names 


The last line is actually where the magic happens, and what creates the actual cohort table. After this, it’s all cosmetics :)

Before the spread, the data looks like this:

# A tibble: 6 x 3
# Groups:   first [1]
    first      yw users
    <chr>   <chr> <int>
1 2018 01 2018 01  9119
2 2018 01 2018 02  5767
3 2018 01 2018 03  5628
4 2018 01 2018 04  6098
5 2018 01 2018 05  5869
6 2018 01 2018 06  5574

After, it looks like this:

> head(cohort)
# A tibble: 6 x 11
# Groups:   first [6]
    first `2018 01` `2018 02` `2018 03` `2018 04` `2018 05` `2018 06` `2018 07` `2018 08` `2018 09`
    <chr>     <int>     <int>     <int>     <int>     <int>     <int>     <int>     <int>     <int>
1 2018 01      9119      5767      5628      6098      5869      5574      5499      5957      6151
2 2018 02        NA      2314       996      1186      1133      1069      1077      1180      1245
3 2018 03        NA        NA      1227       571       541       493       499       579       590
4 2018 04        NA        NA        NA      1077       424       397       403       483       528
5 2018 05        NA        NA        NA        NA       643       246       215       262       306
6 2018 06        NA        NA        NA        NA        NA       618       205       211       208

Step 3: cosmetics, percentages, etc

Ok, a normal cohort table is aligned to the left, and column names should reflect that. So let’s get to work, and do some R magic.

The logic is this: take the cohort table, and align the columns to the left, starting from column 2 (first column is the “first” column)

shiftrow <- function(v) {
  # put a vector in, strip off leading NA values, and place that amount at the end
  first_na_index <- min( which(!is.na(v)) )
  
  # return that bit to the end,  and pad with NAs.
  c(v[first_na_index:length(v)], rep(NA, first_na_index-1))
}

# create a new dataframe, with shifted rows (and keep the first one)
shifted <- data.frame(
    cohort = cohort$first,
    t(apply( select(as.data.frame(cohort), 2:ncol(cohort)), # 2nd column to the end
           1, # for every row
           shiftrow ))
)

# and make column names readable
# first should be "cohort" and the rest week.<number>, (padded)
colnames(shifted) <- c("cohort", sub("","week.", str_pad(1:(ncol(shifted)-1),2,pad = "0")))

This does the trick. We now have a table that looks like this:

   cohort week.01 week.02 week.03 week.04 week.05 week.06 week.07 week.08 week.09 week.10
1 2018 01    9137    5772    5640    6108    5878    5580    5503    5958    6152    4388
2 2018 02    2320     998    1189    1136    1069    1077    1180    1245     825      NA
3 2018 03    1233     570     543     493     498     579     590     400      NA      NA
4 2018 04    1081     424     398     404     483     528     327      NA      NA      NA
5 2018 05     647     246     215     263     306     218      NA      NA      NA      NA
6 2018 06     619     205     211     208     157      NA      NA      NA      NA      NA

Next up: percentages.. we want every week to be expressed as a percentage of week.01. Let’s do this.

We’ll create a new table for this. We divide all week-columns by week.01 of that row.

shifted_pct <- data.frame(
  cohort = shifted$cohort, # first column
  shifted[,2:nrow(shifted)+1] / shifted[["week.01"]] # rest: divide by week.01
)


Result is this: (and of course, week 1 = 100%)

   cohort week.01   week.0   cohort week.01   week.02   week.03   week.04   week.05   week.06   week.07   week.08   week.09
1 2018 01       1 0.6317172 0.6172704 0.6684908 0.6433184 0.6107037 0.6022765 0.6520740 0.6733063
2 2018 02       1 0.4301724 0.5125000 0.4896552 0.4607759 0.4642241 0.5086207 0.5366379 0.3556034
3 2018 03       1 0.4622871 0.4403893 0.3998378 0.4038929 0.4695864 0.4785077 0.3244120        NA
4 2018 04       1 0.3922294 0.3681776 0.3737280 0.4468085 0.4884366 0.3024977        NA        NA
5 2018 05       1 0.3802164 0.3323029 0.4064915 0.4729521 0.3369397        NA        NA        NA
6 2018 06       1 0.3311793 0.3408724 0.3360258 0.2536349        NA        NA        NA        NA

Tataa.. we’re there. These tables can be printed, exported to a spreadsheet, or whatever. They look kind of neat, and aren’t too long (depending on the timeframe).

But.. I promised you a plot. And it comes (as all plots do) at the end.

Step 4: the plot

Ok, now that we have the pretty table, we’re going to reshape it again. Because ggplot likes long format.

We’re going to melt (or rather: gather) the data first, and then we’re going to write some code to produce nice labels.

# ggplot loves long data. Let's melt it. One for the absolute values, one for the pcts
plotdata_abs <- gather(shifted,     "cohort_age", "people"  ,2:ncol(shifted    ))
plotdata_pct <- gather(shifted_pct, "cohort_age", "percent" ,2:ncol(shifted_pct))

# now add some data.. we need pretty labels..
# first bit is the length of the width of the wide column (minus 1, that's the cohort name)
# that contains the absolute numbers
# last bit is the rest, those are percentages.
labelnames <- c( plotdata_abs$people[1:(ncol(shifted)-1)],
                 plotdata_pct$percent[(ncol(shifted)):(nrow(plotdata_pct))])

# we need pretty labels.
pretty_print <- function(n) {
  case_when( n <= 1  ~ sprintf("%1.0f %%", n*100),
             n >  1  ~ as.character(n),
             TRUE    ~ " ") # for NA values, skip the label
}

# create the plot data
plotdata <- data.frame(
  cohort     = plotdata_pct$cohort,
  cohort_age = plotdata_pct$cohort_age,
  percentage = plotdata_pct$percent,
  label      = pretty_print(labelnames)
)



Ok. We have long-form data with 4 columns. The cohort column is used on the y axis, and the cohort_age on the x axis.

Finally, we’re using the percentages for the fill colouring, and the label column for.. labels.

# plot (with reordered y axis, oldest group on top)

# optional: if the percentages are really low, replace the 1.0 in the first column with zero
plotdata[which(plotdata$percentage == 1), "percentage"] <- 0

ggplot(plotdata, aes(x = cohort_age, y = reorder(cohort, desc(cohort)))) +
  geom_raster(aes(fill = percentage)) +
  scale_fill_continuous(guide = FALSE) + # no legend
  geom_text(aes(label = label), color = "white") +
  xlab("cohort age") + ylab("cohort") + 
  ggtitle(paste("Retention table (cohort) for",appname, "app"))

# the end

Step 5: look at the graph, and pat yourself on the back

This is the result of my dataset. Doesn’t it look pretty?

Please let me know if this was useful, and if it worked out for you!

Related posts:

Analyse Web Site Click Paths as Processes Howto: add custom schemas to Snowplow GCE (and BigQuery) Visualize Snowplow data with Google Data Studio Catch exit-intent with a Google Tag Manager trigger

Filed Under: Blog Tagged With: analysis, cohort, data visualisation, ggplot, r, snowplow, tutorial

Liked this post?

Buy Me a Coffee

Comments

  1. Roel says

    28 March 2018 at 12:28

    This is really awesome, thank you for sharing and teaching me how to do it.

    Reply
  2. Julia says

    25 July 2018 at 02:10

    Thank you. This is great! I am wondering how you can do the fill more prominent in the plot?

    Reply
    • jules says

      7 May 2020 at 21:37

      Hi Julia, You can make the 100% of the first period a but lower. For example, if all your other percentages are max 9%, do this before plotting:

      plotdata[which(plotdata$percentage == 1), “percentage”] = 0.1

      Reply
  3. Bastien says

    16 December 2018 at 14:21

    Nice approach – thanks for this.
    I’ve done une modification where you create the shifted_pct.

    I’ve put “shifted[,2:(nrow(shifted)+1)] / shifted[[“week.01″]] # rest: divide by week.01” in order to get all the periods. On your plot, you’re missing week10.

    Reply
    • jules says

      8 May 2020 at 09:18

      Fixed, thank you Bastien

      Reply
  4. sue says

    19 December 2018 at 10:51

    Thanks! It’s perfect~, helping me a lot

    Reply
  5. roshan says

    25 November 2020 at 16:05

    can you please share the sample data ?

    Reply
  6. Roshan says

    27 November 2020 at 10:12

    hey i think there is some issue in wrangle the data step

    it should be something like

    cohort % # store in cohort table, get from dbdata
    group_by(user_id) %>% # group all users together
    mutate(first = min(yw)) %>% # for every user, find the first period
    group_by(first, yw) %>% # group by this first period + the other periods
    summarise(users = n_distinct(user_id)) %>% # for each combination, distinct count the number of users
    spread(yw , users)

    Reply
    • jules says

      27 November 2020 at 12:42

      Thanks Roshan, that makes perfect sense. I was just using it the other day and did it exactly like you described. Can’t believe this n() has been in here so long. Thanks for paying attention!

      Reply
  7. mr_z_ro says

    17 March 2021 at 15:28

    Thanks for this! A quick note that if you want the chart presented as the final result, you will have to include the first week in `shifted_pct`. Otherwise, your x labels will be off (week.01 will be labeled as week.02, etc, and your labels will be color coded with the rest of the chart).

    In other words, you’ll want to change

    “`
    shifted_pct <- data.frame(
    cohort = shifted$cohort, # first column
    shifted[,2:nrow(shifted)+1] / shifted[["month.01"]] # rest: divide by week.01
    )
    “`

    to

    “`
    shifted_pct <- data.frame(
    cohort = shifted$cohort, # first column
    shifted[,1:nrow(shifted)+1] / shifted[["month.01"]] # rest: divide by week.01
    )
    “`

    Reply
    • jules says

      17 March 2021 at 15:44

      Thanks! Given the popularity of this post, I think it’s time to update it – I have much better versions of this code lying around.. :)

      Reply
  8. CM says

    4 October 2021 at 18:07

    I’d love an updated version if you have one, this is fantastic and a very common/repeated type of analysis for many of us who have landed on this page I’m sure.

    Reply
  9. Vyach says

    21 April 2022 at 12:01

    Hihi! I think that there is a mistake in raw [40]: there should be no transposing operation, i.e. no t() function should be applied.

    As a result in the final Retention table the columns and rows are mixed up. For example in your table the “week 2” retention of cohort 10 (2018-09) is only 22%. But according to the data manipulations that you did previously this should be the “week 10” retention of cohort 1.

    When you get rid of transposing you get the right output: the retention for all cohorts starts from pretty high value in “week 2” (> 60%) and gradually decreases in the next weeks.

    Reply

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Recent Posts

  • Analyze AB tests in GA4 via Big Query
  • How to make sure your GA4 events tables do not expire
  • Prepare your cloud project for “Bulk data export” Google Search Console
  • Making Sense of the GA4 Configuration Tag
  • Using Big Query to calculate DAU / MAU

Need help?

  • Contact me

Search

© Copyright Jules Stuifbergen · Powered by Genesis ·