Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

package org.apache.streampark.common.conf

import org.apache.streampark.common.util.{CommandUtils, Logger}
import org.apache.streampark.common.util.{CommandUtils, Logger, SparkEnvUtils}
import org.apache.streampark.common.util.Implicits._

import org.apache.commons.io.FileUtils

import java.io.File
import java.nio.charset.StandardCharsets
import java.util.function.Consumer
import java.util.regex.Pattern

Expand All @@ -35,36 +38,19 @@ class SparkVersion(val sparkHome: String) extends Serializable with Logger {

private[this] lazy val SPARK_SCALA_VERSION_PATTERN = Pattern.compile("Using\\sScala\\sversion\\s(\\d+\\.\\d+)")

val (version, scalaVersion) = {
var sparkVersion: String = null
var scalaVersion: String = null
val cmd = List(s"export SPARK_HOME=$sparkHome&&$sparkHome/bin/spark-submit --version")
val buffer = new mutable.StringBuilder
private[this] lazy val SPARK_RELEASE_VERSION_PATTERN = Pattern.compile("^Spark\\s+(\\d+\\.\\d+\\.\\d+)")

CommandUtils.execute(
sparkHome,
cmd,
new Consumer[String]() {
override def accept(out: String): Unit = {
buffer.append(out).append("\n")
val matcher = SPARK_VERSION_PATTERN.matcher(out)
if (matcher.find) {
sparkVersion = matcher.group(1)
} else {
val matcher1 = SPARK_SCALA_VERSION_PATTERN.matcher(out)
if (matcher1.find) {
scalaVersion = matcher1.group(1)
}
}
}
})
private[this] lazy val SPARK_CORE_JAR_PATTERN =
Pattern.compile("^spark-core_(\\d+\\.\\d+)-(\\d+\\.\\d+\\.\\d+)\\.jar$")

logInfo(buffer.toString())
if (sparkVersion == null || scalaVersion == null) {
throw new IllegalStateException(s"[StreamPark] parse spark version failed. $buffer")
}
buffer.clear()
(sparkVersion, scalaVersion)
val (version, scalaVersion) = {
parseFromSparkCoreJar()
.orElse(parseFromReleaseFile())
.orElse(parseFromSparkSubmit())
.getOrElse(
throw new IllegalStateException(
s"[StreamPark] parse spark version failed for sparkHome: $sparkHome. " +
"Please check whether $SPARK_HOME/jars/spark-core_*.jar or RELEASE exists."))
}

lazy val majorVersion: String = {
Expand All @@ -79,19 +65,22 @@ class SparkVersion(val sparkHome: String) extends Serializable with Logger {

lazy val fullVersion: String = s"${version}_$scalaVersion"

/** Resolved JAVA_HOME for Spark CLI and SparkLauncher, based on spark-env.sh or auto-detection. */
lazy val javaHome: Option[String] = SparkEnvUtils.resolveJavaHome(sparkHome, version)

lazy val sparkLib: File = {
require(sparkHome != null, "[StreamPark] sparkHome must not be null.")
require(new File(sparkHome).exists(), "[StreamPark] sparkHome must be exists.")
val lib = new File(s"$sparkHome/jars")
require(
lib.exists() && lib.isDirectory,
s"[StreamPark] $sparkHome/lib must be exists and must be directory.")
s"[StreamPark] $sparkHome/jars must be exists and must be directory.")
lib
}

def checkVersion(throwException: Boolean = true): Boolean = {
version.split("\\.").map(_.trim.toInt) match {
case Array(v, _, _) if v == 2 || v == 3 => true
case Array(v, _, _) if v == 2 || v == 3 || v == 4 => true
case _ =>
if (throwException) {
throw new UnsupportedOperationException(s"Unsupported spark version: $version")
Expand All @@ -101,12 +90,100 @@ class SparkVersion(val sparkHome: String) extends Serializable with Logger {
}
}

private def parseFromSparkCoreJar(): Option[(String, String)] = {
val jarsDir = new File(s"$sparkHome/jars")
if (!jarsDir.exists() || !jarsDir.isDirectory) {
None
} else {
jarsDir.listFiles().collectFirst {
case file if SPARK_CORE_JAR_PATTERN.matcher(file.getName).matches() =>
val matcher = SPARK_CORE_JAR_PATTERN.matcher(file.getName)
matcher.matches()
val parsed = matcher.group(2) -> matcher.group(1)
logInfo(s"Spark version parsed from spark-core jar name: ${parsed._1}, scala: ${parsed._2}")
parsed
}
}
}

private def parseFromReleaseFile(): Option[(String, String)] = {
val releaseFile = new File(s"$sparkHome/RELEASE")
if (!releaseFile.exists()) {
None
} else {
val firstLine = FileUtils.readFileToString(releaseFile, StandardCharsets.UTF_8).trim.split("\n").headOption.getOrElse("")
val matcher = SPARK_RELEASE_VERSION_PATTERN.matcher(firstLine)
if (matcher.find()) {
parseFromSparkCoreJar().map { case (_, scalaVer) =>
val parsed = matcher.group(1) -> scalaVer
logInfo(s"Spark version parsed from RELEASE file: ${parsed._1}, scala: ${parsed._2}")
parsed
}
} else {
None
}
}
}

private def hintSparkVersion(): String = {
parseFromSparkCoreJar().map(_._1).orElse {
val releaseFile = new File(s"$sparkHome/RELEASE")
if (!releaseFile.exists()) {
None
} else {
val firstLine =
FileUtils.readFileToString(releaseFile, StandardCharsets.UTF_8).trim.split("\n").headOption.getOrElse("")
val matcher = SPARK_RELEASE_VERSION_PATTERN.matcher(firstLine)
if (matcher.find()) Some(matcher.group(1)) else None
}
}.getOrElse("3.0.0")
}

private def parseFromSparkSubmit(): Option[(String, String)] = {
var sparkVersion: String = null
var scalaVersion: String = null
val javaHomeExport = SparkEnvUtils
.resolveJavaHome(sparkHome, hintSparkVersion())
.map(javaHome => s"export JAVA_HOME=$javaHome&&")
.getOrElse("")
val cmd = List(s"export SPARK_HOME=$sparkHome&&${javaHomeExport}$sparkHome/bin/spark-submit --version")
val buffer = new mutable.StringBuilder

CommandUtils.execute(
sparkHome,
cmd,
new Consumer[String]() {
override def accept(out: String): Unit = {
buffer.append(out).append("\n")
val matcher = SPARK_VERSION_PATTERN.matcher(out)
if (matcher.find) {
sparkVersion = matcher.group(1)
} else {
val matcher1 = SPARK_SCALA_VERSION_PATTERN.matcher(out)
if (matcher1.find) {
scalaVersion = matcher1.group(1)
}
}
}
})

logInfo(buffer.toString())
if (sparkVersion != null && scalaVersion != null) {
logInfo(s"Spark version parsed from spark-submit: $sparkVersion, scala: $scalaVersion")
buffer.clear()
Some(sparkVersion -> scalaVersion)
} else {
None
}
}

override def toString: String =
s"""
|----------------------------------------- spark version -----------------------------------
| sparkHome : $sparkHome
| sparkVersion : $version
| scalaVersion : $scalaVersion
| javaHome : ${javaHome.getOrElse("not resolved")}
|-------------------------------------------------------------------------------------------
|""".stripMargin

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.streampark.common.util

import java.io.File
import java.nio.charset.StandardCharsets
import java.util.regex.Pattern

import scala.util.Try

object SparkEnvUtils extends Logger {

private[this] lazy val JAVA_HOME_PATTERN =
Pattern.compile("""(?:^|\n)\s*(?:export\s+)?JAVA_HOME\s*=\s*(?:["']([^"']+)["']|(\S+))""")

/** Minimum Java major version required by the given Spark version string. */
def requiredJavaMajorVersion(sparkVersion: String): Int = {
sparkVersion.split("\\.").headOption.flatMap(v => Try(v.trim.toInt).toOption) match {
case Some(major) if major >= 4 => 17
case _ => 8
}
}

/**
* Resolve JAVA_HOME for Spark CLI and SparkLauncher.
*
* Resolution order:
* 1. `$SPARK_HOME/conf/spark-env.sh`
* 2. process environment `JAVA_HOME`
* 3. system auto-detection (macOS `/usr/libexec/java_home`, common Linux paths)
*/
def resolveJavaHome(sparkHome: String, sparkVersion: String): Option[String] = {
val minVersion = requiredJavaMajorVersion(sparkVersion)
parseJavaHomeFromSparkEnv(sparkHome)
.filter(isValidJavaHome)
.orElse(Option(System.getenv("JAVA_HOME")).filter(isValidJavaHome))
.orElse(detectSystemJavaHome(minVersion).filter(isValidJavaHome))
}

def parseJavaHomeFromSparkEnv(sparkHome: String): Option[String] = {
val sparkEnvFile = new File(sparkHome, "conf/spark-env.sh")
if (!sparkEnvFile.exists()) {
None
} else {
val content = org.apache.commons.io.FileUtils.readFileToString(sparkEnvFile, StandardCharsets.UTF_8)
extractJavaHome(content)
}
}

private[util] def extractJavaHome(content: String): Option[String] = {
val matcher = JAVA_HOME_PATTERN.matcher(content)
var result: Option[String] = None
while (matcher.find() && result.isEmpty) {
val value = Option(matcher.group(1)).getOrElse(matcher.group(2))
if (value != null && value.nonEmpty && !value.startsWith("#")) {
result = Some(value.trim)
}
}
result
}

private def detectSystemJavaHome(minMajor: Int): Option[String] = {
val os = System.getProperty("os.name", "").toLowerCase
if (os.contains("mac")) {
Try {
val (code, output) = CommandUtils.execute(s"/usr/libexec/java_home -v $minMajor 2>/dev/null")
if (code == 0 && output.trim.nonEmpty) Some(output.trim) else None
}.getOrElse(None)
} else {
val candidates = List(
Option(System.getenv(s"JAVA${minMajor}_HOME")),
Option(s"/usr/lib/jvm/java-$minMajor-openjdk"),
Option(s"/usr/lib/jvm/java-$minMajor-openjdk-amd64"),
Option(s"/usr/lib/jvm/java-$minMajor"))
.flatten
.filter(isValidJavaHome)
candidates.headOption
}
}

private def isValidJavaHome(javaHome: String): Boolean = {
javaHome != null && javaHome.nonEmpty && new File(javaHome, "bin/java").exists()
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.streampark.common.conf

import org.apache.commons.io.FileUtils
import org.junit.jupiter.api.{AfterEach, Test}
import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
import org.junit.jupiter.api.io.TempDir

import java.io.File
import java.nio.charset.StandardCharsets
import java.nio.file.Path

class SparkVersionTest {

@TempDir
private var tempDir: Path = _

private var sparkHome: File = _

@AfterEach
def cleanup(): Unit = {
if (sparkHome != null && sparkHome.exists()) {
FileUtils.deleteDirectory(sparkHome)
}
}

@Test
def parseSparkVersionFromJarWithoutRunningSparkSubmit(): Unit = {
sparkHome = tempDir.resolve("spark-4.1.2").toFile
val jarsDir = new File(sparkHome, "jars")
jarsDir.mkdirs()
new File(jarsDir, "spark-core_2.13-4.1.2.jar").createNewFile()
FileUtils.writeStringToFile(
new File(sparkHome, "RELEASE"),
"Spark 4.1.2 (git revision f0bb2e6a47d) built for Hadoop 3.4.2\n",
StandardCharsets.UTF_8)

val sparkVersion = new SparkVersion(sparkHome.getAbsolutePath)

assertEquals("4.1.2", sparkVersion.version)
assertEquals("2.13", sparkVersion.scalaVersion)
assertEquals("4.1", sparkVersion.majorVersion)
assertTrue(sparkVersion.checkVersion(false))
}

}
Loading
Loading