I have a dataframe with the following structure:
|-- data: struct (nullable = true)
| |-- id: long (nullable = true)
| |-- keyNote: struct (nullable = true)
| | |-- key: string (nullable = true)
| | |-- note: string (nullable = true)
| |-- details: map (nullable = true)
| | |-- key: string
| | |-- value: string (valueContainsNull = true)
How it is possible to flatten the structure and create a new dataframe:
|-- id: long (nullable = true)
|-- keyNote: struct (nullable = true)
| |-- key: string (nullable = true)
| |-- note: string (nullable = true)
|-- details: map (nullable = true)
| |-- key: string
| |-- value: string (valueContainsNull = true)
Is there something like explode, but for structs?
This should work in Spark 1.6 or later:
df.select(df.col("data.*"))
or
df.select(df.col("data.id"), df.col("data.keyNote"), df.col("data.details"))
Here is function that is doing what you want and that can deal with multiple nested columns containing columns with same name:
import pyspark.sql.functions as F
def flatten_df(nested_df):
flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']
flat_df = nested_df.select(flat_cols +
[F.col(nc+'.'+c).alias(nc+'_'+c)
for nc in nested_cols
for c in nested_df.select(nc+'.*').columns])
return flat_df
Before:
root
|-- x: string (nullable = true)
|-- y: string (nullable = true)
|-- foo: struct (nullable = true)
| |-- a: float (nullable = true)
| |-- b: float (nullable = true)
| |-- c: integer (nullable = true)
|-- bar: struct (nullable = true)
| |-- a: float (nullable = true)
| |-- b: float (nullable = true)
| |-- c: integer (nullable = true)
After:
root
|-- x: string (nullable = true)
|-- y: string (nullable = true)
|-- foo_a: float (nullable = true)
|-- foo_b: float (nullable = true)
|-- foo_c: integer (nullable = true)
|-- bar_a: float (nullable = true)
|-- bar_b: float (nullable = true)
|-- bar_c: integer (nullable = true)
For Spark 2.4.5,
while,df.select(df.col("data.*"))
will give you org.apache.spark.sql.AnalysisException: No such struct field * in
exception
this will work:-
df.select($"data.*")
data
or whatever is parent is selected -- and doesn't descend if there are further nested structs.
df.select("data.*")
This flatten_df
version flattens the dataframe at every layer level, using a stack to avoid recursive calls:
from pyspark.sql.functions import col
def flatten_df(nested_df):
stack = [((), nested_df)]
columns = []
while len(stack) > 0:
parents, df = stack.pop()
flat_cols = [
col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
for c in df.dtypes
if c[1][:6] != "struct"
]
nested_cols = [
c[0]
for c in df.dtypes
if c[1][:6] == "struct"
]
columns.extend(flat_cols)
for nested_col in nested_cols:
projected_df = df.select(nested_col + ".*")
stack.append((parents + (nested_col,), projected_df))
return nested_df.select(columns)
Example:
from pyspark.sql.types import StringType, StructField, StructType
schema = StructType([
StructField("some", StringType()),
StructField("nested", StructType([
StructField("nestedchild1", StringType()),
StructField("nestedchild2", StringType())
])),
StructField("renested", StructType([
StructField("nested", StructType([
StructField("nestedchild1", StringType()),
StructField("nestedchild2", StringType())
]))
]))
])
data = [
{
"some": "value1",
"nested": {
"nestedchild1": "value2",
"nestedchild2": "value3",
},
"renested": {
"nested": {
"nestedchild1": "value4",
"nestedchild2": "value5",
}
}
}
]
df = spark.createDataFrame(data, schema)
flat_df = flatten_df(df)
print(flat_df.collect())
Prints:
[Row(some=u'value1', renested_nested_nestedchild1=u'value4', renested_nested_nestedchild2=u'value5', nested_nestedchild1=u'value2', nested_nestedchild2=u'value3')]
I generalized the solution from stecos a bit more so the flattening can be done on more than two struct layers deep:
def flatten_df(nested_df, layers):
flat_cols = []
nested_cols = []
flat_df = []
flat_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] != 'struct'])
nested_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] == 'struct'])
flat_df.append(nested_df.select(flat_cols[0] +
[col(nc+'.'+c).alias(nc+'_'+c)
for nc in nested_cols[0]
for c in nested_df.select(nc+'.*').columns])
)
for i in range(1, layers):
print (flat_cols[i-1])
flat_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] != 'struct'])
nested_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] == 'struct'])
flat_df.append(flat_df[i-1].select(flat_cols[i] +
[col(nc+'.'+c).alias(nc+'_'+c)
for nc in nested_cols[i]
for c in flat_df[i-1].select(nc+'.*').columns])
)
return flat_df[-1]
just call with:
my_flattened_df = flatten_df(my_df_having_nested_structs, 3)
(second parameter is the level of layers to be flattened, in my case it's 3)
A little more compact and efficient implementation:
No need to create list and iterate on them. You "act" on fields based on their type (if structures or not).
you create a list and iterate on it, if the column is nested (struct) you need to flat it (.*) else you access with dot notation (parent.child) and replace . with _ (parent_child)
def flatten_df(nested_df):
stack = [((), nested_df)]
columns = []
while len(stack) > 0:
parents, df = stack.pop()
for column_name, column_type in df.dtypes:
if column_type[:6] == "struct":
projected_df = df.select(column_name + ".*")
stack.append((parents + (column_name,), projected_df))
else:
columns.append(col(".".join(parents + (column_name,))).alias("_".join(parents + (column_name,))))
return nested_df.select(columns)
PySpark solution to flatten nested df with both struct and array types with any level of depth. This is improved on this: https://stackoverflow.com/a/56533459/7131019
from pyspark.sql.types import *
from pyspark.sql import functions as f
def flatten_structs(nested_df):
stack = [((), nested_df)]
columns = []
while len(stack) > 0:
parents, df = stack.pop()
array_cols = [
c[0]
for c in df.dtypes
if c[1][:5] == "array"
]
flat_cols = [
f.col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
for c in df.dtypes
if c[1][:6] != "struct"
]
nested_cols = [
c[0]
for c in df.dtypes
if c[1][:6] == "struct"
]
columns.extend(flat_cols)
for nested_col in nested_cols:
projected_df = df.select(nested_col + ".*")
stack.append((parents + (nested_col,), projected_df))
return nested_df.select(columns)
def flatten_array_struct_df(df):
array_cols = [
c[0]
for c in df.dtypes
if c[1][:5] == "array"
]
while len(array_cols) > 0:
for array_col in array_cols:
cols_to_select = [x for x in df.columns if x != array_col ]
df = df.withColumn(array_col, f.explode(f.col(array_col)))
df = flatten_structs(df)
array_cols = [
c[0]
for c in df.dtypes
if c[1][:5] == "array"
]
return df
flat_df = flatten_array_struct_df(df)
array of struct
can produce duplicate rows as others said.
You can use this approach if, you have to covert only struct types. I would not suggest converting the array, as it could lead to duplicate records.
from pyspark.sql.functions import col
from pyspark.sql.types import StructType
def flatten_schema(schema, prefix=""):
return_schema = []
for field in schema.fields:
if isinstance(field.dataType, StructType):
if prefix:
return_schema = return_schema + flatten_schema(field.dataType, "{}.{}".format(prefix, field.name))
else:
return_schema = return_schema + flatten_schema(field.dataType, field.name)
else:
if prefix:
field_path = "{}.{}".format(prefix, field.name)
return_schema.append(col(field_path).alias(field_path.replace(".", "_")))
else:
return_schema.append(field.name)
return return_schema
You can use it as
new_schema = flatten_schema(df.schema)
df1 = df.select(se)
df1.show()
Based on https://stackoverflow.com/a/49532496/17250408 here is solution for struct
and array
fields with multilevel nesting
from pyspark.sql.functions import col, explode
def type_cols(df_dtypes, filter_type):
cols = []
for col_name, col_type in df_dtypes:
if col_type.startswith(filter_type):
cols.append(col_name)
return cols
def flatten_df(nested_df, sep='_'):
nested_cols = type_cols(nested_df.dtypes, "struct")
flatten_cols = [fc for fc, _ in nested_df.dtypes if fc not in nested_cols]
for nc in nested_cols:
for cc in nested_df.select(f"{nc}.*").columns:
if sep is None:
flatten_cols.append(col(f"{nc}.{cc}").alias(f"{cc}"))
else:
flatten_cols.append(col(f"{nc}.{cc}").alias(f"{nc}{sep}{cc}"))
return nested_df.select(flatten_cols)
def explode_df(nested_df):
nested_cols = type_cols(nested_df.dtypes, "array")
exploded_df = nested_df
for nc in nested_cols:
exploded_df = exploded_df.withColumn(nc, explode(col(nc)))
return exploded_df
def flatten_explode_df(nested_df):
df = nested_df
struct_cols = type_cols(nested_df.dtypes, "struct")
array_cols = type_cols(nested_df.dtypes, "array")
if struct_cols:
df = flatten_df(df)
return flatten_explode_df(df)
if array_cols:
df = explode_df(df)
return flatten_explode_df(df)
return df
df = flatten_explode_df(nested_df)
An easy way is to use SQL, you could build a SQL query string to alias nested column as flat ones.
Retrieve data-frame schema (df.schema())
Transform schema to SQL (for (field : schema().fields()) ...
Query: val newDF = sqlContext.sql("SELECT " + sqlGenerated + " FROM source")
Here is an example in Java.
(I prefer SQL way, so you can easily test it on Spark-shell and it's cross-language).
below worked for me in spark sql
import org.apache.spark.sql._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
import org.apache.http.client.methods.HttpGet
import org.apache.http.impl.client.DefaultHttpClient
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.functions.{explode, expr, posexplode, when}
object StackOverFlowQuestion {
def main(args: Array[String]): Unit = {
val logger = Logger.getLogger("FlattenTest")
Logger.getLogger("org").setLevel(Level.WARN)
Logger.getLogger("akka").setLevel(Level.WARN)
val spark = SparkSession.builder()
.appName("FlattenTest")
.config("spark.sql.warehouse.dir", "C:\\Temp\\hive")
.master("local[2]")
//.enableHiveSupport()
.getOrCreate()
import spark.implicits._
val stringTest =
"""{
"total_count": 123,
"page_size": 20,
"another_id": "gdbfdbfdbd",
"sen": [{
"id": 123,
"ses_id": 12424343,
"columns": {
"blah": "blah",
"count": 1234
},
"class": {},
"class_timestamps": {},
"sentence": "spark is good"
}]
}
"""
val result = List(stringTest)
val githubRdd=spark.sparkContext.makeRDD(result)
val gitHubDF=spark.read.json(githubRdd)
gitHubDF.show()
gitHubDF.printSchema()
gitHubDF.registerTempTable("JsonTable")
spark.sql("with cte as" +
"(" +
"select explode(sen) as senArray from JsonTable" +
"), cte_2 as" +
"(" +
"select senArray.ses_id,senArray.ses_id,senArray.columns.* from cte" +
")" +
"select * from cte_2"
).show()
spark.stop()
}
}
output:-
+----------+---------+--------------------+-----------+
|another_id|page_size| sen|total_count|
+----------+---------+--------------------+-----------+
|gdbfdbfdbd| 20|[[[blah, 1234], 1...| 123|
+----------+---------+--------------------+-----------+
root
|-- another_id: string (nullable = true)
|-- page_size: long (nullable = true)
|-- sen: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- columns: struct (nullable = true)
| | | |-- blah: string (nullable = true)
| | | |-- count: long (nullable = true)
| | |-- id: long (nullable = true)
| | |-- sentence: string (nullable = true)
| | |-- ses_id: long (nullable = true)
|-- total_count: long (nullable = true)
+--------+--------+----+-----+
| ses_id| ses_id|blah|count|
+--------+--------+----+-----+
|12424343|12424343|blah| 1234|
+--------+--------+----+-----+
This is for scala spark.
val totalMainArrayBuffer=collection.mutable.ArrayBuffer[String]()
def flatten_df_Struct(dfTemp:org.apache.spark.sql.DataFrame,dfTotalOuter:org.apache.spark.sql.DataFrame):org.apache.spark.sql.DataFrame=
{
//dfTemp.printSchema
val totalStructCols=dfTemp.dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(_.split(",",2)(1).contains("Struct")) // in case i the column names come with the word Struct embedded in it
val mainArrayBuffer=collection.mutable.ArrayBuffer[String]()
for(totalStructCol <- totalStructCols)
{
val tempArrayBuffer=collection.mutable.ArrayBuffer[String]()
tempArrayBuffer+=s"${totalStructCol.split(",")(0)}.*"
//tempArrayBuffer.toSeq.toDF.show(false)
val columnsInside=dfTemp.selectExpr(tempArrayBuffer:_*).columns
for(column <- columnsInside)
mainArrayBuffer+=s"${totalStructCol.split(",")(0)}.${column} as ${totalStructCol.split(",")(0)}_${column}"
//mainArrayBuffer.toSeq.toDF.show(false)
}
//dfTemp.selectExpr(mainArrayBuffer:_*).printSchema
val nonStructCols=dfTemp.selectExpr(mainArrayBuffer:_*).dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(!_.split(",",2)(1).contains("Struct")) // in case i the column names come with the word Struct embedded in it
for (nonStructCol <- nonStructCols)
totalMainArrayBuffer+=s"${nonStructCol.split(",")(0).replace("_",".")} as ${nonStructCol.split(",")(0)}" // replacing _ by . in origial select clause if it's an already nested column
dfTemp.selectExpr(mainArrayBuffer:_*).dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(_.split(",",2)(1).contains("Struct")).size
match {
case value if value ==0 => dfTotalOuter.selectExpr(totalMainArrayBuffer:_*)
case _ => flatten_df_Struct(dfTemp.selectExpr(mainArrayBuffer:_*),dfTotalOuter)
}
}
def flatten_df(dfTemp:org.apache.spark.sql.DataFrame):org.apache.spark.sql.DataFrame=
{
var totalArrayBuffer=collection.mutable.ArrayBuffer[String]()
val totalNonStructCols=dfTemp.dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(!_.split(",",2)(1).contains("Struct")) // in case i the column names come with the word Struct embedded in it
for (totalNonStructCol <- totalNonStructCols)
totalArrayBuffer+=s"${totalNonStructCol.split(",")(0)}"
totalMainArrayBuffer.clear
flatten_df_Struct(dfTemp,dfTemp) // flattened schema is now in totalMainArrayBuffer
totalArrayBuffer=totalArrayBuffer++totalMainArrayBuffer
dfTemp.selectExpr(totalArrayBuffer:_*)
}
flatten_df(dfTotal.withColumn("tempStruct",lit(5))).printSchema
File
{"num1":1,"num2":2,"bool1":true,"bool2":false,"double1":4.5,"double2":5.6,"str1":"a","str2":"b","arr1":[3,4,5],"map1":{"cool":1,"okay":2,"normal":3},"carInfo":{"Engine":{"Make":"sa","Power":{"IC":"900","battery":"165"},"Redline":"11500"} ,"Tyres":{"Make":"Pirelli","Compound":"c1","Life":"120"}}}
{"num1":3,"num2":4,"bool1":false,"bool2":false,"double1":4.2,"double2":5.5,"str1":"u","str2":"n","arr1":[6,7,9],"map1":{"fast":1,"medium":2,"agressive":3},"carInfo":{"Engine":{"Make":"na","Power":{"IC":"800","battery":"150"},"Redline":"10000"} ,"Tyres":{"Make":"Pirelli","Compound":"c2","Life":"100"}}}
{"num1":8,"num2":4,"bool1":true,"bool2":true,"double1":5.7,"double2":7.5,"str1":"t","str2":"k","arr1":[11,12,23],"map1":{"preserve":1,"medium":2,"fast":3},"carInfo":{"Engine":{"Make":"ta","Power":{"IC":"950","battery":"170"},"Redline":"12500"} ,"Tyres":{"Make":"Pirelli","Compound":"c3","Life":"80"}}}
{"num1":7,"num2":9,"bool1":false,"bool2":true,"double1":33.2,"double2":7.5,"str1":"b","str2":"u","arr1":[12,14,5],"map1":{"normal":1,"preserve":2,"agressive":3},"carInfo":{"Engine":{"Make":"pa","Power":{"IC":"920","battery":"160"},"Redline":"11800"} ,"Tyres":{"Make":"Pirelli","Compound":"c4","Life":"70"}}}
Before:
root
|-- arr1: array (nullable = true)
| |-- element: long (containsNull = true)
|-- bool1: boolean (nullable = true)
|-- bool2: boolean (nullable = true)
|-- carInfo: struct (nullable = true)
| |-- Engine: struct (nullable = true)
| | |-- Make: string (nullable = true)
| | |-- Power: struct (nullable = true)
| | | |-- IC: string (nullable = true)
| | | |-- battery: string (nullable = true)
| | |-- Redline: string (nullable = true)
| |-- Tyres: struct (nullable = true)
| | |-- Compound: string (nullable = true)
| | |-- Life: string (nullable = true)
| | |-- Make: string (nullable = true)
|-- double1: double (nullable = true)
|-- double2: double (nullable = true)
|-- map1: struct (nullable = true)
| |-- agressive: long (nullable = true)
| |-- cool: long (nullable = true)
| |-- fast: long (nullable = true)
| |-- medium: long (nullable = true)
| |-- normal: long (nullable = true)
| |-- okay: long (nullable = true)
| |-- preserve: long (nullable = true)
|-- num1: long (nullable = true)
|-- num2: long (nullable = true)
|-- str1: string (nullable = true)
|-- str2: string (nullable = true
After:
root
|-- arr1: array (nullable = true)
| |-- element: long (containsNull = true)
|-- bool1: boolean (nullable = true)
|-- bool2: boolean (nullable = true)
|-- double1: double (nullable = true)
|-- double2: double (nullable = true)
|-- num1: long (nullable = true)
|-- num2: long (nullable = true)
|-- str1: string (nullable = true)
|-- str2: string (nullable = true)
|-- map1_agressive: long (nullable = true)
|-- map1_cool: long (nullable = true)
|-- map1_fast: long (nullable = true)
|-- map1_medium: long (nullable = true)
|-- map1_normal: long (nullable = true)
|-- map1_okay: long (nullable = true)
|-- map1_preserve: long (nullable = true)
|-- carInfo_Engine_Make: string (nullable = true)
|-- carInfo_Engine_Redline: string (nullable = true)
|-- carInfo_Tyres_Compound: string (nullable = true)
|-- carInfo_Tyres_Life: string (nullable = true)
|-- carInfo_Tyres_Make: string (nullable = true)
|-- carInfo_Engine_Power_IC: string (nullable = true)
|-- carInfo_Engine_Power_battery: string (nullable = true)
Tried for 2 Levels, it worked
We used https://github.com/lvhuyen/SparkAid It works to any Level
from sparkaid import flatten
flatten(df_nested_B).printSchema()
Success story sharing