How to Use persist()
Function in PySpark – Cache vs Persist Explained
The persist()
function in PySpark is used to cache or store a DataFrame's intermediate results. It's especially useful when the same DataFrame is reused multiple times and you want to avoid recomputation for performance gains. Learn how it differs from cache()
and how to use it effectively.
Step 1: Create a Sample DataFrame
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
data = [
("Aamir Shahzad", "Engineering", 100000),
("Ali Raza", "HR", 70000),
("Bob", "Engineering", 80000),
("Lisa", "Marketing", 65000)
]
columns = ["name", "department", "salary"]
df = spark.createDataFrame(data, columns)
print("📌 Sample DataFrame:")
df.show()
Step 2: Use persist()
to Store Intermediate Result
from pyspark.storagelevel import StorageLevel
df_cached = df.filter(df.salary > 70000).persist(StorageLevel.MEMORY_AND_DISK)
print("📌 Count of high earners (cached):", df_cached.count())
print("📌 Average salary (cached):", df_cached.groupBy().avg("salary").collect())
Step 3: Optionally Unpersist to Free Memory
df_cached.unpersist()
Summary
persist()
allows finer control thancache()
by letting you specify the storage level (memory, disk, or both).- Useful for long-running or reused DataFrames to reduce recomputation.
- Always call
unpersist()
after you're done to free up resources.
No comments:
Post a Comment