Apache Spark [PART 24]: A Little Bit Complicated Cumulative Sum

4 minute read

Published:

Suppose you have a dataframe consisting of several columns, such as the followings:

  • A: group indicator -> call it A value
  • B: has two different categories (b0 and b1) -> call it B value
  • C: let’s assume it contains integers -> call it C value
  • D: date and time -> call it D value
  • E: timestamp -> call it E value

For the sake of clarity, let’s take a look at a simple example of the dataframe.

+---+---+---+-------------------+---+                                           
|  A|  B|  C|                  D|  E|
+---+---+---+-------------------+---+
| a0| b0|100|2018-04-01 08:00:00|  0|
| a0| b1|100|2018-04-01 15:00:00|  5|
| a0| b0|300|2018-04-01 17:00:00|  6|
| a0| b1|100|2018-04-01 17:30:00|  7|
| a0| b1|500|2018-04-01 18:00:00|  8|
| a1| b0|100|2018-04-01 09:00:00|  1|
| a1| b1|500|2018-04-01 09:30:30|  2|
| a1| b1|500|2018-04-01 09:40:40|  3|
| a1| b0|100|2018-04-01 09:50:50|  4|
| a2| b0|300|2018-04-07 08:00:00|  9|
| a2| b0|700|2018-04-07 18:00:00| 17|
| a3| b0|100|2018-04-07 15:00:00| 10|
| a0| b0|300|2018-04-15 08:00:00| 18|
| a0| b0|300|2018-04-15 08:30:30| 19|
| a0| b1|100|2018-04-15 09:00:00| 20|
| a1| b0|100|2018-04-15 17:00:00| 21|
| a1| b0|100|2018-04-15 17:30:30| 22|
| a1| b0|300|2018-04-15 18:00:00| 23|
| a1| b0|500|2018-04-15 18:30:30| 24|
| a1| b1|500|2018-04-15 19:00:00| 25|
| a1| b1|500|2018-04-15 19:30:30| 27|
+---+---+---+-------------------+---+

Now, the task is to compute the cumulative sum of B value in a daily period (1 day). To make it clearer, we group each A value along with its respective B value and D value, and then compute the cumulative sum of B value. Generally, we should end up with the following dataframe.

DAY = 2018-04-01
===========================================================
| A | B | C |          D        | E | daily_b0 | daily_b1 |
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
| a0| b0|100|2018-04-01 08:00:00|  0|    100   |    0
| a1| b0|100|2018-04-01 09:00:00|  1|    100   |    0
| a1| b1|500|2018-04-01 09:30:30|  2|    100   |   500
| a1| b1|500|2018-04-01 09:40:40|  3|    100   |   1000
| a1| b0|100|2018-04-01 09:50:50|  4|    200   |   1000
| a0| b1|100|2018-04-01 15:00:00|  5|    100   |   100
| a0| b0|300|2018-04-01 17:00:00|  6|    400   |   100
| a0| b1|100|2018-04-01 17:30:00|  7|    400   |   200
| a0| b1|500|2018-04-01 18:00:00|  8|    400   |   700


DAY = 2018-04-07
===========================================================
| A | B | C |          D        | E | daily_b0 | daily_b1 |
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
| a2| b0|300|2018-04-07 08:00:00|  9|    300   |    0
| a3| b0|100|2018-04-07 15:00:00| 10|    100   |    0
| a2| b0|700|2018-04-07 18:00:00| 17|    1000  |    0


DAY = 2018-04-15
===========================================================
| A | B | C |          D        | E | daily_b0 | daily_b1 |
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
| a0| b0|300|2018-04-15 08:00:00| 18|    300   |    0
| a0| b0|300|2018-04-15 08:30:30| 19|    600   |    0
| a0| b1|100|2018-04-15 09:00:00| 20|    600   |   100
| a1| b0|100|2018-04-15 17:00:00| 21|    100   |    0
| a1| b0|100|2018-04-15 17:30:30| 22|    200   |    0
| a1| b0|300|2018-04-15 18:00:00| 23|    500   |    0
| a1| b0|500|2018-04-15 18:30:30| 24|    1000  |    0
| a1| b1|500|2018-04-15 19:00:00| 25|    1000  |   500
| a1| b1|500|2018-04-15 19:30:30| 27|    1000  |   1000

I hope you’ve got what the task means.

Now, the golden question is, how to implement such a cumulative sum computation?

I’m going to make it very brief. Here’s the code.

import pyspark.sql.functions as F

from pyspark.sql import DataFrame, Window
from pyspark.sql import SparkSession


def augment_period_window_cols(df, period_window_col_name: str, datetime_format: str, window_length: str):
	df = df.withColumn(
	         period_window_col_name,
                 F.window(F.to_timestamp(F.col('D'), datetime_format), window_length)
	)
	
	return df

def augment_period_based_sum_cols(b_category: str, partition_window: Window):
	condition = F.col('B') == b_category

	sum_col = (
                     F.sum(
		            F.coalesce(
                                        F.when(condition, F.col('C')).otherwise(F.lit(0.0)), 
                                        F.lit(0.0)
                            )
                    )
	)

	return sum_col.over(partition_window)


period_window_col_name = 'PERIOD_WINDOW'
datetime_format = yyyy-MM-dd HH:mm:ss

df = augment_period_window_cols(df, period_window_col_name, datetime_format, 1 day')

partition_window = Window.partitionBy(['A', period_window_col_name])
                         .orderBy('E')
			 .rowsBetween(Window.unboundedPreceding, 0)

b_categories = ['b0', 'b1']

daily_augmented_columns = []
for b_category in b_categories:
      daily_augmented_columns.append(
            (
	        augment_period_based_sum_cols(b_category, partition_window)
            ).alias('daily_{}'.format(b_category))
      )

df = df.select(*df.columns, *daily_augmented_columns).drop(period_window_col_name)

We should got the following dataframe as the result.

+---+---+---+-------------------+---+--------+--------+
|  A|  B|  C|                  D|  E|daily_b0|daily_b1|
+---+---+---+-------------------+---+--------+--------+
| a1| b0|100|2018-04-01 09:00:00|  1|   100.0|     0.0|
| a1| b1|500|2018-04-01 09:30:30|  2|   100.0|   500.0|
| a1| b1|500|2018-04-01 09:40:40|  3|   100.0|  1000.0|
| a1| b0|100|2018-04-01 09:50:50|  4|   200.0|  1000.0|
| a2| b0|300|2018-04-07 08:00:00|  9|   300.0|     0.0|
| a2| b0|700|2018-04-07 18:00:00| 17|  1000.0|     0.0|
| a0| b0|100|2018-04-01 08:00:00|  0|   100.0|     0.0|
| a0| b1|100|2018-04-01 15:00:00|  5|   100.0|   100.0|
| a0| b0|300|2018-04-01 17:00:00|  6|   400.0|   100.0|
| a0| b1|100|2018-04-01 17:30:00|  7|   400.0|   200.0|
| a0| b1|500|2018-04-01 18:00:00|  8|   400.0|   700.0|
| a0| b0|300|2018-04-15 08:00:00| 18|   300.0|     0.0|
| a0| b0|300|2018-04-15 08:30:30| 19|   600.0|     0.0|
| a0| b1|100|2018-04-15 09:00:00| 20|   600.0|   100.0|
| a1| b0|100|2018-04-15 17:00:00| 21|   100.0|     0.0|
| a1| b0|100|2018-04-15 17:30:30| 22|   200.0|     0.0|
| a1| b0|300|2018-04-15 18:00:00| 23|   500.0|     0.0|
| a1| b0|500|2018-04-15 18:30:30| 24|  1000.0|     0.0|
| a1| b1|500|2018-04-15 19:00:00| 25|  1000.0|   500.0|
| a1| b1|500|2018-04-15 19:30:30| 27|  1000.0|  1000.0|
| a3| b0|100|2018-04-07 15:00:00| 10|   100.0|     0.0|
+---+---+---+-------------------+---+--------+--------+

Actually, you can also use another period type, such as hourly. Just replace the last parameter of augment_period_window_cols with 1 hour.

Thank you for reading.