ChatGPT解决这个技术问题 Extra ChatGPT

How to flatten a struct in a Spark dataframe?

I have a dataframe with the following structure:

 |-- data: struct (nullable = true)
 |    |-- id: long (nullable = true)
 |    |-- keyNote: struct (nullable = true)
 |    |    |-- key: string (nullable = true)
 |    |    |-- note: string (nullable = true)
 |    |-- details: map (nullable = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)

How it is possible to flatten the structure and create a new dataframe:

     |-- id: long (nullable = true)
     |-- keyNote: struct (nullable = true)
     |    |-- key: string (nullable = true)
     |    |-- note: string (nullable = true)
     |-- details: map (nullable = true)
     |    |-- key: string
     |    |-- value: string (valueContainsNull = true)

Is there something like explode, but for structs?

The answers at stackoverflow.com/questions/37471346/… were also helpful.
a nice solution is also presented here: stackoverflow.com/questions/47285871/…

2
2 revs user6022341

This should work in Spark 1.6 or later:

df.select(df.col("data.*"))

or

df.select(df.col("data.id"), df.col("data.keyNote"), df.col("data.details"))

Exception in thread "main" org.apache.spark.sql.AnalysisException: No such struct field *
but using select on all the columns like df.select(df.col1, df.col2, df.col3) works, so I will accept this answer
I was just editing but it is strange. I can use *. Maybe some version issue?
Yeah maybe. I'm using spark 1.6.1 and scala 2.10
How would you select key or note under the nested struct keyNote?
a
amza

Here is function that is doing what you want and that can deal with multiple nested columns containing columns with same name:

import pyspark.sql.functions as F

def flatten_df(nested_df):
    flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
    nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']

    flat_df = nested_df.select(flat_cols +
                               [F.col(nc+'.'+c).alias(nc+'_'+c)
                                for nc in nested_cols
                                for c in nested_df.select(nc+'.*').columns])
    return flat_df

Before:

root
 |-- x: string (nullable = true)
 |-- y: string (nullable = true)
 |-- foo: struct (nullable = true)
 |    |-- a: float (nullable = true)
 |    |-- b: float (nullable = true)
 |    |-- c: integer (nullable = true)
 |-- bar: struct (nullable = true)
 |    |-- a: float (nullable = true)
 |    |-- b: float (nullable = true)
 |    |-- c: integer (nullable = true)

After:

root
 |-- x: string (nullable = true)
 |-- y: string (nullable = true)
 |-- foo_a: float (nullable = true)
 |-- foo_b: float (nullable = true)
 |-- foo_c: integer (nullable = true)
 |-- bar_a: float (nullable = true)
 |-- bar_b: float (nullable = true)
 |-- bar_c: integer (nullable = true)

P
Pratik Anurag

For Spark 2.4.5,

while,df.select(df.col("data.*")) will give you org.apache.spark.sql.AnalysisException: No such struct field * in exception

this will work:-

df.select($"data.*")

This works with Spark 3.1.0 too, but it doesn't preserve the data or whatever is parent is selected -- and doesn't descend if there are further nested structs.
when I am selecting as df.select("data.*'), it gives me n*n rows. (Duplicated n rows for each row). My data frame having n-2 distinct rows, so if I put distinct it gives me n-2 result. But I want n results, which is actually present in my data. How can I achieve this using above mentioned select command.
the dollar can be omitted :)
Tested with DBR 9.1, Spark 3.1.2 and it works. df.select("data.*")
f
federicojasson

This flatten_df version flattens the dataframe at every layer level, using a stack to avoid recursive calls:

from pyspark.sql.functions import col


def flatten_df(nested_df):
    stack = [((), nested_df)]
    columns = []

    while len(stack) > 0:
        parents, df = stack.pop()

        flat_cols = [
            col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
            for c in df.dtypes
            if c[1][:6] != "struct"
        ]

        nested_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:6] == "struct"
        ]

        columns.extend(flat_cols)

        for nested_col in nested_cols:
            projected_df = df.select(nested_col + ".*")
            stack.append((parents + (nested_col,), projected_df))

    return nested_df.select(columns)

Example:

from pyspark.sql.types import StringType, StructField, StructType


schema = StructType([
    StructField("some", StringType()),

    StructField("nested", StructType([
        StructField("nestedchild1", StringType()),
        StructField("nestedchild2", StringType())
    ])),

    StructField("renested", StructType([
        StructField("nested", StructType([
            StructField("nestedchild1", StringType()),
            StructField("nestedchild2", StringType())
        ]))
    ]))
])

data = [
    {
        "some": "value1",
        "nested": {
            "nestedchild1": "value2",
            "nestedchild2": "value3",
        },
        "renested": {
            "nested": {
                "nestedchild1": "value4",
                "nestedchild2": "value5",
            }
        }
    }
]

df = spark.createDataFrame(data, schema)
flat_df = flatten_df(df)
print(flat_df.collect())

Prints:

[Row(some=u'value1', renested_nested_nestedchild1=u'value4', renested_nested_nestedchild2=u'value5', nested_nestedchild1=u'value2', nested_nestedchild2=u'value3')]

This doesn't seem to recurse into nested structs inside arrays.
@malthe It won't. I don't think it's feasible to do that, actually. Assuming you use the array index as column name (e.g. array.0.field, array.1.field, ...), you'll have to know the length of the array beforehand. All these solutions iterate the dataframe structure, which is known at the driver.
I ended up figuring out how to do it and posted a script here: stackoverflow.com/a/66482320/647151.
Oh, so the idea was to keep the array but transform the structures it contains. Nice!
A
Aydin K.

I generalized the solution from stecos a bit more so the flattening can be done on more than two struct layers deep:

def flatten_df(nested_df, layers):
    flat_cols = []
    nested_cols = []
    flat_df = []

    flat_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] != 'struct'])
    nested_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] == 'struct'])

    flat_df.append(nested_df.select(flat_cols[0] +
                               [col(nc+'.'+c).alias(nc+'_'+c)
                                for nc in nested_cols[0]
                                for c in nested_df.select(nc+'.*').columns])
                  )
    for i in range(1, layers):
        print (flat_cols[i-1])
        flat_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] != 'struct'])
        nested_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] == 'struct'])

        flat_df.append(flat_df[i-1].select(flat_cols[i] +
                                [col(nc+'.'+c).alias(nc+'_'+c)
                                    for nc in nested_cols[i]
                                    for c in flat_df[i-1].select(nc+'.*').columns])
        )

    return flat_df[-1]

just call with:

my_flattened_df = flatten_df(my_df_having_nested_structs, 3)

(second parameter is the level of layers to be flattened, in my case it's 3)


D
Domenico Di Nicola

A little more compact and efficient implementation:

No need to create list and iterate on them. You "act" on fields based on their type (if structures or not).

you create a list and iterate on it, if the column is nested (struct) you need to flat it (.*) else you access with dot notation (parent.child) and replace . with _ (parent_child)

def flatten_df(nested_df):
    stack = [((), nested_df)]
    columns = []
    while len(stack) > 0:
        parents, df = stack.pop()
        for column_name, column_type in df.dtypes:
            if column_type[:6] == "struct":
                projected_df = df.select(column_name + ".*")
                stack.append((parents + (column_name,), projected_df))
            else:
                columns.append(col(".".join(parents + (column_name,))).alias("_".join(parents + (column_name,))))
    return nested_df.select(columns)

Welcome to Stack Overflow. Code-only answers are discouraged on Stack Overflow because they don't explain how it solves the problem. Please edit your answer to explain what this code does and how it is more efficient than the other answers as you say it is, so that it is useful to other users with similar issues and they can learn from it.
worked perfect for me. You would get more like if you add more explanation for sure.
what is the data type for "parents" @Domenico Di Nicola
I had problems when the column_name includes 'dot' in it. i know if it is explicit column name i can use back tick to escape it, but above, i could not do it.
N
Narahari B M

PySpark solution to flatten nested df with both struct and array types with any level of depth. This is improved on this: https://stackoverflow.com/a/56533459/7131019

from pyspark.sql.types import *
from pyspark.sql import functions as f

def flatten_structs(nested_df):
    stack = [((), nested_df)]
    columns = []

    while len(stack) > 0:
        
        parents, df = stack.pop()
        
        array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
        
        flat_cols = [
            f.col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
            for c in df.dtypes
            if c[1][:6] != "struct"
        ]

        nested_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:6] == "struct"
        ]
        
        columns.extend(flat_cols)

        for nested_col in nested_cols:
            projected_df = df.select(nested_col + ".*")
            stack.append((parents + (nested_col,), projected_df))
        
    return nested_df.select(columns)

def flatten_array_struct_df(df):
    
    array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
    
    while len(array_cols) > 0:
        
        for array_col in array_cols:
            
            cols_to_select = [x for x in df.columns if x != array_col ]
            
            df = df.withColumn(array_col, f.explode(f.col(array_col)))
            
        df = flatten_structs(df)
        
        array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
    return df

flat_df = flatten_array_struct_df(df)

This works very well. and i take the 'responsibility' of using it carefully, since flatten array of struct can produce duplicate rows as others said.
A
Amrish Mishra

You can use this approach if, you have to covert only struct types. I would not suggest converting the array, as it could lead to duplicate records.

from pyspark.sql.functions import col
from pyspark.sql.types import StructType


def flatten_schema(schema, prefix=""):
    return_schema = []
    for field in schema.fields:
        if isinstance(field.dataType, StructType):
            if prefix:
                return_schema = return_schema + flatten_schema(field.dataType, "{}.{}".format(prefix, field.name))
            else:
                return_schema = return_schema + flatten_schema(field.dataType, field.name)
        else:
            if prefix:
                field_path = "{}.{}".format(prefix, field.name)
                return_schema.append(col(field_path).alias(field_path.replace(".", "_")))
            else:
                return_schema.append(field.name)
    return return_schema

You can use it as

new_schema = flatten_schema(df.schema)
df1 = df.select(se)
df1.show()

V
V_K

Based on https://stackoverflow.com/a/49532496/17250408 here is solution for struct and array fields with multilevel nesting

from pyspark.sql.functions import col, explode


def type_cols(df_dtypes, filter_type):
    cols = []
    for col_name, col_type in df_dtypes:
        if col_type.startswith(filter_type):
            cols.append(col_name)
    return cols


def flatten_df(nested_df, sep='_'):
    nested_cols = type_cols(nested_df.dtypes, "struct")
    flatten_cols = [fc for fc, _ in nested_df.dtypes if fc not in nested_cols]
    for nc in nested_cols:
        for cc in nested_df.select(f"{nc}.*").columns:
            if sep is None:
                flatten_cols.append(col(f"{nc}.{cc}").alias(f"{cc}"))
            else:
                flatten_cols.append(col(f"{nc}.{cc}").alias(f"{nc}{sep}{cc}"))
    return nested_df.select(flatten_cols)


def explode_df(nested_df):
    nested_cols = type_cols(nested_df.dtypes, "array")
    exploded_df = nested_df
    for nc in nested_cols:
        exploded_df = exploded_df.withColumn(nc, explode(col(nc)))
    return exploded_df


def flatten_explode_df(nested_df):
    df = nested_df
    struct_cols = type_cols(nested_df.dtypes, "struct")
    array_cols = type_cols(nested_df.dtypes, "array")
    if struct_cols:
        df = flatten_df(df)
        return flatten_explode_df(df)
    if array_cols:
        df = explode_df(df)
        return flatten_explode_df(df)
    return df


df = flatten_explode_df(nested_df)

T
Tshilidzi Mudau

An easy way is to use SQL, you could build a SQL query string to alias nested column as flat ones.

Retrieve data-frame schema (df.schema())

Transform schema to SQL (for (field : schema().fields()) ...

Query: val newDF = sqlContext.sql("SELECT " + sqlGenerated + " FROM source")

Here is an example in Java.

(I prefer SQL way, so you can easily test it on Spark-shell and it's cross-language).


s
sri hari kali charan Tummala

below worked for me in spark sql

import org.apache.spark.sql._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
import org.apache.http.client.methods.HttpGet
import org.apache.http.impl.client.DefaultHttpClient
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.functions.{explode, expr, posexplode, when}

object StackOverFlowQuestion {
  def main(args: Array[String]): Unit = {

    val logger = Logger.getLogger("FlattenTest")
    Logger.getLogger("org").setLevel(Level.WARN)
    Logger.getLogger("akka").setLevel(Level.WARN)

    val spark = SparkSession.builder()
      .appName("FlattenTest")
      .config("spark.sql.warehouse.dir", "C:\\Temp\\hive")
      .master("local[2]")
      //.enableHiveSupport()
      .getOrCreate()
    import spark.implicits._

    val stringTest =
      """{
                               "total_count": 123,
                               "page_size": 20,
                               "another_id": "gdbfdbfdbd",
                               "sen": [{
                                "id": 123,
                                "ses_id": 12424343,
                                "columns": {
                                    "blah": "blah",
                                    "count": 1234
                                },
                                "class": {},
                                "class_timestamps": {},
                                "sentence": "spark is good"
                               }]
                            }
                             """
    val result = List(stringTest)
    val githubRdd=spark.sparkContext.makeRDD(result)
    val gitHubDF=spark.read.json(githubRdd)
    gitHubDF.show()
    gitHubDF.printSchema()

    gitHubDF.registerTempTable("JsonTable")

   spark.sql("with cte as" +
      "(" +
      "select explode(sen) as senArray  from JsonTable" +
      "), cte_2 as" +
      "(" +
      "select senArray.ses_id,senArray.ses_id,senArray.columns.* from cte" +
      ")" +
      "select * from cte_2"
    ).show()

    spark.stop()
}

}

output:-

+----------+---------+--------------------+-----------+
|another_id|page_size|                 sen|total_count|
+----------+---------+--------------------+-----------+
|gdbfdbfdbd|       20|[[[blah, 1234], 1...|        123|
+----------+---------+--------------------+-----------+

root
 |-- another_id: string (nullable = true)
 |-- page_size: long (nullable = true)
 |-- sen: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- columns: struct (nullable = true)
 |    |    |    |-- blah: string (nullable = true)
 |    |    |    |-- count: long (nullable = true)
 |    |    |-- id: long (nullable = true)
 |    |    |-- sentence: string (nullable = true)
 |    |    |-- ses_id: long (nullable = true)
 |-- total_count: long (nullable = true)

+--------+--------+----+-----+
|  ses_id|  ses_id|blah|count|
+--------+--------+----+-----+
|12424343|12424343|blah| 1234|
+--------+--------+----+-----+

R
Raptor0009

This is for scala spark.

val totalMainArrayBuffer=collection.mutable.ArrayBuffer[String]()
def flatten_df_Struct(dfTemp:org.apache.spark.sql.DataFrame,dfTotalOuter:org.apache.spark.sql.DataFrame):org.apache.spark.sql.DataFrame=
{
//dfTemp.printSchema
val totalStructCols=dfTemp.dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(_.split(",",2)(1).contains("Struct")) // in case i the column names come with the word Struct embedded in it
val mainArrayBuffer=collection.mutable.ArrayBuffer[String]()
for(totalStructCol <- totalStructCols)
{
val tempArrayBuffer=collection.mutable.ArrayBuffer[String]()
tempArrayBuffer+=s"${totalStructCol.split(",")(0)}.*"
//tempArrayBuffer.toSeq.toDF.show(false)
val columnsInside=dfTemp.selectExpr(tempArrayBuffer:_*).columns
for(column <- columnsInside)
mainArrayBuffer+=s"${totalStructCol.split(",")(0)}.${column} as ${totalStructCol.split(",")(0)}_${column}"
//mainArrayBuffer.toSeq.toDF.show(false)
}
//dfTemp.selectExpr(mainArrayBuffer:_*).printSchema
val nonStructCols=dfTemp.selectExpr(mainArrayBuffer:_*).dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(!_.split(",",2)(1).contains("Struct")) // in case i the column names come with the word Struct embedded in it
for (nonStructCol <- nonStructCols)
totalMainArrayBuffer+=s"${nonStructCol.split(",")(0).replace("_",".")} as ${nonStructCol.split(",")(0)}" // replacing _ by . in origial select clause if it's an already nested column 
dfTemp.selectExpr(mainArrayBuffer:_*).dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(_.split(",",2)(1).contains("Struct")).size 
match {
case value if value ==0 => dfTotalOuter.selectExpr(totalMainArrayBuffer:_*)
case _ => flatten_df_Struct(dfTemp.selectExpr(mainArrayBuffer:_*),dfTotalOuter)
}
}


def flatten_df(dfTemp:org.apache.spark.sql.DataFrame):org.apache.spark.sql.DataFrame=
{
var totalArrayBuffer=collection.mutable.ArrayBuffer[String]()
val totalNonStructCols=dfTemp.dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(!_.split(",",2)(1).contains("Struct")) // in case i the column names come with the word Struct embedded in it
for (totalNonStructCol <- totalNonStructCols)
totalArrayBuffer+=s"${totalNonStructCol.split(",")(0)}"
totalMainArrayBuffer.clear
flatten_df_Struct(dfTemp,dfTemp) // flattened schema is now in totalMainArrayBuffer 
totalArrayBuffer=totalArrayBuffer++totalMainArrayBuffer
dfTemp.selectExpr(totalArrayBuffer:_*)
}


flatten_df(dfTotal.withColumn("tempStruct",lit(5))).printSchema



File

{"num1":1,"num2":2,"bool1":true,"bool2":false,"double1":4.5,"double2":5.6,"str1":"a","str2":"b","arr1":[3,4,5],"map1":{"cool":1,"okay":2,"normal":3},"carInfo":{"Engine":{"Make":"sa","Power":{"IC":"900","battery":"165"},"Redline":"11500"} ,"Tyres":{"Make":"Pirelli","Compound":"c1","Life":"120"}}}
{"num1":3,"num2":4,"bool1":false,"bool2":false,"double1":4.2,"double2":5.5,"str1":"u","str2":"n","arr1":[6,7,9],"map1":{"fast":1,"medium":2,"agressive":3},"carInfo":{"Engine":{"Make":"na","Power":{"IC":"800","battery":"150"},"Redline":"10000"} ,"Tyres":{"Make":"Pirelli","Compound":"c2","Life":"100"}}}
{"num1":8,"num2":4,"bool1":true,"bool2":true,"double1":5.7,"double2":7.5,"str1":"t","str2":"k","arr1":[11,12,23],"map1":{"preserve":1,"medium":2,"fast":3},"carInfo":{"Engine":{"Make":"ta","Power":{"IC":"950","battery":"170"},"Redline":"12500"} ,"Tyres":{"Make":"Pirelli","Compound":"c3","Life":"80"}}}
{"num1":7,"num2":9,"bool1":false,"bool2":true,"double1":33.2,"double2":7.5,"str1":"b","str2":"u","arr1":[12,14,5],"map1":{"normal":1,"preserve":2,"agressive":3},"carInfo":{"Engine":{"Make":"pa","Power":{"IC":"920","battery":"160"},"Redline":"11800"} ,"Tyres":{"Make":"Pirelli","Compound":"c4","Life":"70"}}}

Before:

root
 |-- arr1: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- bool1: boolean (nullable = true)
 |-- bool2: boolean (nullable = true)
 |-- carInfo: struct (nullable = true)
 |    |-- Engine: struct (nullable = true)
 |    |    |-- Make: string (nullable = true)
 |    |    |-- Power: struct (nullable = true)
 |    |    |    |-- IC: string (nullable = true)
 |    |    |    |-- battery: string (nullable = true)
 |    |    |-- Redline: string (nullable = true)
 |    |-- Tyres: struct (nullable = true)
 |    |    |-- Compound: string (nullable = true)
 |    |    |-- Life: string (nullable = true)
 |    |    |-- Make: string (nullable = true)
 |-- double1: double (nullable = true)
 |-- double2: double (nullable = true)
 |-- map1: struct (nullable = true)
 |    |-- agressive: long (nullable = true)
 |    |-- cool: long (nullable = true)
 |    |-- fast: long (nullable = true)
 |    |-- medium: long (nullable = true)
 |    |-- normal: long (nullable = true)
 |    |-- okay: long (nullable = true)
 |    |-- preserve: long (nullable = true)
 |-- num1: long (nullable = true)
 |-- num2: long (nullable = true)
 |-- str1: string (nullable = true)
 |-- str2: string (nullable = true

After:

root
 |-- arr1: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- bool1: boolean (nullable = true)
 |-- bool2: boolean (nullable = true)
 |-- double1: double (nullable = true)
 |-- double2: double (nullable = true)
 |-- num1: long (nullable = true)
 |-- num2: long (nullable = true)
 |-- str1: string (nullable = true)
 |-- str2: string (nullable = true)
 |-- map1_agressive: long (nullable = true)
 |-- map1_cool: long (nullable = true)
 |-- map1_fast: long (nullable = true)
 |-- map1_medium: long (nullable = true)
 |-- map1_normal: long (nullable = true)
 |-- map1_okay: long (nullable = true)
 |-- map1_preserve: long (nullable = true)
 |-- carInfo_Engine_Make: string (nullable = true)
 |-- carInfo_Engine_Redline: string (nullable = true)
 |-- carInfo_Tyres_Compound: string (nullable = true)
 |-- carInfo_Tyres_Life: string (nullable = true)
 |-- carInfo_Tyres_Make: string (nullable = true)
 |-- carInfo_Engine_Power_IC: string (nullable = true)
 |-- carInfo_Engine_Power_battery: string (nullable = true)

Tried for 2 Levels, it worked


R
Raj ks

We used https://github.com/lvhuyen/SparkAid It works to any Level

from sparkaid import flatten

flatten(df_nested_B).printSchema()