How to See Record Count Per Partition in a Spark DataFrame (i.e. Find Skew)

One of our greatest enemies in big data processing is cardinality (i.e. skew) in our data. This manifests itself in subtle ways, such as 99 out of 100 tasks finishing quickly, while 1 lone task takes forever to complete (or worse: never does).

Skew is largely inevitable in this line of work, and we have 2 choices:

  • Ignore it, and live with the slowdown
  • Try to find the source of the skew, and mitigate it

Ignoring issues caused by skew can be worth it sometimes, especially if the skew is not too severe, or isn’t worth the time spent for the performance gained. This is particularly true with one-off or ad-hoc analysis that isn’t likely to be repeated, and simply needs to get done.

However, the rest of the time, we need to find out where the skew is occurring, and take steps to dissolve it and get back to processing our big data. This post will show you one way to help find the source of skew in a Spark DataFrame. It won’t delve into the handful of ways to mitigate it (repartitioning, distributing/clustering, isolation, etc) (but our new book will), but this will certainly help pinpoint where the issue may be.

Introducing… Spark Partition ID

There is a built-in function of Spark that allows you to reference the numeric ID of each partition, and perform operations against it. In our case, we’d like the .count() for each Partition ID.

By doing a simple count grouped by partition id, and optionally sorted from smallest to largest, we can see the distribution of our data across partitions. This will help us determine if our dataset is skewed.

Python / PySpark

from pyspark.sql.functions import spark_partition_id, asc, desc
df\
    .withColumn("partitionId", spark_partition_id())\
    .groupBy("partitionId")\
    .count()\
    .orderBy(asc("count"))\
    .show()
+-----------+-----+
|partitionId|count|
+-----------+-----+
|         21|86640|
|          4|86716|
|         19|86729|
|         13|86790|
|         31|86911|
|         25|86927|
|         24|86978|
|         15|87044|
|         10|87085|
|         18|87088|
|         17|87105|
|         22|87236|
|          5|87287|
|         29|87313|
|          2|87331|
|          8|87363|
|          1|87401|
|         16|87424|
|          9|87457|
|         14|87468|
+-----------+-----+
only showing top 20 rows

Scala / Spark

import org.apache.spark.sql.functions.{spark_partition_id, asc, desc}

df
    .groupBy(spark_partition_id)
    .count()
    .orderBy(asc("count"))
    .show()
+--------------------+-----+
|SPARK_PARTITION_ID()|count|
+--------------------+-----+
|                  21|86640|
|                   4|86716|
|                  19|86729|
|                  13|86790|
|                  31|86911|
|                  25|86927|
|                  24|86978|
|                  15|87044|
|                  10|87085|
|                  18|87088|
|                  17|87105|
|                  22|87236|
|                   5|87287|
|                  29|87313|
|                   2|87331|
|                   8|87363|
|                   1|87401|
|                  16|87424|
|                   9|87457|
|                  14|87468|
+--------------------+-----+
only showing top 20 rows

Spark SQL

First, create a version of your DataFrame with the Partition ID added as a field. You can do this in any supported language. Here it is in Scala:

import org.apache.spark.sql.functions.spark_partition_id

val df_with_id = df.withColumn("partitionId", spark_partition_id())
df_with_id.createOrReplaceTempView("df_with_id")

Then, simply execute similar logic as above using Spark SQL (%sql block in Zeppelin/Qubole, or using spark.sql() in any supported language:

select partitionId, count(1) as num_records
from df_with_id
group by partitionId
order by num_records asc

As you can see, the partitions of our Spark DataFrame are nice and evenly distributed. No outliers here!

Let us know if you have any other tricks in the comments!

Big special thanks to this StackOverflow discussion for pointing me in the right direction!

2 thoughts on “How to See Record Count Per Partition in a Spark DataFrame (i.e. Find Skew)

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.