PySpark Tutorial: How to Use rollup()
to Aggregate Data by Groups and Subtotals
In this PySpark tutorial, we’ll explore how to use the rollup()
function to perform multi-level aggregations (group subtotals and grand totals).
It’s very useful when you want to analyze hierarchical data grouped by multiple columns.
1. What is rollup()
in PySpark?
The rollup()
function in PySpark allows you to create subtotals and a grand total in grouped aggregations. It’s similar to SQL’s ROLLUP
.
2. Import Required Libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum as _sum
3. Create Sample DataFrame
data = [
("Aamir Shahzad", "Pakistan", 5000),
("Ali Raza", "Pakistan", 6000),
("Bob", "USA", 5500),
("Lisa", "Canada", 7000),
("Aamir Shahzad", "Pakistan", 8000),
("Ali Raza", "Pakistan", 6500),
("Bob", "USA", 5200),
("Lisa", "Canada", 7200)
]
columns = ["Name", "Country", "Salary"]
df = spark.createDataFrame(data, columns)
print("Original DataFrame:")
df.show()
Output:
+--------------+--------+------+
| Name| Country|Salary|
+--------------+--------+------+
| Aamir Shahzad|Pakistan| 5000|
| Ali Raza|Pakistan| 6000|
| Bob| USA| 5500|
| Lisa| Canada| 7000|
| Aamir Shahzad|Pakistan| 8000|
| Ali Raza|Pakistan| 6500|
| Bob| USA| 5200|
| Lisa| Canada| 7200|
+--------------+--------+------+
4. rollup() on Country and Name
df_rollup = df.rollup("Country", "Name") \
.agg(_sum("Salary").alias("Total_Salary")) \
.orderBy("Country", "Name")
print("Rollup Aggregation by Country and Name:")
df_rollup.show(truncate=False)
Output:
+--------+--------------+------------+
|Country |Name |Total_Salary|
+--------+--------------+------------+
|Canada |Lisa |14200 |
|Canada |null |14200 |
|Pakistan|Aamir Shahzad |13000 |
|Pakistan|Ali Raza |12500 |
|Pakistan|null |25500 |
|USA |Bob |10700 |
|USA |null |10700 |
|null |null |50400 |
+--------+--------------+------------+
5. rollup() on Country Only
df_rollup_country = df.rollup("Country") \
.agg(_sum("Salary").alias("Total_Salary")) \
.orderBy("Country")
print("Rollup Aggregation by Country with Grand Total:")
df_rollup_country.show(truncate=False)
Output:
+--------+------------+
|Country |Total_Salary|
+--------+------------+
|Canada |14200 |
|Pakistan|25500 |
|USA |10700 |
|null |50400 |
+--------+------------+
Conclusion
The rollup()
function is a powerful tool for hierarchical aggregation in PySpark. It allows you to create subtotals at multiple levels as well as an overall total.
It’s great for building summarized reports or dashboards directly from Spark DataFrames.
No comments:
Post a Comment