diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 33771011fe364..86d0405c678a7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -83,10 +83,23 @@ private[spark] case class PythonFunction( */ private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) +/** + * Enumerate the type of command that will be sent to the Python worker + */ +private[spark] object PythonEvalType { + val NON_UDF = 0 + val SQL_BATCHED_UDF = 1 + val SQL_PANDAS_UDF = 2 +} + private[spark] object PythonRunner { def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { new PythonRunner( - Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0))) + Seq(ChainedPythonFunctions(Seq(func))), + bufferSize, + reuse_worker, + PythonEvalType.NON_UDF, + Array(Array(0))) } } @@ -100,7 +113,7 @@ private[spark] class PythonRunner( funcs: Seq[ChainedPythonFunctions], bufferSize: Int, reuse_worker: Boolean, - isUDF: Boolean, + evalType: Int, argOffsets: Array[Array[Int]]) extends Logging { @@ -309,8 +322,8 @@ private[spark] class PythonRunner( } dataOut.flush() // Serialized command: - if (isUDF) { - dataOut.writeInt(1) + dataOut.writeInt(evalType) + if (evalType != PythonEvalType.NON_UDF) { dataOut.writeInt(funcs.length) funcs.zip(argOffsets).foreach { case (chained, offsets) => dataOut.writeInt(offsets.length) @@ -324,7 +337,6 @@ private[spark] class PythonRunner( } } } else { - dataOut.writeInt(0) val command = funcs.head.funcs.head.command dataOut.writeInt(command.length) dataOut.write(command) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d5c2a7518b18f..addd7d115a03d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -81,6 +81,12 @@ class SpecialLengths(object): NULL = -5 +class PythonEvalType(object): + NON_UDF = 0 + SQL_BATCHED_UDF = 1 + SQL_PANDAS_UDF = 2 + + class Serializer(object): def dump_stream(self, iterator, stream): @@ -187,8 +193,14 @@ class ArrowSerializer(FramedSerializer): Serializes an Arrow stream. """ - def dumps(self, obj): - raise NotImplementedError + def dumps(self, batch): + import pyarrow as pa + import io + sink = io.BytesIO() + writer = pa.RecordBatchFileWriter(sink, batch.schema) + writer.write_batch(batch) + writer.close() + return sink.getvalue() def loads(self, obj): import pyarrow as pa @@ -199,6 +211,55 @@ def __repr__(self): return "ArrowSerializer" +class ArrowPandasSerializer(ArrowSerializer): + """ + Serializes Pandas.Series as Arrow data. + """ + + def __init__(self): + super(ArrowPandasSerializer, self).__init__() + + def dumps(self, series): + """ + Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or + a list of series accompanied by an optional pyarrow type to coerce the data to. + """ + import pyarrow as pa + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or \ + (len(series) == 2 and isinstance(series[1], pa.DataType)): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + # If a nullable integer series has been promoted to floating point with NaNs, need to cast + # NOTE: this is not necessary with Arrow >= 0.7 + def cast_series(s, t): + if t is None or s.dtype == t.to_pandas_dtype(): + return s + else: + return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) + + arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + return super(ArrowPandasSerializer, self).dumps(batch) + + def loads(self, obj): + """ + Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series + followed by a dictionary containing length of the loaded batches. + """ + import pyarrow as pa + reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) + batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)] + # NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set + num_rows = sum((batch.num_rows for batch in batches)) + table = pa.Table.from_batches(batches) + return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}] + + def __repr__(self): + return "ArrowPandasSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0e76182e0e02d..007b418082d07 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2032,7 +2032,7 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ - def __init__(self, func, returnType, name=None): + def __init__(self, func, returnType, name=None, vectorized=False): if not callable(func): raise TypeError( "Not a function or callable (__call__ is not defined): " @@ -2046,6 +2046,7 @@ def __init__(self, func, returnType, name=None): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) + self._vectorized = vectorized @property def returnType(self): @@ -2077,7 +2078,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt) + self._name, wrapped_func, jdt, self._vectorized) return judf def __call__(self, *cols): @@ -2111,6 +2112,22 @@ def wrapper(*args): return wrapper +def _create_udf(f, returnType, vectorized): + + def _udf(f, returnType=StringType(), vectorized=vectorized): + udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) + return udf_obj._wrapped() + + # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf + if f is None or isinstance(f, (str, DataType)): + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + return functools.partial(_udf, returnType=return_type, vectorized=vectorized) + else: + return _udf(f=f, returnType=returnType, vectorized=vectorized) + + @since(1.3) def udf(f=None, returnType=StringType()): """Creates a :class:`Column` expression representing a user defined function (UDF). @@ -2142,18 +2159,26 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - def _udf(f, returnType=StringType()): - udf_obj = UserDefinedFunction(f, returnType) - return udf_obj._wrapped() + return _create_udf(f, returnType=returnType, vectorized=False) - # decorator @udf, @udf() or @udf(dataType()) - if f is None or isinstance(f, (str, DataType)): - # If DataType has been passed as a positional argument - # for decorator use it as a returnType - return_type = f or returnType - return functools.partial(_udf, returnType=return_type) + +@since(2.3) +def pandas_udf(f=None, returnType=StringType()): + """ + Creates a :class:`Column` expression representing a user defined function (UDF) that accepts + `Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length. + + :param f: python function if used as a standalone function + :param returnType: a :class:`pyspark.sql.types.DataType` object + + # TODO: doctest + """ + import inspect + # If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder + if inspect.getargspec(f).keywords is None: + return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True) else: - return _udf(f=f, returnType=returnType) + return _create_udf(f, returnType=returnType, vectorized=True) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3d87ccfc03ddd..3d9d5d1175f51 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3122,6 +3122,203 @@ def test_filtered_frame(self): self.assertTrue(pdf.empty) +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class VectorizedUDFTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + def test_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10).select( + col('id').cast('string').alias('str'), + col('id').cast('int').alias('int'), + col('id').alias('long'), + col('id').cast('float').alias('float'), + col('id').cast('double').alias('double'), + col('id').cast('boolean').alias('bool')) + f = lambda x: x + str_f = pandas_udf(f, StringType()) + int_f = pandas_udf(f, IntegerType()) + long_f = pandas_udf(f, LongType()) + float_f = pandas_udf(f, FloatType()) + double_f = pandas_udf(f, DoubleType()) + bool_f = pandas_udf(f, BooleanType()) + res = df.select(str_f(col('str')), int_f(col('int')), + long_f(col('long')), float_f(col('float')), + double_f(col('double')), bool_f(col('bool'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_boolean(self): + from pyspark.sql.functions import pandas_udf, col + data = [(True,), (True,), (None,), (False,)] + schema = StructType().add("bool", BooleanType()) + df = self.spark.createDataFrame(data, schema) + bool_f = pandas_udf(lambda x: x, BooleanType()) + res = df.select(bool_f(col('bool'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_byte(self): + from pyspark.sql.functions import pandas_udf, col + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("byte", ByteType()) + df = self.spark.createDataFrame(data, schema) + byte_f = pandas_udf(lambda x: x, ByteType()) + res = df.select(byte_f(col('byte'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_short(self): + from pyspark.sql.functions import pandas_udf, col + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("short", ShortType()) + df = self.spark.createDataFrame(data, schema) + short_f = pandas_udf(lambda x: x, ShortType()) + res = df.select(short_f(col('short'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_int(self): + from pyspark.sql.functions import pandas_udf, col + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("int", IntegerType()) + df = self.spark.createDataFrame(data, schema) + int_f = pandas_udf(lambda x: x, IntegerType()) + res = df.select(int_f(col('int'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_long(self): + from pyspark.sql.functions import pandas_udf, col + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("long", LongType()) + df = self.spark.createDataFrame(data, schema) + long_f = pandas_udf(lambda x: x, LongType()) + res = df.select(long_f(col('long'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_float(self): + from pyspark.sql.functions import pandas_udf, col + data = [(3.0,), (5.0,), (-1.0,), (None,)] + schema = StructType().add("float", FloatType()) + df = self.spark.createDataFrame(data, schema) + float_f = pandas_udf(lambda x: x, FloatType()) + res = df.select(float_f(col('float'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_double(self): + from pyspark.sql.functions import pandas_udf, col + data = [(3.0,), (5.0,), (-1.0,), (None,)] + schema = StructType().add("double", DoubleType()) + df = self.spark.createDataFrame(data, schema) + double_f = pandas_udf(lambda x: x, DoubleType()) + res = df.select(double_f(col('double'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_string(self): + from pyspark.sql.functions import pandas_udf, col + data = [("foo",), (None,), ("bar",), ("bar",)] + schema = StructType().add("str", StringType()) + df = self.spark.createDataFrame(data, schema) + str_f = pandas_udf(lambda x: x, StringType()) + res = df.select(str_f(col('str'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_zero_parameter(self): + from pyspark.sql.functions import pandas_udf + import pandas as pd + df = self.spark.range(10) + f0 = pandas_udf(lambda **kwargs: pd.Series(1).repeat(kwargs['length']), LongType()) + res = df.select(f0()) + self.assertEquals(df.select(lit(1)).collect(), res.collect()) + + def test_vectorized_udf_datatype_string(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10).select( + col('id').cast('string').alias('str'), + col('id').cast('int').alias('int'), + col('id').alias('long'), + col('id').cast('float').alias('float'), + col('id').cast('double').alias('double'), + col('id').cast('boolean').alias('bool')) + f = lambda x: x + str_f = pandas_udf(f, 'string') + int_f = pandas_udf(f, 'integer') + long_f = pandas_udf(f, 'long') + float_f = pandas_udf(f, 'float') + double_f = pandas_udf(f, 'double') + bool_f = pandas_udf(f, 'boolean') + res = df.select(str_f(col('str')), int_f(col('int')), + long_f(col('long')), float_f(col('float')), + double_f(col('double')), bool_f(col('bool'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_complex(self): + from pyspark.sql.functions import pandas_udf, col, expr + df = self.spark.range(10).select( + col('id').cast('int').alias('a'), + col('id').cast('int').alias('b'), + col('id').cast('double').alias('c')) + add = pandas_udf(lambda x, y: x + y, IntegerType()) + power2 = pandas_udf(lambda x: 2 ** x, IntegerType()) + mul = pandas_udf(lambda x, y: x * y, DoubleType()) + res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c'))) + expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c')) + self.assertEquals(expected.collect(), res.collect()) + + def test_vectorized_udf_exception(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10) + raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'): + df.select(raise_exception(col('id'))).collect() + + def test_vectorized_udf_invalid_length(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + raise_exception = pandas_udf(lambda: pd.Series(1), LongType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Result vector from pandas_udf was not the required length'): + df.select(raise_exception()).collect() + + def test_vectorized_udf_mix_udf(self): + from pyspark.sql.functions import pandas_udf, udf, col + df = self.spark.range(10) + row_by_row_udf = udf(lambda x: x, LongType()) + pd_udf = pandas_udf(lambda x: x, LongType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Can not mix vectorized and non-vectorized UDFs'): + df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() + + def test_vectorized_udf_chained(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10).toDF('x') + f = pandas_udf(lambda x: x + 1, LongType()) + g = pandas_udf(lambda x: x - 1, LongType()) + res = df.select(g(f(col('x')))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_wrong_return_type(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10).toDF('x') + f = pandas_udf(lambda x: x * 1.0, StringType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Invalid.*type.*string'): + df.select(f(col('x'))).collect() + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 51bf7bef49763..2a5a1b68f4a2c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1582,6 +1582,33 @@ def convert(self, obj, gateway_client): register_input_converter(DateConverter()) +def toArrowType(dt): + """ Convert Spark data type to pyarrow type + """ + import pyarrow as pa + if type(dt) == BooleanType: + arrow_type = pa.bool_() + elif type(dt) == ByteType: + arrow_type = pa.int8() + elif type(dt) == ShortType: + arrow_type = pa.int16() + elif type(dt) == IntegerType: + arrow_type = pa.int32() + elif type(dt) == LongType: + arrow_type = pa.int64() + elif type(dt) == FloatType: + arrow_type = pa.float32() + elif type(dt) == DoubleType: + arrow_type = pa.float64() + elif type(dt) == DecimalType: + arrow_type = pa.decimal(dt.precision, dt.scale) + elif type(dt) == StringType: + arrow_type = pa.string() + else: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) + return arrow_type + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index baaa3fe074e9a..0e35cf7be6240 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,7 +30,9 @@ from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer + write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ + BatchedSerializer, ArrowPandasSerializer +from pyspark.sql.types import toArrowType from pyspark import shuffle pickleSer = PickleSerializer() @@ -58,9 +60,12 @@ def read_command(serializer, file): return command -def chain(f, g): - """chain two function together """ - return lambda *a: g(f(*a)) +def chain(f, g, eval_type): + """chain two functions together """ + if eval_type == PythonEvalType.SQL_PANDAS_UDF: + return lambda *a, **kwargs: g(f(*a, **kwargs), **kwargs) + else: + return lambda *a: g(f(*a)) def wrap_udf(f, return_type): @@ -71,7 +76,21 @@ def wrap_udf(f, return_type): return lambda *a: f(*a) -def read_single_udf(pickleSer, infile): +def wrap_pandas_udf(f, return_type): + arrow_return_type = toArrowType(return_type) + + def verify_result_length(*a): + kwargs = a[-1] + result = f(*a[:-1], **kwargs) + if len(result) != kwargs["length"]: + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d\nUse input vector length or kwargs['length']" + % (kwargs["length"], len(result))) + return result, arrow_return_type + return lambda *a: verify_result_length(*a) + + +def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None @@ -80,17 +99,22 @@ def read_single_udf(pickleSer, infile): if row_func is None: row_func = f else: - row_func = chain(row_func, f) + row_func = chain(row_func, f, eval_type) # the last returnType will be the return type of UDF - return arg_offsets, wrap_udf(row_func, return_type) + if eval_type == PythonEvalType.SQL_PANDAS_UDF: + # A pandas_udf will take kwargs as the last argument + arg_offsets = arg_offsets + [-1] + return arg_offsets, wrap_pandas_udf(row_func, return_type) + else: + return arg_offsets, wrap_udf(row_func, return_type) -def read_udfs(pickleSer, infile): +def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) udfs = {} call_udf = [] for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) @@ -102,7 +126,12 @@ def read_udfs(pickleSer, infile): mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) - ser = BatchedSerializer(PickleSerializer(), 100) + + if eval_type == PythonEvalType.SQL_PANDAS_UDF: + ser = ArrowPandasSerializer() + else: + ser = BatchedSerializer(PickleSerializer(), 100) + # profiling is not supported for UDF return func, None, ser, ser @@ -159,11 +188,11 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - is_sql_udf = read_int(infile) - if is_sql_udf: - func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) - else: + eval_type = read_int(infile) + if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) + else: + func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala new file mode 100644 index 0000000000000..f8bdc1e14eebc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} +import org.apache.spark.sql.types.StructType + +/** + * A physical plan that evaluates a [[PythonUDF]], + */ +case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) + extends EvalPythonExec(udfs, output, child) { + + protected override def evaluate( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + argOffsets: Array[Array[Int]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + val inputIterator = ArrowConverters.toPayloadIterator( + iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable) + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets) + .compute(inputIterator, context.partitionId(), context) + + val outputRowIterator = ArrowConverters.fromPayloadIterator( + outputIterator.map(new ArrowPayload(_)), context) + + // Verify that the output schema is correct + val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex + .map { case (attr, i) => attr.withName(s"_$i") }) + assert(schemaOut.equals(outputRowIterator.schema), + s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}") + + outputRowIterator + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 3e176e2cde5bd..2978eac50554d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -17,153 +17,78 @@ package org.apache.spark.sql.execution.python -import java.io.File - import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.{Pickler, Unpickler} -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} -import org.apache.spark.rdd.RDD +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{DataType, StructField, StructType} -import org.apache.spark.util.Utils - +import org.apache.spark.sql.types.{StructField, StructType} /** - * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time. - * - * Python evaluation works by sending the necessary (projected) input data via a socket to an - * external Python process, and combine the result from the Python process with the original row. - * - * For each row we send to Python, we also put it in a queue first. For each output row from Python, - * we drain the queue to find the original input row. Note that if the Python process is way too - * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory. - * - * Here is a diagram to show how this works: - * - * Downstream (for parent) - * / \ - * / socket (output of UDF) - * / \ - * RowQueue Python - * \ / - * \ socket (input of UDF) - * \ / - * upstream (from child) - * - * The rows sent to and received from Python are packed into batches (100 rows) and serialized, - * there should be always some rows buffered in the socket or Python process, so the pulling from - * RowQueue ALWAYS happened after pushing into it. + * A physical plan that evaluates a [[PythonUDF]] */ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) - extends SparkPlan { - - def children: Seq[SparkPlan] = child :: Nil - - override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) - - private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { - udf.children match { - case Seq(u: PythonUDF) => - val (chained, children) = collectFunctions(u) - (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) - case children => - // There should not be any other UDFs, or the children can't be evaluated directly. - assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) - (ChainedPythonFunctions(Seq(udf.func)), udf.children) - } - } - - protected override def doExecute(): RDD[InternalRow] = { - val inputRDD = child.execute().map(_.copy()) - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - - inputRDD.mapPartitions { iter => - EvaluatePython.registerPicklers() // register pickler for Row - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - TaskContext.get().addTaskCompletionListener({ ctx => - queue.close() - }) - - val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip - - // flatten all the arguments - val allInputs = new ArrayBuffer[Expression] - val dataTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => - input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) - } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 - } - }.toArray - }.toArray - val projection = newMutableProjection(allInputs, child.output) - val schema = StructType(dataTypes.map(dt => StructField("", dt))) - val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) - - // enable memo iff we serialize the row with schema (schema and class should be memorized) - val pickle = new Pickler(needConversion) - // Input iterator to Python: input rows are grouped so we send them in batches to Python. - // For each row, add it to the queue. - val inputIterator = iter.map { inputRow => - queue.add(inputRow.asInstanceOf[UnsafeRow]) - val row = projection(inputRow) - if (needConversion) { - EvaluatePython.toJava(row, schema) - } else { - // fast path for these types that does not need conversion in Python - val fields = new Array[Any](row.numFields) - var i = 0 - while (i < row.numFields) { - val dt = dataTypes(i) - fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) - i += 1 - } - fields - } - }.grouped(100).map(x => pickle.dumps(x.toArray)) - - val context = TaskContext.get() - - // Output iterator for results from Python. - val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) - .compute(inputIterator, context.partitionId(), context) - - val unpickle = new Unpickler - val mutableRow = new GenericInternalRow(1) - val joined = new JoinedRow - val resultType = if (udfs.length == 1) { - udfs.head.dataType + extends EvalPythonExec(udfs, output, child) { + + protected override def evaluate( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + argOffsets: Array[Array[Int]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + EvaluatePython.registerPicklers() // register pickler for Row + + val dataTypes = schema.map(_.dataType) + val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) + + // enable memo iff we serialize the row with schema (schema and class should be memorized) + val pickle = new Pickler(needConversion) + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.map { row => + if (needConversion) { + EvaluatePython.toJava(row, schema) } else { - StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) - } - val resultProj = UnsafeProjection.create(output, output) - outputIterator.flatMap { pickedResult => - val unpickledBatch = unpickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map { result => - val row = if (udfs.length == 1) { - // fast path for single UDF - mutableRow(0) = EvaluatePython.fromJava(result, resultType) - mutableRow - } else { - EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 } - resultProj(joined(queue.remove(), row)) + fields + } + }.grouped(100).map(x => pickle.dumps(x.toArray)) + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) + .compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler + val mutableRow = new GenericInternalRow(1) + val resultType = if (udfs.length == 1) { + udfs.head.dataType + } else { + StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) + } + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + if (udfs.length == 1) { + // fast path for single UDF + mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow + } else { + EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala new file mode 100644 index 0000000000000..860dc78c1dd1b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + + +/** + * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue first. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory. + * + * Here is a diagram to show how this works: + * + * Downstream (for parent) + * / \ + * / socket (output of UDF) + * / \ + * RowQueue Python + * \ / + * \ socket (input of UDF) + * \ / + * upstream (from child) + * + * The rows sent to and received from Python are packed into batches (100 rows) and serialized, + * there should be always some rows buffered in the socket or Python process, so the pulling from + * RowQueue ALWAYS happened after pushing into it. + */ +abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + + def children: Seq[SparkPlan] = child :: Nil + + override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + protected def evaluate( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + argOffsets: Array[Array[Int]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + + inputRDD.mapPartitions { iter => + val context = TaskContext.get() + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener { ctx => + queue.close() + } + + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + val projection = newMutableProjection(allInputs, child.output) + val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + // Add rows to queue to join later with the result. + val projectedRowIter = iter.map { inputRow => + queue.add(inputRow.asInstanceOf[UnsafeRow]) + projection(inputRow) + } + + val outputRowIterator = evaluate( + pyFuncs, bufferSize, reuseWorker, argOffsets, projectedRowIter, schema, context) + + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) + + outputRowIterator.map { outputRow => + resultProj(joined(queue.remove(), outputRow)) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 882a5ce1a663e..fec456d86dbe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -138,7 +138,16 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child) + + val evaluation = validUdfs.partition(_.vectorized) match { + case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => + ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => + BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) + case _ => + throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") + } + attributeMap ++= validUdfs.zip(resultAttrs) evaluation } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 7ebbdb9846cce..84a6d9e5be59c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -28,7 +28,8 @@ case class PythonUDF( name: String, func: PythonFunction, dataType: DataType, - children: Seq[Expression]) + children: Seq[Expression], + vectorized: Boolean) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 0d39c8ff980f2..a30a80acf5c23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -28,10 +28,11 @@ import org.apache.spark.sql.types.DataType case class UserDefinedPythonFunction( name: String, func: PythonFunction, - dataType: DataType) { + dataType: DataType, + vectorized: Boolean) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e) + PythonUDF(name, func, dataType, e, vectorized) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index bbd9484271a3e..153e6e1f88c70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -105,5 +105,8 @@ class DummyUDF extends PythonFunction( broadcastVars = null, accumulator = null) -class MyDummyPythonUDF - extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType) +class MyDummyPythonUDF extends UserDefinedPythonFunction( + name = "dummyUDF", + func = new DummyUDF, + dataType = BooleanType, + vectorized = false)