Moving towards the DataFrame API using the Spark Connect gRPC API in .NET

All Spark Connect Posts

Code

Goal of this post

So there are two goals of this post, the first is to take a look at Apache Arrow and how we can do things like show the output from DataFrame.Show, the second is to start to create objects that look more familiar to us, i.e. the DataFrame API.

I want to take it in small steps, I 100% know that this sort of syntax is possible:

var spark = SparkSession
    .Builder
    .Remote("http://localhost:15002")
    .GetOrCreate();

var dataFrame = spark.Range(1000);
dataFrame.Show();

var dataFrame2 = dataFrame
    .WithColumn("Hello",
        Lit("Hello From Spark, via Spark Connect"));

but as I said lets start closer to the gRPC API and take gentle steps in the right direction.

Relations

So far we have created all the relations on the fly and specified all the parameters but this gets boring quite quickly so here is an example of the Range relation wrapped in a method:

public static Relation RangeRelation(long end, long step = 1)
{
    return new Relation
    {
        Range = new Range
        {
            Step = step,
            End = end
        }
    };
}

so we can now call Range like:

var dataFrameLeft = RangeRelation(100);

Now looking at what else we can do, a common operation is WithColumn:

public static Relation WithColumn(this Relation input, string columnName, string value)
{
    var newColumn = new Expression.Types.Alias
    {
        Expr = new Expression
        {
            Literal = new Expression.Types.Literal
            {
                String = value
            }
        },
        Name = { columnName }
    };

    return new Relation
    {
        WithColumns = new WithColumns
        {
            Input = input,
            Aliases = { newColumn }
        }
    };
}

(I’ve created some extension functions hence the this Relation).

WithColumn is made up of an expression that is the value of the Literal (anyone for Functions.Lit?) and a name for the new column and then we pass the relation that we want to add the column to. Hopefully, in your head you should be thinking “Relation is like a DataFrame then”.

We can call this like:


var dataFrameLeft = RangeRelation(100).WithColumn("hello", "lit_string_value");
var dataFrameRight = RangeRelation(500).WithColumn("bye", "oooh");
dataFrameRight = dataFrameRight.WithColumn("and", "another");

Changing the second parameter from a string to an Expression doesn’t seem too much of a leap and we will get there eventually. This is starting to look closer to the DataFrame API, nice!

The next thing is to use the Relation objects that we have created and do something with them, here I use the Join Relation to join the two relations/DataFrames together and then use ShowString to generate the string representation of the table:

Plan = new Plan
{
    Root = new Relation
    {
        ShowString = new ShowString
        {
            Truncate = 100,
            NumRows = 30,
            Input = new Relation
            {
                Join = new Join
                {
                    JoinType = Join.Types.JoinType.Inner,
                    Left = dataFrameLeft,
                    Right = dataFrameRight
                }
            }
        }
    }

Again, wrapping the ShowString with a method won’t be a problem.

ExecutePlanResponse

So far we have done things like used Range to create a DataFrame and then the DataFrameWriter command to save that data as a parquet file, that is nice but we often want to do things like run DataFrame.Show() or SparkSession.Sql and save the DataFrame and then do something with it later on.

If we look at the ExecutePlanResponse what we get is a series of responses (like an array) of different possible things:

  1. Schema
  2. Arrow Batch (the data)
  3. SqlCommandOutput (think Databricks notebook, the output is now saved as _sqldf and this is what we get, a reference to the DataFrame)
  4. Metrics
  5. ObservedMetrics

According to the current Scala code: https://github.com/apache/spark/blob/master/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala#L66 the order will always be the same but it might change in future so you probably shouldn’t rely on it being in that order.

Schema

This is returned in the response.ResponseStream.Current.Schema and the Struct property contains the schema, I am not sure if this is always the case but in my few tests it looks to be right.

Arrow Batch

If we have an ArrowBatch then we will need to parse it to get the data out of it and into something we can read. There is a library for reading the Apache Arrow format in .NET so I would use that. In the sample accompanying the post I show how to convert a column of strings into a .NET List<object> so we can iterate and print them out see here.

If we want to show the actual output of the ShowString command then we will need to decode the data into rows and print them out, see the full example for this but believe me, it is awesome.

SqlCommandOutput

This is the returned DataFrame that is the result of the command, some commands like ShowString won’t return anything, some like Sql will. This is effectively optional and you should’t rely on it.

If it is there then it will be a valid Relation:

if (current.SqlCommandResult != null)
{
    resultDataframe = current.SqlCommandResult.Relation;
}

Metrics

Metrics are things like rows written:

if (current.Metrics != null)
    foreach (var metric in current.Metrics.Metrics_)
    foreach (var value in metric.ExecutionMetrics)
        Console.WriteLine($"Metric: {metric.Name}, {value.Key}, {value.Value.Name} = {value.Value.Value}");

outputs:

Metric: LocalTableScan, numOutputRows, number of output rows = 0

Observed Metrics

I’ve not seen any observed metrics yet, i’ll keep my eye out for them and this should display them:

if (current.ObservedMetrics != null)
    foreach (var metric in current.ObservedMetrics)
    foreach (var value in metric.Values)
        Console.WriteLine($"ObservedMetric: {metric.Name} = {value}");

Decoding Arrow Batch

This is what I decided to do, I am not really happy with having a massive switch statement to cast the array into the correct type but I checked the GO implementation and it ignores the arrow batch, and I checked the Rust implementation and that just seemed to be able to use the data as is. I’ll revisit this at some point and make it nicer but for now lets do this:

public List<Row> ToDataset(ExecutePlanResponse.Types.ArrowBatch batch)
    {
        var rows = new List<Row>();
        var reader = new ArrowStreamReader(batch.Data.Memory);

        var recordBatch = reader.ReadNextRecordBatch();
        //TODO - do we need to handle multiple record batches??

        var columnData = new List<IList<object>>();

        foreach (var array in recordBatch.Arrays)
        {
            var items = FromArray(array);
            columnData.Add(items);
        }

        for (var i = 0; i < recordBatch.Length; i++)
        {
            var row = new Row
            {
                Values = new List<object>(columnData.Count)
            };

            foreach (var column in columnData) row.Values.Add(column[i]);

            rows.Add(row);
        }

        return rows;
    }

    private static IList<object> FromArray(IArrowArray array)
    {
        var items = new List<object>(array.Length);

        switch (array.Data.DataType)
        {
            //TODO add all types and implement all these \0/ but think of something smarter first
            case Int16Type int16Type:
                break;
            case Int32Type int32Type:
                break;
            case Int64Type int64Type:

                foreach (var item in (array as Int64Array).Values) items.Add(item);

                break;
   
            case StringType stringType:

                var stringArray = array as StringArray;

                var lastOffset = 0;

                foreach (var offset in stringArray.ValueOffsets)
                {
                    if (offset == 0) continue;

                    items.Add(Encoding.UTF8.GetString(stringArray.Values.Slice(lastOffset, offset - lastOffset)));
                    lastOffset = offset;
                }

                break;
            default:

                throw new ArgumentOutOfRangeException();
        }

        return items;
    }

I know I have used an object called Dataset, it isn’t the same one as we see in Scala but I stared at the screen for a good 53 seconds and couldn’t think of a better name so I went with it. I am sure I will change it at some point.

Anyway using this we change the data from the columnar Apache Arrow format to a row based format, think about what your needs are, if you don’t need to do this then don’t, it will take time to do and use memory etc.

Now we have our rows in a dataset of sorts, lets just create a simple method to print out the values to the console:

public string AsString()
{
    var builder = new StringBuilder();
    foreach (var row in Rows)
    {
        foreach (var value in row.Values) builder.AppendFormat("{0},", value);

        if (builder.Length > 2) builder.Remove(builder.Length - 2, 2);

        builder.Append("\n");
    }

    return builder.ToString();
}

I have also created some extension methods that take the json representation of the gRPC API’s objects and pretty print them (JsonConvert’s Formatting.Indented):


public static class PrettyPrintExtensions
{
    public static string PrettyPrint(this Relation src) => Format(src.ToString());

    public static string PrettyPrint(this DataType.Types.Struct src) => Format(src.ToString());

    private static string Format(string json) => JsonConvert.SerializeObject(JsonConvert.DeserializeObject(json), Formatting.Indented);
}

Now putting it together, what do we have?

var dataFrameLeft = RangeRelation(100).WithColumn("hello", "lit_string_value");
var dataFrameRight = RangeRelation(500).WithColumn("bye", "oooh");
dataFrameRight = dataFrameRight.WithColumn("and", "another");

Console.WriteLine(dataFrameRight.PrettyPrint());
{
  "withColumns": {
    "input": {
      "withColumns": {
        "input": {
          "range": {
            "end": "500",
            "step": "1"
          }
        },
        "aliases": [
          {
            "expr": {
              "literal": {
                "string": "oooh"
              }
            },
            "name": [
              "bye"
            ]
          }
        ]
      }
    },
    "aliases": [
      {
        "expr": {
          "literal": {
            "string": "another"
          }
        },
        "name": [
          "and"
        ]
      }
    ]
  }
}
var request = new ExecutePlanRequest
{
    SessionId = sessionId,
    UserContext = new UserContext(),
    ClientType = ".NET Cool",

    Plan = new Plan
    {
        Root = new Relation
        {
            ShowString = new ShowString
            {
                Truncate = 100,
                NumRows = 30,
                Input = new Relation
                {
                    Join = new Join
                    {
                        JoinType = Join.Types.JoinType.Inner,
                        Left = dataFrameLeft,
                        Right = dataFrameRight
                    }
                }
            }
        }
    }
};

await request.Execute(client, headers);

with Execute looking like:

var response = client.ExecutePlan(executePlanRequest, headers);
DataType.Types.Struct? schema = null;

Relation? resultDataframe = null;

while (await response.ResponseStream.MoveNext())
{
    var current = response.ResponseStream.Current;
    if (current.ResponseId != null) Console.WriteLine($"Response ID: {current.ResponseId}");

    if (current.Schema != null)
    {
        Console.WriteLine(current.Schema.Struct.PrettyPrint());
        if (current.Schema.Struct != null) schema = current.Schema.Struct;
    }

    if (current.ArrowBatch != null)
    {
        var rows = new ArrowBatchWrapper().ToDataset(current.ArrowBatch);
        var dataset = new Dataset(schema, rows);
        Console.WriteLine($"We have an arrow batch: {dataset.Rows.Count}");
        Console.WriteLine(dataset.AsString());
    }

    if (current.Metrics != null)
        foreach (var metric in current.Metrics.Metrics_)
        foreach (var value in metric.ExecutionMetrics)
            Console.WriteLine($"Metric: {metric.Name}, {value.Key}, {value.Value.Name} = {value.Value.Value}");

    if (current.ObservedMetrics != null)
        foreach (var metric in current.ObservedMetrics)
        foreach (var value in metric.Values)
            Console.WriteLine($"ObservedMetric: {metric.Name} = {value}");

    if (current.SqlCommandResult != null)
    {
        resultDataframe = current.SqlCommandResult.Relation;
    }
}

if (resultDataframe != null)
{
    Console.WriteLine("**** Result data frame ****");
    Console.WriteLine(resultDataframe.PrettyPrint());
}

return resultDataframe;
}

We get this, the ShowString output is sent to us as an arrow batch and we can decode it and just do a Console.WriteLine on it:

Response ID: 754399c1-91ba-40b6-8cbe-cb3d5c24b99f
{
  "fields": [
    {
      "name": "show_string",
      "dataType": {
        "string": {}
      }
    }
  ]
}
Response ID: 08cbbb62-732b-455a-83a6-ca073ecac121
We have an arrow batch: 1
+---+----------------+---+----+-------+
| id|           hello| id| bye|    and|
+---+----------------+---+----+-------+
|  0|lit_string_value|  0|oooh|another|
|  1|lit_string_value|  0|oooh|another|
|  2|lit_string_value|  0|oooh|another|
|  3|lit_string_value|  0|oooh|another|
|  4|lit_string_value|  0|oooh|another|
|  5|lit_string_value|  0|oooh|another|
|  6|lit_string_value|  0|oooh|another|
|  7|lit_string_value|  0|oooh|another|
|  8|lit_string_value|  0|oooh|another|
|  9|lit_string_value|  0|oooh|another|
| 10|lit_string_value|  0|oooh|another|
| 11|lit_string_value|  0|oooh|another|
| 12|lit_string_value|  0|oooh|another|
| 13|lit_string_value|  0|oooh|another|
| 14|lit_string_value|  0|oooh|another|
| 15|lit_string_value|  0|oooh|another|
| 16|lit_string_value|  0|oooh|another|
| 17|lit_string_value|  0|oooh|another|
| 18|lit_string_value|  0|oooh|another|
| 19|lit_string_value|  0|oooh|another|
| 20|lit_string_value|  0|oooh|another|
| 21|lit_string_value|  0|oooh|another|
| 22|lit_string_value|  0|oooh|another|
| 23|lit_string_value|  0|oooh|another|
| 24|lit_string_value|  0|oooh|another|
| 25|lit_string_value|  0|oooh|another|
| 26|lit_string_value|  0|oooh|another|
| 27|lit_string_value|  0|oooh|another|
| 28|lit_string_value|  0|oooh|another|
| 29|lit_string_value|  0|oooh|another|
+---+----------------+---+----+-------+
only showing top 30 rows

Response ID: 6de00ead-14ba-40e4-817c-17e7d05e4731
Metric: LocalTableScan, numOutputRows, number of output rows = 0

Then if we go back to running a command that returns a Relation and no ArrowBatch:

request = new ExecutePlanRequest
{
    SessionId = sessionId,
    UserContext = new UserContext(),
    ClientType = ".NET Cool",

    Plan = new Plan
    {
        Command = new Command()
        {
            SqlCommand = new SqlCommand()
            {
                Sql = "with a as (SELECT *, 'newcol' as a FROM range(5)), b as (select *, 'another' as b from range(15)) select * from a right outer join b on a.id = b.id "
            }
        }
    }
};

var returnedRelation = await request.Execute(client, headers);

output:

Response ID: 08175db3-bf95-48dd-86be-4a1d9de4fa9e
Response ID: e807b3b2-bfbb-4410-bc7f-e3a3e5991920
Metric: BroadcastHashJoin, numOutputRows, number of output rows = 0
Metric: BroadcastExchange, broadcastTime, time to broadcast = 0
Metric: BroadcastExchange, buildTime, time to build = 0
Metric: BroadcastExchange, collectTime, time to collect = 0
Metric: BroadcastExchange, numOutputRows, number of output rows = 0
Metric: BroadcastExchange, dataSize, data size = 0
Metric: Range, numOutputRows, number of output rows = 0
Metric: Range, numOutputRows, number of output rows = 0
**** Result data frame ****
{
  "sql": {
    "query": "with a as (SELECT *, 'newcol' as a FROM range(5)), b as (select *, 'another' as b from range(15)) select * from a right outer join b on a.id = b.id "
  }
}

Then if we take the relation from the Sql and run explain on it like:


var analyzePlanRequest = new AnalyzePlanRequest()
{
    SessionId = sessionId,
    UserContext = new UserContext(),
    ClientType = ".NET Cool",
    Explain = new AnalyzePlanRequest.Types.Explain()
    {
        ExplainMode = AnalyzePlanRequest.Types.Explain.Types.ExplainMode.Extended,
        Plan = new Plan()
        {
            Root = returnedRelation
        }
    }
};

var response = client.AnalyzePlan(analyzePlanRequest, headers);
Console.WriteLine(response.Explain.ExplainString);

we get:

== Parsed Logical Plan ==
CTE [a, b]
:  :- 'SubqueryAlias a
:  :  +- 'Project [*, newcol AS a#86]
:  :     +- 'UnresolvedTableValuedFunction [range], [5]
:  +- 'SubqueryAlias b
:     +- 'Project [*, another AS b#87]
:        +- 'UnresolvedTableValuedFunction [range], [15]
+- 'Project [*]
   +- 'Join RightOuter, ('a.id = 'b.id)
      :- 'UnresolvedRelation [a], [], false
      +- 'UnresolvedRelation [b], [], false

== Analyzed Logical Plan ==
id: bigint, a: string, id: bigint, b: string
WithCTE
:- CTERelationDef 4, false
:  +- SubqueryAlias a
:     +- Project [id#88L, newcol AS a#86]
:        +- Range (0, 5, step=1, splits=None)
:- CTERelationDef 5, false
:  +- SubqueryAlias b
:     +- Project [id#89L, another AS b#87]
:        +- Range (0, 15, step=1, splits=None)
+- Project [id#88L, a#86, id#89L, b#87]
   +- Join RightOuter, (id#88L = id#89L)
      :- SubqueryAlias a
      :  +- CTERelationRef 4, true, [id#88L, a#86]
      +- SubqueryAlias b
         +- CTERelationRef 5, true, [id#89L, b#87]

== Optimized Logical Plan ==
Join RightOuter, (id#88L = id#89L)
:- Project [id#88L, newcol AS a#86]
:  +- Range (0, 5, step=1, splits=None)
+- Project [id#89L, another AS b#87]
   +- Range (0, 15, step=1, splits=None)

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- BroadcastHashJoin [id#88L], [id#89L], RightOuter, BuildLeft, false
   :- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [plan_id=153]
   :  +- Project [id#88L, newcol AS a#86]
   :     +- Range (0, 5, step=1, splits=10)
   +- Project [id#89L, another AS b#87]
      +- Range (0, 15, step=1, splits=10)

Finally we can call ShowString on the relation we got back from Sql:

await returnedRelation.Show(100, 3, sessionId, client, headers);

which uses this extension method:

public static async Task Show(this Relation src, int truncate, int numRows, string sessionId, SparkConnectService.SparkConnectServiceClient client, Metadata headers)
{
    var request = new ExecutePlanRequest
    {
        SessionId = sessionId,
        UserContext = new UserContext(),
        ClientType = ".NET Cool",

        Plan = new Plan
        {
            Root = new Relation
            {
                ShowString = new ShowString
                {
                    Truncate = truncate,
                    NumRows = numRows,
                    Input = src
                }
            }
        }
    };

    await request.Execute(client, headers);
}

and results in this output:

**** Show ****
Response ID: 7865c173-c7a2-469c-87ea-567a270ee9fd
{
  "fields": [
    {
      "name": "show_string",
      "dataType": {
        "string": {}
      }
    }
  ]
}
Response ID: edcc9d8d-bf2d-4f50-9f07-d82e7fb340e8
We have an arrow batch: 1
+---+------+---+-------+
| id|     a| id|      b|
+---+------+---+-------+
|  0|newcol|  0|another|
|  1|newcol|  1|another|
|  2|newcol|  2|another|
+---+------+---+-------+
only showing top 3 rows

Response ID: 00b888f2-df59-4c13-b34d-2c556ffe148a
Metric: LocalTableScan, numOutputRows, number of output rows = 0

So there you are, by storing the results of the calls and by starting to create some more famliar objects we can start to move towards the DataFrame API.