PySpark collect() Function Tutorial
Retrieve Entire DataFrame to Driver with Examples
Learn how to use the collect()
function in PySpark to retrieve all rows from a DataFrame into the driver node for local processing or debugging.
1. What is collect()
in PySpark?
The collect()
function gathers all rows from a PySpark DataFrame or RDD and returns them to the driver as a Python list.
- ✅ Useful for small datasets or debugging
- ⚠️ Avoid on large datasets to prevent memory overload
2. Create Spark Session
from pyspark.sql import SparkSession
spark = SparkSession.builder \\
.appName("PySpark collect() Example") \\
.getOrCreate()
3. Create a Sample DataFrame
data = [
(1, "Aamir Shahzad", 5000),
(2, "Ali Raza", 6000),
(3, "Bob", 5500),
(4, "Lisa", 7000)
]
columns = ["ID", "Name", "Salary"]
df = spark.createDataFrame(data, columns)
df.show()
+---+--------------+------+
| ID| Name|Salary|
+---+--------------+------+
| 1| Aamir Shahzad| 5000|
| 2| Ali Raza| 6000|
| 3| Bob| 5500|
| 4| Lisa| 7000|
+---+--------------+------+
| ID| Name|Salary|
+---+--------------+------+
| 1| Aamir Shahzad| 5000|
| 2| Ali Raza| 6000|
| 3| Bob| 5500|
| 4| Lisa| 7000|
+---+--------------+------+
4. Using collect()
to Retrieve Data
collected_data = df.collect()
for row in collected_data:
print(row)
Row(ID=1, Name='Aamir Shahzad', Salary=5000)
Row(ID=2, Name='Ali Raza', Salary=6000)
Row(ID=3, Name='Bob', Salary=5500)
Row(ID=4, Name='Lisa', Salary=7000)
Row(ID=2, Name='Ali Raza', Salary=6000)
Row(ID=3, Name='Bob', Salary=5500)
Row(ID=4, Name='Lisa', Salary=7000)
5. Accessing Individual Column Values
for row in collected_data:
print(f"ID: {row['ID']}, Name: {row['Name']}, Salary: {row['Salary']}")
6. When to Use collect()
- To inspect data locally during development
- For exporting or logging small result sets
- Avoid in production for large datasets
No comments:
Post a Comment