Examples
This page provides examples for all types of PySpark UDFs supported by Sail. For more information about the API, please refer to the Spark documentation and user guide.
Python UDF
You can define a Python scalar UDF by wrapping a Python lambda function with udf()
, or using the @udf
decorator.
>>> from pyspark.sql.functions import udf
>>>
>>> square = udf(lambda x: x**2, "integer")
>>> spark.range(3).select(square("id")).show()
+------------+
|<lambda>(id)|
+------------+
| 0|
| 1|
| 4|
+------------+
>>> @udf("string")
... def upper(s):
... return s.upper() if s is not None else s
...
>>> spark.sql("SELECT 'hello' as v").select(upper("v")).show()
+--------+
|upper(v)|
+--------+
| HELLO|
+--------+
The UDF can be registered and used in SQL queries.
>>> from pyspark.sql.functions import udf
>>>
>>> @udf("integer")
... def cube(x):
... return x**3
...
>>> spark.udf.register("cube", cube)
<function cube at ...>
>>> spark.sql("SELECT cube(5)").show()
+-------+
|cube(5)|
+-------+
| 125|
+-------+
Pandas UDF
You can define a Pandas scalar UDF via the @pandas_udf
decorator. The UDF takes a batch of data as a Pandas object, and must return a Pandas object of the same length as the input. The UDF can also take an iterator of Pandas objects and returns an iterator of Pandas objects, one for each input batch.
>>> from typing import Iterator
>>> import pandas as pd
>>> from pyspark.sql.functions import pandas_udf
>>>
>>> @pandas_udf("string")
... def upper(s):
... return s.str.upper()
...
>>> spark.sql("SELECT 'hello' as v").select(upper("v")).show()
+--------+
|upper(v)|
+--------+
| HELLO|
+--------+
>>> @pandas_udf("long")
... def square(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
... for s in iterator:
... yield s**2
...
>>> spark.range(3).select(square("id")).show()
+----------+
|square(id)|
+----------+
| 0|
| 1|
| 4|
+----------+
You can define a Pandas UDF for aggregation. The UDF returns a single value for the input group.
>>> from pyspark.sql.functions import PandasUDFType, pandas_udf
>>>
>>> @pandas_udf("int", PandasUDFType.GROUPED_AGG)
... def least(v):
... return v.min()
...
>>> df = spark.createDataFrame([(1, "Hello"), (2, "Hi"), (5, "Hello"), (0, "Hello")], ["n", "text"])
>>> df.groupBy(df.text).agg(least(df.n)).sort("text").show()
+-----+--------+
| text|least(n)|
+-----+--------+
|Hello| 0|
| Hi| 2|
+-----+--------+
The Pandas UDF can be registered for use in SQL queries, in the same way as the Python scalar UDF.
You can define a Pandas UDF to transform data partitions using mapInPandas()
.
>>> def transform(iterator):
... for pdf in iterator:
... yield pdf[pdf.text.str.len() > 2]
...
>>> df = spark.createDataFrame([(1, "Hello"), (2, "Hi")], ["id", "text"])
>>> df.mapInPandas(transform, schema=df.schema).show()
+---+-----+
| id| text|
+---+-----+
| 1|Hello|
+---+-----+
You can define a Pandas UDF to transform grouped data or co-grouped data using applyInPandas()
.
>>> import pandas as pd
>>>
>>> def center(pdf):
... return pdf.assign(v=pdf.v - pdf.v.mean())
...
>>> def summary(key, lpdf, rpdf):
... (key,) = key
... return pd.DataFrame({
... "id": [key],
... "diff": [lpdf.v.mean() - rpdf.v.mean()],
... "count": [lpdf.v.count() + rpdf.v.count()],
... })
...
>>> df1 = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0)], ["id", "v"])
>>> df2 = spark.createDataFrame([(1, 5.0)], ["id", "v"])
>>>
>>> df1.groupby("id").applyInPandas(center, schema="id long, v double").sort("id", "v").show()
+---+----+
| id| v|
+---+----+
| 1|-0.5|
| 1| 0.5|
| 2| 0.0|
+---+----+
>>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
... summary, schema="id long, diff double, count long"
... ).sort("id").show()
+---+----+-----+
| id|diff|count|
+---+----+-----+
| 1|-3.5| 3|
| 2|NULL| 1|
+---+----+-----+
Arrow UDF
You can define an Arrow UDF to transform data partitions using mapInArrow()
. This is similar to mapInPandas()
but the input and output are Arrow record batches.
>>> import pyarrow.compute as pc
>>>
>>> def transform(iterator):
... for batch in iterator:
... yield batch.filter(pc.utf8_length(pc.field("text")) > 2)
...
>>> df = spark.createDataFrame([(1, "Hello"), (2, "Hi")], ["id", "text"])
>>> df.mapInArrow(transform, schema=df.schema).show()
+---+-----+
| id| text|
+---+-----+
| 1|Hello|
+---+-----+
Python UDTF
You can define a Python UDTF class that produces multiple rows for each input row.
>>> from pyspark.sql.functions import lit, udtf
>>>
>>> class Fibonacci:
... def eval(self, n: int):
... a, b = 0, 1
... for _ in range(n):
... a, b = b, a + b
... yield (a,)
...
>>> fibonacci = udtf(Fibonacci, returnType="value: integer")
>>> fibonacci(lit(5)).show()
+-----+
|value|
+-----+
| 1|
| 1|
| 2|
| 3|
| 5|
+-----+
The UDTF can be registered and used in SQL queries.
>>> from pyspark.sql.functions import udtf
>>>
>>> @udtf(returnType="word: string")
... class Tokenizer:
... def eval(self, text: str):
... for word in text.split():
... yield (word.strip().lower(),)
...
>>> spark.udtf.register("tokenize", Tokenizer)
<...UserDefinedTableFunction object at ...>
>>> spark.sql("SELECT * FROM tokenize('Hello world')").show()
+-----+
| word|
+-----+
|hello|
|world|
+-----+
>>> spark.sql(
... "SELECT id, word FROM VALUES (1, 'Hello world'), (2, 'Hi') AS t(id, text), "
... "LATERAL tokenize(text) "
... "ORDER BY id, word"
... ).show()
+---+-----+
| id| word|
+---+-----+
| 1|hello|
| 1|world|
| 2| hi|
+---+-----+