I encountered an issue when applying crosstab function in PySpark to a pretty big data. And I think this should be considered as a pretty big issue.
Suppose we have a dataframe df with two columns, A and B. Column A consists of lots of unique values, while column B consists of two unique values, namely 0 and 1. We would like in this case, to calculate the number of each unique value of column B for each unique value of column A.
Let’s take a look at an example for the sake of clarity.
DATAFRAME df ============================= A | B ============================= unique_val_a | 0 unique_val_a | 0 unique_val_a | 0 unique_val_b | 1 unique_val_c | 1 unique_val_c | 1 unique_val_c | 1 unique_val_d | 0 unique_val_d | 0 unique_val_d | 0 unique_val_d | 1 ============================= EXPECTED OUTPUT ============================================ A_B | 0 | 1 ============================================ unique_val_a | 3 | 0 unique_val_b | 0 | 1 unique_val_c | 0 | 3 unique_val_d | 3 | 1 ============================================
Accomplishing such a result is extremely simple though thanks to PySpark’s statistical function, such as crosstab. Let’s see how we can get the result.
df = df.crosstab(‘A’, ‘B’)
The function seems perfectly fit to our needs from the outside. However, I came across an odd result when playing with the data type of column B. At that time I was experimenting with integer and string data type only. Let’s take a look at the code.
from pyspark.sql.types import IntegerType, StringType int_df = df.withColumn(‘B’, F.col(‘B’).cast(IntegerType())) int_df = int_df.crosstab(‘A’, ‘B’) print(int_df.count()) str_df = df.withColumn(‘B’, F.col(‘B’).cast(StringType())) str_df = str_df.crosstab(‘A’, ‘B’) print(str_df.count())
The focus was that I wanted to check whether different data type yielded the same number of rows. I thought that there was no need to check the dataframes’ content (int_df and str_df) when the number of rows was different.
The result? Surprisingly it didn’t yield the same result for integer and string data type. The difference was pretty small, however. But the point here is that the function (crosstab) doesn’t yield the same output for different data types. I’m not sure about the rationales since there is no any information regarding data types usage in the documentation. Perhaps I need to look at the code base.
What made the case worse was that each dataframe (int_df and str_df) seemed to have different data. We can check it by computing the set difference of column A between int_df and str_df. Let’s take a look.
set(int_df.select(‘A_B’).collect()) - set(str_df.select(‘A_B’).collect())
As I said before, the above code didn’t yield an empty set. In my case, both sets differed by many elements. Obviously, this should not be a good news as it means that this crosstab depends on the used data type. Since there’s no any documentation on data type use when using crosstab, I think the developers might use the function inappropriately.
Since I was curious, I created a small dataframe (the one explained before was pretty big) and applied the same scenario (two different data types for column B) just to investigate the behaviour. I expected that the result should be different as shown by the big dataframe. My expectation was wrong since both dataframe yielded the same result.
Based on the quick investigation using small and pretty big data, up to this point I haven’t known the limit of dataframe size in order for crosstab to result in the same output for different data types.
But let’s leave this issue for a while. I decided to search for a better solution.
The principle was since crosstab utilises the concept of groupby, I thought that groupby with a few engineering tricks might be the solution. Let’s take a look at the code.
from pyspark.sql.types import IntegerType, StringType int_df = df.withColumn(‘B’, F.col(‘B’).cast(IntegerType())) int_df = int_df.groupby(‘A’).pivot(‘B’, [0, 1]).count().fillna(0) print(int_df.count()) str_df = df.withColumn(‘B’, F.col(‘B’).cast(StringType())) str_df = str_df.groupby(‘A’).pivot(‘B’, [‘0’, ‘1’]).count().fillna(0) print(str_df.count())
As you can see, the engineering trick is implemented by pivot. Simply, pivot creates a tabular data with several main columns first followed by elements provided in the list (second argument of pivot method). We apply a count method to calculate the number of each unique value of column B. And then fillna to replace all null values with zero (it seems that the count method only returns values more than 0).
Here comes the good news. Using this groupby-pivot approach yields the same result for both integer and string data type.
Now, let’s execute the set difference computation.
set(int_df.select(‘A’).collect()) - set(str_df.select(‘A’).collect())
Well, it resulted in an empty set, which is good.
OVERALL CHECK & OBSERVATION
I compared the result returned by crosstab and groupby_pivot with groupby_only approach. For the sake of clarity, let’s take a look at the code.
gb_only = df.groupby(‘A’).count() print(gb_only.count())
The above code yielded the same result as what was returned by using groupby_pivot approach. This should clarify that groupby_pivot really works correctly.
Last but not least, I encountered that the number of rows (by count() method) returned by crosstab and groupby_pivot differed by factor of two (approximately). This means that the number of rows returned by groupby_pivot was approx. two times more than the number of rows returned by crosstab.
What does this mean?
Well, according to me, this means that crosstab function might cause what’s called as data loss. Using the above case, we can consider that this function didn’t retain the other 50% (approximately) of the data.
WHAT DO YOU THINK?
So, do you have any idea about this PySpark’s crosstab issue? I would love to know your thoughts.