ChatGPT解决这个技术问题 Extra ChatGPT

How to select the first row of each group?

I have a DataFrame generated as follow:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

The results look like:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

As you can see, the DataFrame is ordered by Hour in an increasing order, then by TotalValue in a descending order.

I would like to select the top row of each group, i.e.

from the group of Hour==0 select (0,cat26,30.9)

from the group of Hour==1 select (1,cat67,28.5)

from the group of Hour==2 select (2,cat56,39.6)

and so on

So the desired output would be:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

It might be handy to be able to select the top N rows of each group as well.

Any help is highly appreciated.


1
10465355

Window functions:

Something like this should do the trick:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

This method will be inefficient in case of significant data skew. This problem is tracked by SPARK-34775 and might be resolved in the future (SPARK-37099).

Plain SQL aggregation followed by join:

Alternatively you can join with aggregated data frame:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Using ordering over structs:

Neat, although not very well tested, trick which doesn't require joins or window functions:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

With DataSet API (Spark 1.6+, 2.0+):

Spark 1.6:

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 or later:

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed output mode.

Don't use:

df.orderBy(...).groupBy(...).agg(first(...), ...)

It may seem to work (especially in the local mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).

The same note applies to

df.orderBy(...).dropDuplicates(...)

which internally uses equivalent execution plan.


It looks like since spark 1.6 it is row_number() instead of rowNumber
About the Don't use df.orderBy(...).gropBy(...). Under what circumstances can we rely on orderBy(...)? or if we can not be sure if orderBy() is going to give the correct result, what alternatives do we have?
I might be overlooking something, but in general it is recommended to avoid groupByKey, instead reduceByKey should be used. Also, you'll be saving one line.
@Thomas avoiding groupBy/groupByKey is just when dealing with RDDs, you'll notice that the Dataset api doesn't even have a reduceByKey function.
A
Antonín Hoskovec

For Spark 2.0.2 with grouping by multiple columns:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

This code is more or less contained in Apache DataFu's dedupWithOrder method
C
Community

This is a exact same of zero323's answer but in SQL query way.

Assuming that dataframe is created and registered as

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Window function :

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Plain SQL aggregation followed by join:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Using ordering over structs:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

DataSets way and don't dos are same as in original answer


H
Heedo Lee

You can use max_by() function from Spark 3.0 !

https://spark.apache.org/docs/3.0.0-preview/api/sql/index.html#max_by

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

// Register the DataFrame as a SQL temporary view
df.createOrReplaceTempView("table")

// Using SQL
val result = spark.sql("select Hour, max_by(Category, TotalValue) AS Category, max(TotalValue) as TotalValue FROM table group by Hour order by Hour")

// or Using DataFrame API
val result = df.groupBy("Hour").
  agg(expr("max_by(Category, TotalValue)").as("Category"), max("TotalValue").as("TotalValue")).
  sort("Hour")

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
+----+--------+----------+

R
Rubber Duck

The pattern is group by keys => do something to each group e.g. reduce => return to dataframe

I thought the Dataframe abstraction is a bit cumbersome in this case so I used RDD functionality

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)

E
Eyal

You can do this easily with Apache DataFu (the implementation is similar to Antonin's answer).

import datafu.spark.DataFrameOps._

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

df.dedupWithOrder($"Hour", $"TotalValue".desc).show

which will result in

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   3|    cat8|      35.6|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
+----+--------+----------+

(yes, the result will not be ordered by Hour, but you can always do that later if it's important)

There's also an API - dedupTopN - for taking the top N rows. And another API - dedupWithCombiner - when you expect a large number of rows per grouping.

(full disclosure - I'm part of the DataFu project)


e
elghoto

The solution below does only one groupBy and extract the rows of your dataframe that contain the maxValue in one shot. No need for further Joins, or Windows.

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}

But it shuffles everything first. It is hardly an improvement (maybe not worse than window functions, depending on the data).
you have a group first place, that will triggers a shuffle. It's not worse than window function because in a window function it's going to evaluate the window for each single row in the dataframe.
r
randal25

A nice way of doing this with the dataframe api is using the argmax logic like so

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+

S
Shubham Agrawal

Here you can do like this -

   val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show