PySpark DataFrame sampleBy() Function - Group-Wise Sampling
In this PySpark tutorial, learn how to use the sampleBy()
function to perform group-wise sampling from a DataFrame. It's ideal for stratified sampling, testing models, or creating balanced subsets of data grouped by a specific column.
Step 1: Create a Sample DataFrame
data = [
("Aamir Shahzad", "Engineering"),
("Ali Raza", "HR"),
("Bob", "Engineering"),
("Lisa", "Marketing"),
("Ali Raza", "HR"),
("Aamir Shahzad", "Engineering"),
("Lisa", "Marketing"),
("Bob", "Engineering"),
("Aamir Shahzad", "Engineering"),
("Ali Raza", "HR")
]
columns = ["name", "department"]
df = spark.createDataFrame(data, columns)
print("📌 Original DataFrame:")
df.show()
Step 2: Count per Group
print("📊 Count per name before sampling:")
df.groupBy("name").count().show()
Step 3: Apply sampleBy() with Different Sampling Fractions
fractions = {
"Lisa": 1.0,
"Ali Raza": 0.7,
"Aamir Shahzad": 1.0
}
sampled_df = df.sampleBy("name", fractions=fractions, seed=42)
print("📊 Sampled DataFrame:")
sampled_df.show()
Step 4: Count per Group After Sampling
print("📊 Count per name after sampling:")
sampled_df.groupBy("name").count().show()
Step 5: Use Column Object Instead of String
from pyspark.sql.functions import col
bob_sample = df.sampleBy(col("name"), fractions={"Bob": 1.0}, seed=99)
print("📊 Sample only 'Bob' using Column object:")
bob_sample.show()
Summary
sampleBy()
is used for group-wise (stratified) sampling.- You define a sampling fraction per group key (e.g., name or category).
- Great for testing, ML training balance, and subset analysis.
No comments:
Post a Comment