Examples
This page provides examples for all types of PySpark UDFs supported by Sail. For more information about the API, please refer to the Spark documentation and user guide.
INFO
In the code below, spark refers to a Spark client session connected to the Sail server. You can refer to the Getting Started guide for how it works.
>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.remote("sc://localhost:50051").getOrCreate()Python UDFs
You can define a Python scalar UDF by wrapping a Python lambda function with udf(), or using the @udf decorator.
>>> from pyspark.sql.functions import udf
>>>
>>> square = udf(lambda x: x**2, "integer")
>>> spark.range(3).select(square("id")).show()
+------------+
|<lambda>(id)|
+------------+
| 0|
| 1|
| 4|
+------------+
>>> @udf("string")
... def upper(s):
... return s.upper() if s is not None else s
...
>>> spark.sql("SELECT 'hello' as v").select(upper("v")).show()
+--------+
|upper(v)|
+--------+
| HELLO|
+--------+The UDF can be registered and used in SQL queries.
>>> from pyspark.sql.functions import udf
>>>
>>> @udf("integer")
... def cube(x):
... return x**3
...
>>> spark.udf.register("cube", cube)
<function cube at ...>
>>> spark.sql("SELECT cube(5)").show()
+-------+
|cube(5)|
+-------+
| 125|
+-------+Pandas UDFs
Pandas Scalar UDFs
You can define a Pandas scalar UDF via the @pandas_udf decorator. The UDF takes a batch of data as a Pandas object, and must return a Pandas object of the same length as the input. The UDF can also take an iterator of Pandas objects and returns an iterator of Pandas objects, one for each input batch.
>>> from typing import Iterator
>>> import pandas as pd
>>> from pyspark.sql.functions import pandas_udf
>>>
>>> @pandas_udf("string")
... def upper(s):
... return s.str.upper()
...
>>> spark.sql("SELECT 'hello' as v").select(upper("v")).show()
+--------+
|upper(v)|
+--------+
| HELLO|
+--------+
>>> @pandas_udf("long")
... def square(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
... for s in iterator:
... yield s**2
...
>>> spark.range(3).select(square("id")).show()
+----------+
|square(id)|
+----------+
| 0|
| 1|
| 4|
+----------+Pandas Aggregate UDFs
You can define a Pandas UDF for aggregation. The UDF returns a single value for the input group.
>>> from pyspark.sql.functions import PandasUDFType, pandas_udf
>>>
>>> @pandas_udf("int", PandasUDFType.GROUPED_AGG)
... def least(v):
... return v.min()
...
>>> df = spark.createDataFrame([(1, "Hello"), (2, "Hi"), (5, "Hello"), (0, "Hello")], ["n", "text"])
>>> df.groupBy(df.text).agg(least(df.n)).sort("text").show()
+-----+--------+
| text|least(n)|
+-----+--------+
|Hello| 0|
| Hi| 2|
+-----+--------+The Pandas UDF can be registered for use in SQL queries, in the same way as the Python scalar UDF.
Pandas Map UDFs
You can define a Pandas UDF to transform data partitions using mapInPandas().
>>> def transform(iterator):
... for pdf in iterator:
... yield pdf[pdf.text.str.len() > 2]
...
>>> df = spark.createDataFrame([(1, "Hello"), (2, "Hi")], ["id", "text"])
>>> df.mapInPandas(transform, schema=df.schema).show()
+---+-----+
| id| text|
+---+-----+
| 1|Hello|
+---+-----+Pandas Grouped or Co-grouped Map UDFs
You can define a Pandas UDF to transform grouped data or co-grouped data using applyInPandas().
>>> import pandas as pd
>>>
>>> def center(pdf):
... return pdf.assign(v=pdf.v - pdf.v.mean())
...
>>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0)], ["id", "v"])
>>> df.groupby("id").applyInPandas(center, schema="id long, v double").sort("id", "v").show()
+---+----+
| id| v|
+---+----+
| 1|-0.5|
| 1| 0.5|
| 2| 0.0|
+---+----+>>> import pandas as pd
>>>
>>> def summary(key, lpdf, rpdf):
... (key,) = key
... return pd.DataFrame({
... "id": [key],
... "diff": [lpdf.v.mean() - rpdf.v.mean()],
... "count": [lpdf.v.count() + rpdf.v.count()],
... })
...
>>> df1 = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0)], ["id", "v"])
>>> df2 = spark.createDataFrame([(1, 5.0)], ["id", "v"])
>>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
... summary, schema="id long, diff double, count long"
... ).sort("id").show()
+---+----+-----+
| id|diff|count|
+---+----+-----+
| 1|-3.5| 3|
| 2|NULL| 1|
+---+----+-----+Arrow UDFs
Arrow Scalar UDFs
You can define an Arrow scalar UDF via the @arrow_udf decorator. The UDF takes a batch of data as a pyarrow.Array object, and must return a pyarrow.Array object of the same length as the input. The UDF can also take an iterator of pyarrow.Array objects and returns an iterator of pyarrow.Array objects, one for each input batch.
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> from pyspark.sql.functions import arrow_udf
>>>
>>> @arrow_udf("long")
... def square(v: pa.Array) -> pa.Array:
... return pc.multiply(v, v)
...
>>> spark.range(3).select(square("id")).show()
+----------+
|square(id)|
+----------+
| 0|
| 1|
| 4|
+----------+
>>> from typing import Iterator
>>>
>>> @arrow_udf("long")
... def increment(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
... for v in it:
... yield pc.add(v, 1)
...
>>> spark.range(3).select(increment("id")).show()
+-------------+
|increment(id)|
+-------------+
| 1|
| 2|
| 3|
+-------------+Arrow Aggregate UDFs
You can define an Arrow UDF for aggregation. The UDF returns a single value for the input group.
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> from pyspark.sql.functions import arrow_udf
>>>
>>> @arrow_udf("double")
... def average(v: pa.Array) -> float:
... return pc.mean(v).as_py()
...
>>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ["id", "v"])
>>> df.groupby("id").agg(average("v")).sort("id").show()
+---+----------+
| id|average(v)|
+---+----------+
| 1| 1.5|
| 2| 6.0|
+---+----------+Arrow Map UDFs
You can define an Arrow UDF to transform data partitions using mapInArrow(). This is similar to mapInPandas() but the input and output are Arrow record batches.
>>> import pyarrow.compute as pc
>>>
>>> def transform(iterator):
... for batch in iterator:
... yield batch.filter(pc.utf8_length(pc.field("text")) > 2)
...
>>> df = spark.createDataFrame([(1, "Hello"), (2, "Hi")], ["id", "text"])
>>> df.mapInArrow(transform, schema=df.schema).show()
+---+-----+
| id| text|
+---+-----+
| 1|Hello|
+---+-----+Arrow Grouped or Co-grouped Map UDFs
You can define an Arrow UDF to transform grouped data or co-grouped data using applyInArrow(). This is similar to applyInPandas() but the input and output are Arrow record batches.
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>>
>>> def center(t):
... v = t.column("v")
... norm = pc.subtract(v, pc.mean(v))
... return t.set_column(1, "v", norm)
...
>>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0)], ["id", "v"])
>>> df.groupby("id").applyInArrow(center, schema="id long, v double").sort("id", "v").show()
+---+----+
| id| v|
+---+----+
| 1|-0.5|
| 1| 0.5|
| 2| 0.0|
+---+----+>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>>
>>> def diff(l, r):
... return pa.table({
... "id": l.column("id").slice(0, 1),
... "d": [pc.subtract(pc.mean(l.column("v")), pc.mean(r.column("v"))).as_py()],
... })
...
>>> df1 = spark.createDataFrame([(1, 1.0), (1, 2.0)], ["id", "v"])
>>> df2 = spark.createDataFrame([(1, 5.0)], ["id", "v"])
>>> df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow(
... diff, schema="id long, d double"
... ).show()
+---+----+
| id| d|
+---+----+
| 1|-3.5|
+---+----+Python UDTFs
You can define a Python UDTF class that produces multiple rows for each input row.
>>> from pyspark.sql.functions import lit, udtf
>>>
>>> class Fibonacci:
... def eval(self, n: int):
... a, b = 0, 1
... for _ in range(n):
... a, b = b, a + b
... yield (a,)
...
>>> fibonacci = udtf(Fibonacci, returnType="value: integer")
>>> fibonacci(lit(5)).show()
+-----+
|value|
+-----+
| 1|
| 1|
| 2|
| 3|
| 5|
+-----+The UDTF can be registered and used in SQL queries.
>>> from pyspark.sql.functions import udtf
>>>
>>> @udtf(returnType="word: string")
... class Tokenizer:
... def eval(self, text: str):
... for word in text.split():
... yield (word.strip().lower(),)
...
>>> spark.udtf.register("tokenize", Tokenizer)
<...UserDefinedTableFunction object at ...>
>>> spark.sql("SELECT * FROM tokenize('Hello world')").show()
+-----+
| word|
+-----+
|hello|
|world|
+-----+
>>> spark.sql(
... "SELECT id, word FROM VALUES (1, 'Hello world'), (2, 'Hi') AS t(id, text), "
... "LATERAL tokenize(text) "
... "ORDER BY id, word"
... ).show()
+---+-----+
| id| word|
+---+-----+
| 1|hello|
| 1|world|
| 2| hi|
+---+-----+Arrow UDTFs
You can define an Arrow UDTF class that produces multiple rows for each input row.
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> from pyspark.sql.functions import arrow_udtf, lit
>>>
>>> @arrow_udtf(returnType="v int")
... class Square:
... def eval(self, x: pa.Array):
... yield pa.table({"v": pc.multiply(x, x)})
...
>>> Square(lit(3)).show()
+---+
| v|
+---+
| 9|
+---+