我正在尝试使用对现有列集的 groupby 聚合在 Pyspark 中创建一个新的列表列。下面提供了一个示例输入数据框:
------------------------
id | date | value
------------------------
1 |2014-01-03 | 10
1 |2014-01-04 | 5
1 |2014-01-05 | 15
1 |2014-01-06 | 20
2 |2014-02-10 | 100
2 |2014-03-11 | 500
2 |2014-04-15 | 1500
预期的输出是:
id | value_list
------------------------
1 | [10, 5, 15, 20]
2 | [100, 500, 1500]
列表中的值按日期排序。
我尝试使用 collect_list 如下:
from pyspark.sql import functions as F
ordered_df = input_df.orderBy(['id','date'],ascending = True)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))
但是即使我在聚合之前按日期对输入数据帧进行排序,collect_list 也不能保证顺序。
有人可以通过保留基于第二个(日期)变量的顺序来帮助如何进行聚合吗?
from pyspark.sql import functions as F
from pyspark.sql import Window
w = Window.partitionBy('id').orderBy('date')
sorted_list_df = input_df.withColumn(
'sorted_list', F.collect_list('value').over(w)
)\
.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))
Window
用户提供的示例通常并不能真正解释发生了什么,所以让我为您剖析一下。
如您所知,将 collect_list
与 groupBy
一起使用将产生一个无序值列表。这是因为根据您的数据分区方式,Spark 会在找到组中的一行后立即将值附加到您的列表中。然后,顺序取决于 Spark 如何计划您对 executor 的聚合。
Window
函数允许您控制这种情况,按特定值对行进行分组,以便您可以对每个结果组执行操作 over
:
w = Window.partitionBy('id').orderBy('date')
partitionBy - 您想要具有相同 id 的行的组/分区
orderBy - 您希望组中的每一行按日期排序
一旦定义了 Window 的范围——“具有相同 id
的行,按 date
排序”——就可以使用它对其执行操作,在本例中为 collect_list
:
F.collect_list('value').over(w)
此时,您创建了一个新列 sorted_list
,其中包含按日期排序的有序值列表,但每个 id
仍有重复的行。要删除您想要 groupBy
id
的重复行并为每个组保留 max
值:
.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))
如果您将日期和值都收集为列表,则可以使用和 udf
根据日期对结果列进行排序,然后仅保留结果中的值。
import operator
import pyspark.sql.functions as F
# create list column
grouped_df = input_df.groupby("id") \
.agg(F.collect_list(F.struct("date", "value")) \
.alias("list_col"))
# define udf
def sorter(l):
res = sorted(l, key=operator.itemgetter(0))
return [item[1] for item in res]
sort_udf = F.udf(sorter)
# test
grouped_df.select("id", sort_udf("list_col") \
.alias("sorted_list")) \
.show(truncate = False)
+---+----------------+
|id |sorted_list |
+---+----------------+
|1 |[10, 5, 15, 20] |
|2 |[100, 500, 1500]|
+---+----------------+
array_sort
的引入,这是最好的方法,因为它不需要 UDF 的开销。
您可以使用 sort_array
功能。如果您将日期和值都收集为列表,则可以使用 sort_array
对结果列进行排序并仅保留所需的列。
import operator
import pyspark.sql.functions as F
grouped_df = input_df.groupby("id") \
.agg(F.sort_array(F.collect_list(F.struct("date", "value"))) \
.alias("collected_list")) \
.withColumn("sorted_list",col("collected_list.value")) \
.drop("collected_list")
.show(truncate=False)
+---+----------------+
|id |sorted_list |
+---+----------------+
|1 |[10, 5, 15, 20] |
|2 |[100, 500, 1500]|
+---+----------------+ ```````
问题是针对 PySpark,但对 Scala Spark 也有帮助。
让我们准备测试数据框:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{ Window, UserDefinedFunction}
import java.sql.Date
import java.time.LocalDate
val spark: SparkSession = ...
// Out test data set
val data: Seq[(Int, Date, Int)] = Seq(
(1, Date.valueOf(LocalDate.parse("2014-01-03")), 10),
(1, Date.valueOf(LocalDate.parse("2014-01-04")), 5),
(1, Date.valueOf(LocalDate.parse("2014-01-05")), 15),
(1, Date.valueOf(LocalDate.parse("2014-01-06")), 20),
(2, Date.valueOf(LocalDate.parse("2014-02-10")), 100),
(2, Date.valueOf(LocalDate.parse("2014-02-11")), 500),
(2, Date.valueOf(LocalDate.parse("2014-02-15")), 1500)
)
// Create dataframe
val df: DataFrame = spark.createDataFrame(data)
.toDF("id", "date", "value")
df.show()
//+---+----------+-----+
//| id| date|value|
//+---+----------+-----+
//| 1|2014-01-03| 10|
//| 1|2014-01-04| 5|
//| 1|2014-01-05| 15|
//| 1|2014-01-06| 20|
//| 2|2014-02-10| 100|
//| 2|2014-02-11| 500|
//| 2|2014-02-15| 1500|
//+---+----------+-----+
使用 UDF
// Group by id and aggregate date and value to new column date_value
val grouped = df.groupBy(col("id"))
.agg(collect_list(struct("date", "value")) as "date_value")
grouped.show()
grouped.printSchema()
// +---+--------------------+
// | id| date_value|
// +---+--------------------+
// | 1|[[2014-01-03,10],...|
// | 2|[[2014-02-10,100]...|
// +---+--------------------+
// udf to extract data from Row, sort by needed column (date) and return value
val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => {
rows.map { case Row(date: Date, value: Int) => (date, value) }
.sortBy { case (date, value) => date }
.map { case (date, value) => value }
})
// Select id and value_list
val r1 = grouped.select(col("id"), sortUdf(col("date_value")).alias("value_list"))
r1.show()
// +---+----------------+
// | id| value_list|
// +---+----------------+
// | 1| [10, 5, 15, 20]|
// | 2|[100, 500, 1500]|
// +---+----------------+
使用窗口
val window = Window.partitionBy(col("id")).orderBy(col("date"))
val sortedDf = df.withColumn("values_sorted_by_date", collect_list("value").over(window))
sortedDf.show()
//+---+----------+-----+---------------------+
//| id| date|value|values_sorted_by_date|
//+---+----------+-----+---------------------+
//| 1|2014-01-03| 10| [10]|
//| 1|2014-01-04| 5| [10, 5]|
//| 1|2014-01-05| 15| [10, 5, 15]|
//| 1|2014-01-06| 20| [10, 5, 15, 20]|
//| 2|2014-02-10| 100| [100]|
//| 2|2014-02-11| 500| [100, 500]|
//| 2|2014-02-15| 1500| [100, 500, 1500]|
//+---+----------+-----+---------------------+
val r2 = sortedDf.groupBy(col("id"))
.agg(max("values_sorted_by_date").as("value_list"))
r2.show()
//+---+----------------+
//| id| value_list|
//+---+----------------+
//| 1| [10, 5, 15, 20]|
//| 2|[100, 500, 1500]|
//+---+----------------+
为了确保对每个 id 进行排序,我们可以使用 sortWithinPartitions:
from pyspark.sql import functions as F
ordered_df = (
input_df
.repartition(input_df.id)
.sortWithinPartitions(['date'])
)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))
我尝试了 TMichel 方法并没有为我工作。当我执行最大聚合时,我没有取回列表的最高值。所以对我有用的是以下内容:
def max_n_values(df, key, col_name, number):
'''
Returns the max n values of a spark dataframe
partitioned by the key and ranked by the col_name
'''
w2 = Window.partitionBy(key).orderBy(f.col(col_name).desc())
output = df.select('*',
f.row_number().over(w2).alias('rank')).filter(
f.col('rank') <= number).drop('rank')
return output
def col_list(df, key, col_to_collect, name, score):
w = Window.partitionBy(key).orderBy(f.col(score).desc())
list_df = df.withColumn(name, f.collect_set(col_to_collect).over(w))
size_df = list_df.withColumn('size', f.size(name))
output = max_n_values(df=size_df,
key=key,
col_name='size',
number=1)
return output
从 Spark 2.4 开始,@mtoto 的答案中创建的 collect_list(ArrayType) 可以使用 SparkSQL 的内置函数 transform 和 array_sort 进行后处理(不需要 udf):
from pyspark.sql.functions import collect_list, expr, struct
df.groupby('id') \
.agg(collect_list(struct('date','value')).alias('value_list')) \
.withColumn('value_list', expr('transform(array_sort(value_list), x -> x.value)')) \
.show()
+---+----------------+
| id| value_list|
+---+----------------+
| 1| [10, 5, 15, 20]|
| 2|[100, 500, 1500]|
+---+----------------+
注意:如果需要降序,请将 array_sort(value_list)
更改为 sort_array(value_list, False)
警告: 如果项目(在collect_list 中)必须按多个字段(列)以混合顺序(即orderBy('col1', desc('col2'))
)排序,则array_sort() 和sort_array() 将不起作用。
在 Spark SQL 世界中,这个问题的答案是:
SELECT
browser, max(list)
from (
SELECT
id,
COLLECT_LIST(value) OVER (PARTITION BY id ORDER BY date DESC) as list
FROM browser_count
GROUP BYid, value, date)
Group by browser;
如果你想在这里使用 spark sql,你可以如何实现这一点。假设表名(或临时视图)是 temp_table
。
select
t1.id,
collect_list(value) as value_list
(Select * from temp_table order by id,date) t1
group by 1
作为 ShadyStego 的补充,我一直在测试 Spark 上 sortWithinPartitions 和 GroupBy 的使用,发现它的性能比 Window 函数或 UDF 好得多。尽管如此,使用此方法时,每个分区都会出现一次错误排序的问题,但可以轻松解决。我在这里展示它Spark (pySpark) groupBy misordering first element on collect_list。
此方法在大型 DataFrame 上特别有用,但如果驱动程序内存不足,则可能需要大量分区。
不定期副业成功案例分享
F.collect_list('value').over(w)
将创建一个从 1 到 24 的新列大小,即 1000 万 * 24 次。然后通过从每个组中获取最大的行来做另一组。collect_set
而不是collect_list
,这将不起作用。