diff --git a/examples/src/scriptsTest/scala/water/sparkling/scripts/ScriptsTestSuite.scala b/examples/src/scriptsTest/scala/water/sparkling/scripts/ScriptsTestSuite.scala index 60c7be497..03a9e082d 100644 --- a/examples/src/scriptsTest/scala/water/sparkling/scripts/ScriptsTestSuite.scala +++ b/examples/src/scriptsTest/scala/water/sparkling/scripts/ScriptsTestSuite.scala @@ -50,7 +50,6 @@ class BasicInterpreterTests extends ScriptsTestHelper { } test("Test Spark API call via interpreter") { - val inspections = new ScriptInspections() inspections.addTermToCheck("num1") inspections.addTermToCheck("num2") @@ -65,6 +64,20 @@ class BasicInterpreterTests extends ScriptsTestHelper { assert(result.realTermValues.get("num1").get == "2", "Value of term \"num\" should be 2") assert(result.realTermValues.get("num2").get == "2", "Value of term \"num\" should be 3") } + + test("[SW-386] Test Spark API exposed implicit conversions (https://issues.scala-lang.org/browse/SI-9734 and https://issues.apache.org/jira/browse/SPARK-13456)") { + val inspections = new ScriptInspections() + inspections.addTermToCheck("count") + val result = launchCode( + """ + |import spark.implicits._ + |case class Person(id: Long) + |val ds = Seq(Person(0), Person(1)).toDS + |val count = ds.count + """.stripMargin, inspections) + assert(result.codeExecutionStatus == CodeResults.Success, "Problem during interpreting the script!") + assert(result.realTermValues.get("count").get == "2", "Value of term \"count\" should be 2") + } } diff --git a/repl/build.gradle b/repl/build.gradle index 8ffc91e5d..46f913559 100644 --- a/repl/build.gradle +++ b/repl/build.gradle @@ -7,4 +7,7 @@ dependencies { compile "org.apache.spark:spark-repl_${scalaBaseVersion}:${sparkVersion}" // H2O Scala API compile "ai.h2o:h2o-scala_${scalaBaseVersion}:${h2oVersion}" + + testCompile "org.scalatest:scalatest_${scalaBaseVersion}:2.2.1" + testCompile "junit:junit:4.11" } diff --git a/repl/src/main/scala/org/apache/spark/repl/h2o/H2OIMainHelper.scala b/repl/src/main/scala/org/apache/spark/repl/h2o/H2OIMainHelper.scala index a68c3581c..e28d2d253 100644 --- a/repl/src/main/scala/org/apache/spark/repl/h2o/H2OIMainHelper.scala +++ b/repl/src/main/scala/org/apache/spark/repl/h2o/H2OIMainHelper.scala @@ -44,6 +44,10 @@ trait H2OIMainHelper { fieldSessionNames.set(naming, new SessionNames { override def line = "intp_id_" + sessionId + propOr("line") }) + + // FIX for SW-386 + // We need to patch OuterScopes regexp to correctly recognize classes generated by H2O REPL + PatchUtils.PatchManager.patch("SW-386", Thread.currentThread().getContextClassLoader) } def setClassLoaderToSerializers(classLoader: ClassLoader): Unit = { @@ -59,7 +63,7 @@ trait H2OIMainHelper { } def initializeClassLoader(sc: SparkContext): Unit = { - if(!_initialized){ + if (!_initialized) { if (Main.interp != null) { // Application has been started using SparkShell script. // Set the original interpreter classloader as the fallback class loader for all diff --git a/repl/src/main/scala/org/apache/spark/repl/h2o/PatchUtils.scala b/repl/src/main/scala/org/apache/spark/repl/h2o/PatchUtils.scala new file mode 100644 index 000000000..538b36f7b --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/h2o/PatchUtils.scala @@ -0,0 +1,83 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.repl.h2o + +/** + * Runtime patch utilities. + */ +private[repl] object PatchUtils { + + // The patcher accepts object and its defining class and return true if patching was successful + type Patcher = (AnyRef, Class[_]) => Boolean + + // Actual patch definition + type Patch = (ClassLoader) => Boolean + + /** + * Path given object. + * @param fullClassName class name of object + * @param classloader classloader to use for loading the object definition + * @param patcher actual patcher + * @return true if patching was successful else false + */ + def patchObject(fullClassName: String, classloader: ClassLoader, patcher: Patcher): Boolean = { + val clz = Class.forName(fullClassName + "$", false, classloader) + val module = getModule(clz) + + // Patch it + patcher(module, clz) + } + + def getModule(objectClass: Class[_]): AnyRef = { + val f = objectClass.getField("MODULE$") + f.get(null) + } + + val OUTER_SCOPES_CLASS = "org.apache.spark.sql.catalyst.encoders.OuterScopes" + val OUTER_SCOPE_REPL_REGEX = """^((?:intp_id_\d+)??\$line(?:\d+)\.\$read)(?:\$\$iw)+$""".r + + // Patch Spark OuterScopes definition + val patchOuterScopes: Patch = (classLoader: ClassLoader) => { + val patcher: Patcher = (obj: AnyRef, clz: Class[_]) => { + val f = clz.getDeclaredField("REPLClass") + f.setAccessible(true) + f.set(obj, OUTER_SCOPE_REPL_REGEX) + true + } + + patchObject(OUTER_SCOPES_CLASS, classLoader, patcher) + } + + // Manages all runtime patches in the system + // Note: if necessary it should accept environment configuration and + // apply patch only if it is applicable for given environment (e.g., Scala 2.10 + Spark2.0) + object PatchManager { + + private val patches = Map( + "SW-386" -> + ("Patches OuterScope to replace default REPL regexp by one which understand H2O REPL", patchOuterScopes) + ) + + def patch(jiraId: String, classLoader: ClassLoader): Boolean = { + patches.get(jiraId).map(p => p._2(classLoader)).getOrElse(false) + } + + def patchInfo(jiraId: String): String = { + patches.get(jiraId).map(_._1).getOrElse("NOT FOUND") + } + } +} diff --git a/repl/src/main/scala_2.11/org/apache/spark/repl/h2o/H2OIMain.scala b/repl/src/main/scala_2.11/org/apache/spark/repl/h2o/H2OIMain.scala index 2f3e7b4d1..5a5d6b83a 100644 --- a/repl/src/main/scala_2.11/org/apache/spark/repl/h2o/H2OIMain.scala +++ b/repl/src/main/scala_2.11/org/apache/spark/repl/h2o/H2OIMain.scala @@ -35,7 +35,6 @@ private[repl] class H2OIMain private(initialSettings: Settings, extends IMain(initialSettings, interpreterWriter) with H2OIMainHelper { setupClassNames(naming, sessionId) - } object H2OIMain extends H2OIMainHelper { diff --git a/repl/src/test/scala/org/apache/spark/repl/h2o/PatchUtilsTestSuite.scala b/repl/src/test/scala/org/apache/spark/repl/h2o/PatchUtilsTestSuite.scala new file mode 100644 index 000000000..7b05b4239 --- /dev/null +++ b/repl/src/test/scala/org/apache/spark/repl/h2o/PatchUtilsTestSuite.scala @@ -0,0 +1,67 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.repl.h2o + +import org.junit.runner.RunWith +import org.scalatest.junit.JUnitRunner +import org.scalatest._ + +import scala.util.matching.Regex + +/** + * Test that patcher is patching. + */ +@RunWith(classOf[JUnitRunner]) +class PatchUtilsTestSuite extends FunSuite with BeforeAndAfterAll { + + val EXAMPLE_CLASS_NAME = "intp_id_12$line1.$read$$iw$$iw" + val EXAMPLE_RESULT_AFTER_PATCH = "intp_id_12$line1.$read" + val FAILED_MATCH = "FAIL" + + def assertMatch(regex: Regex, input: String, exp: String):Unit = { + val result = input match { + case regex(b) => b + case _ => FAILED_MATCH + } + assert(result == exp) + } + + test("Test new regexp for OuterScopes") { + val regex = PatchUtils.OUTER_SCOPE_REPL_REGEX + assertMatch(regex, EXAMPLE_CLASS_NAME, EXAMPLE_RESULT_AFTER_PATCH) + assertMatch(regex, "$line1.$read$$iw$$iw", "$line1.$read") + } + + test("[SW-386] Test patched OuterScopes") { + val regexBeforePatch = getRegexp() + // Default regexp fails for our class names with intp_id prefix + assertMatch(regexBeforePatch, EXAMPLE_CLASS_NAME, FAILED_MATCH) + + PatchUtils.PatchManager.patch("SW-386", Thread.currentThread().getContextClassLoader) + + val regexAfterPatch = getRegexp() + assertMatch(regexAfterPatch, EXAMPLE_CLASS_NAME, EXAMPLE_RESULT_AFTER_PATCH) + } + + def getRegexp(): Regex = { + val clz = Class.forName(PatchUtils.OUTER_SCOPES_CLASS + "$") + val module = PatchUtils.getModule(clz) + val f = clz.getDeclaredField("REPLClass") + f.setAccessible(true) + f.get(module).asInstanceOf[Regex] + } +}