`

SparkSQL DF.agg 执行过程解析

阅读更多
在上一篇文章前, 我一直没看懂为什么下面的代码就能得到max或者avg或者min的值:
malePPL.agg(Map("height" -> "max", "sex" -> "count")).show


数据是
身高 性别
这样的一个组合大概有几百万个值

刚开始是使用reducebykey去做计算, 后来发现网上有agg里面直接进行排序获取值的做法, 特地看了一下为什么传进去一个Map(column -> Expression)就能得到想要的结果

首先还是直接进到agg的方法里面:
  /**
   * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups.
   * {{{
   *   // df.agg(...) is a shorthand for df.groupBy().agg(...)
   *   df.agg(Map("age" -> "max", "salary" -> "avg"))
   *   df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
   * }}}
   * @group dfops
   * @since 1.3.0
   */
  def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)


看到他是执行groupBy返回对象的agg方法, 可以看到groupBy是一个GroupData:
  @scala.annotation.varargs
  def groupBy(cols: Column*): GroupedData = {
    GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
  }


GroupedData的agg方法:

  def agg(exprs: Map[String, String]): DataFrame = {
    toDF(exprs.map { case (colName, expr) =>
      strToExpr(expr)(df(colName).expr)
    }.toSeq)
  }


可以看到他是使用toDF方法构建一个DataFrame, 看一下strToExpr里面其实是做了一个unresolvedFunction:
  private[this] def strToExpr(expr: String): (Expression => Expression) = {
    val exprToFunc: (Expression => Expression) = {
      (inputExpr: Expression) => expr.toLowerCase match {
        // We special handle a few cases that have alias that are not in function registry.
        case "avg" | "average" | "mean" =>
          UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false)
        case "stddev" | "std" =>
          UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false)
        // Also special handle count because we need to take care count(*).
        case "count" | "size" =>
          // Turn count(*) into count(1)
          inputExpr match {
            case s: Star => Count(Literal(1)).toAggregateExpression()
            case _ => Count(inputExpr).toAggregateExpression()
          }
        case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false)
      }
    }
    (inputExpr: Expression) => exprToFunc(inputExpr)
  }



看一下toDF是怎么写的:
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
    val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
      groupingExprs ++ aggExprs
    } else {
      aggExprs
    }

    val aliasedAgg = aggregates.map(alias)

    groupType match {
      case GroupedData.GroupByType =>
        DataFrame(
          df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
      case GroupedData.RollupType =>
        DataFrame(
          df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
      case GroupedData.CubeType =>
        DataFrame(
          df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
      case GroupedData.PivotType(pivotCol, values) =>
        val aliasedGrps = groupingExprs.map(alias)
        DataFrame(
          df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
    }
  }

在groupBy方法里面我们其实可以看到传入的grouptype是GroupedData.GroupByType

所以这里会去执行:
case GroupedData.GroupByType =>
        DataFrame(
          df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))

Aggregate方法继承自UnaryNode, 也就是一个LogicPlan
case class Aggregate(
    groupingExpressions: Seq[Expression],
    aggregateExpressions: Seq[NamedExpression],
    child: LogicalPlan)
  extends UnaryNode {

  override lazy val resolved: Boolean = {
    val hasWindowExpressions = aggregateExpressions.exists ( _.collect {
        case window: WindowExpression => window
      }.nonEmpty
    )

    !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions
  }

  override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}


这个logicplan包含了我们传入的表达式, 比如说hight-> max这样的。 经过这几步后, 一个DataFrame被创建了, 按照之前的那片文章来看, DF会做下面这几步去优化logicplan直到一个可执行的物理计划为止: (包含对unresolvedFunction的优化)
1.通过Sqlparse 转成unresolvedLogicplan
2.通过Analyzer转成 resolvedLogicplan
3.通过optimizer转成 optimzedLogicplan
4.通过sparkplanner转成physicalLogicplan
5.通过prepareForExecution 转成executable logicplan
6.通过toRDD等方法执行executedplan去调用tree的doExecute

既然这样, 那么我们看一下unresolvedFunction是怎么会和max min avg等expression关联起来的, 进入analyzer, 看到SQLContext里面创建Analyzer时候传入了一个registry:
protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()

  protected[sql] lazy val analyzer: Analyzer =
    new Analyzer(catalog, functionRegistry, conf) {
      override val extendedResolutionRules =
        ExtractPythonUDFs ::
        PreInsertCastAndRename ::
        (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil)

      override val extendedCheckRules = Seq(
        datasources.PreWriteCheck(catalog)
      )
    }


在这个FunctionRegistry里面包含了所有的expression:
object FunctionRegistry {

  type FunctionBuilder = Seq[Expression] => Expression

  val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
    // misc non-aggregate functions
    expression[Abs]("abs"),
    expression[CreateArray]("array"),
    expression[Coalesce]("coalesce"),
    expression[Explode]("explode"),
    expression[Greatest]("greatest"),
    expression[If]("if"),
    expression[IsNaN]("isnan"),
    expression[IsNull]("isnull"),
    expression[IsNotNull]("isnotnull"),
    expression[Least]("least"),
    expression[Coalesce]("nvl"),
    expression[Rand]("rand"),
    expression[Randn]("randn"),
    expression[CreateStruct]("struct"),
    expression[CreateNamedStruct]("named_struct"),
    expression[Sqrt]("sqrt"),
    expression[NaNvl]("nanvl"),

    // math functions
    expression[Acos]("acos"),
    expression[Asin]("asin"),
    expression[Atan]("atan"),
    expression[Atan2]("atan2"),
    expression[Bin]("bin"),
    expression[Cbrt]("cbrt"),
    expression[Ceil]("ceil"),
    expression[Ceil]("ceiling"),
    expression[Cos]("cos"),
    expression[Cosh]("cosh"),
    expression[Conv]("conv"),
    expression[EulerNumber]("e"),
    expression[Exp]("exp"),
    expression[Expm1]("expm1"),
    expression[Floor]("floor"),
    expression[Factorial]("factorial"),
    expression[Hypot]("hypot"),
    expression[Hex]("hex"),
    expression[Logarithm]("log"),
    expression[Log]("ln"),
    expression[Log10]("log10"),
    expression[Log1p]("log1p"),
    expression[Log2]("log2"),
    expression[UnaryMinus]("negative"),
    expression[Pi]("pi"),
    expression[Pow]("pow"),
    expression[Pow]("power"),
    expression[Pmod]("pmod"),
    expression[UnaryPositive]("positive"),
    expression[Rint]("rint"),
    expression[Round]("round"),
    expression[ShiftLeft]("shiftleft"),
    expression[ShiftRight]("shiftright"),
    expression[ShiftRightUnsigned]("shiftrightunsigned"),
    expression[Signum]("sign"),
    expression[Signum]("signum"),
    expression[Sin]("sin"),
    expression[Sinh]("sinh"),
    expression[Tan]("tan"),
    expression[Tanh]("tanh"),
    expression[ToDegrees]("degrees"),
    expression[ToRadians]("radians"),

    // aggregate functions
    expression[HyperLogLogPlusPlus]("approx_count_distinct"),
    expression[Average]("avg"),
    expression[Corr]("corr"),
    expression[Count]("count"),
    expression[First]("first"),
    expression[First]("first_value"),
    expression[Last]("last"),
    expression[Last]("last_value"),
    expression[Max]("max"),
    expression[Average]("mean"),
    expression[Min]("min"),
    expression[StddevSamp]("stddev"),
    expression[StddevPop]("stddev_pop"),
    expression[StddevSamp]("stddev_samp"),
    expression[Sum]("sum"),
    expression[VarianceSamp]("variance"),
    expression[VariancePop]("var_pop"),
    expression[VarianceSamp]("var_samp"),
    expression[Skewness]("skewness"),
    expression[Kurtosis]("kurtosis"),

    // string functions
    expression[Ascii]("ascii"),
    expression[Base64]("base64"),
    expression[Concat]("concat"),
    expression[ConcatWs]("concat_ws"),
    expression[Encode]("encode"),
    expression[Decode]("decode"),
    expression[FindInSet]("find_in_set"),
    expression[FormatNumber]("format_number"),
    expression[GetJsonObject]("get_json_object"),
    expression[InitCap]("initcap"),
    expression[JsonTuple]("json_tuple"),
    expression[Lower]("lcase"),
    expression[Lower]("lower"),
    expression[Length]("length"),
    expression[Levenshtein]("levenshtein"),
    expression[RegExpExtract]("regexp_extract"),
    expression[RegExpReplace]("regexp_replace"),
    expression[StringInstr]("instr"),
    expression[StringLocate]("locate"),
    expression[StringLPad]("lpad"),
    expression[StringTrimLeft]("ltrim"),
    expression[FormatString]("format_string"),
    expression[FormatString]("printf"),
    expression[StringRPad]("rpad"),
    expression[StringRepeat]("repeat"),
    expression[StringReverse]("reverse"),
    expression[StringTrimRight]("rtrim"),
    expression[SoundEx]("soundex"),
    expression[StringSpace]("space"),
    expression[StringSplit]("split"),
    expression[Substring]("substr"),
    expression[Substring]("substring"),
    expression[SubstringIndex]("substring_index"),
    expression[StringTranslate]("translate"),
    expression[StringTrim]("trim"),
    expression[UnBase64]("unbase64"),
    expression[Upper]("ucase"),
    expression[Unhex]("unhex"),
    expression[Upper]("upper"),

    // datetime functions
    expression[AddMonths]("add_months"),
    expression[CurrentDate]("current_date"),
    expression[CurrentTimestamp]("current_timestamp"),
    expression[CurrentTimestamp]("now"),
    expression[DateDiff]("datediff"),
    expression[DateAdd]("date_add"),
    expression[DateFormatClass]("date_format"),
    expression[DateSub]("date_sub"),
    expression[DayOfMonth]("day"),
    expression[DayOfYear]("dayofyear"),
    expression[DayOfMonth]("dayofmonth"),
    expression[FromUnixTime]("from_unixtime"),
    expression[FromUTCTimestamp]("from_utc_timestamp"),
    expression[Hour]("hour"),
    expression[LastDay]("last_day"),
    expression[Minute]("minute"),
    expression[Month]("month"),
    expression[MonthsBetween]("months_between"),
    expression[NextDay]("next_day"),
    expression[Quarter]("quarter"),
    expression[Second]("second"),
    expression[ToDate]("to_date"),
    expression[ToUnixTimestamp]("to_unix_timestamp"),
    expression[ToUTCTimestamp]("to_utc_timestamp"),
    expression[TruncDate]("trunc"),
    expression[UnixTimestamp]("unix_timestamp"),
    expression[WeekOfYear]("weekofyear"),
    expression[Year]("year"),

    // collection functions
    expression[Size]("size"),
    expression[SortArray]("sort_array"),
    expression[ArrayContains]("array_contains"),

    // misc functions
    expression[Crc32]("crc32"),
    expression[Md5]("md5"),
    expression[Sha1]("sha"),
    expression[Sha1]("sha1"),
    expression[Sha2]("sha2"),
    expression[SparkPartitionID]("spark_partition_id"),
    expression[InputFileName]("input_file_name"),
    expression[MonotonicallyIncreasingID]("monotonically_increasing_id")
  )



这样当Analyzer在执行execute方法, 对所有的node进行Rule的时候, 有一个Rule叫ResolveFunctions, 下面是analyzer里面定义的batch:
  lazy val batches: Seq[Batch] = Seq(
    Batch("Substitution", fixedPoint,
      CTESubstitution,
      WindowsSubstitution),
    Batch("Resolution", fixedPoint,
      ResolveRelations ::
      ResolveReferences ::
      ResolveGroupingAnalytics ::
      ResolvePivot ::
      ResolveUpCast ::
      ResolveSortReferences ::
      ResolveGenerate ::
      ResolveFunctions ::
      ResolveAliases ::
      ExtractWindowExpressions ::
      GlobalAggregates ::
      ResolveAggregateFunctions ::
      HiveTypeCoercion.typeCoercionRules ++
      extendedResolutionRules : _*),
    Batch("Nondeterministic", Once,
      PullOutNondeterministic,
      ComputeCurrentTime),
    Batch("UDF", Once,
      HandleNullInputsForUDF),
    Batch("Cleanup", fixedPoint,
      CleanupAliases)
  )


在ResolveFunctions 是这样定义的:
  object ResolveFunctions extends Rule[LogicalPlan] {
    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
      case q: LogicalPlan =>
        q transformExpressions {
          case u if !u.childrenResolved => u // Skip until children are resolved.
          case u @ UnresolvedFunction(name, children, isDistinct) =>
            withPosition(u) {
              registry.lookupFunction(name, children) match {
                // DISTINCT is not meaningful for a Max or a Min.
                case max: Max if isDistinct =>
                  AggregateExpression(max, Complete, isDistinct = false)
                case min: Min if isDistinct =>
                  AggregateExpression(min, Complete, isDistinct = false)
                // We get an aggregate function, we need to wrap it in an AggregateExpression.
                case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct)
                // This function is not an aggregate function, just return the resolved one.
                case other => other
              }
            }
        }
    }
  }


看到这个方法会对所有的expression进行遍历:
registry.lookupFunction(name, children) match{
...
}

如果我们传入的是max或者min, 或者不属于这两者的, 那么直接就能返回aggregateexpression:
AggregateExpression(max, Complete, isDistinct = false)
AggregateExpression(min, Complete, isDistinct = false)
AggregateExpression(agg, Complete, isDistinct)
这样我们传入的max min就被registryFunction里面的expression代替了, 继续通过其他Rule执行来变成resolvedaggreFunction。

可以看到我们定义的max min或者avg其实在构建DataFrame的时候已经在其最总的执行计划里面了, 就不难理解为什么我们这样传入参数就能得到这些结果。

根据测试结果, 传入agg的expression的方法远比rdd计算获取结果快的多。 目前来看, 如果能用agg这样去获取想要的结果, 那么就不要用rdd去进行计算了。

如果有什么不对的地方, 请指正

ps:可以试一下传入的参数不在registryFunction里面的话会由checkAnalysis(resolvedAggregate)这个方法发现及抛出异常


分享到:
评论

相关推荐

    sparkSQL基本操作.zip

    例如,`df.createOrReplaceTempView("tempView")`后,可以使用`spark.sql("SELECT * FROM tempView WHERE column = 'value'")`执行SQL查询。 6. **DataFrame优化**:Spark SQL使用 Catalyst 编译器优化查询计划,...

    Agg的.NET移植Agg-Sharp.zip

    示例代码:using MatterHackers.Agg.UI; using System; namespace MatterHackers.Agg { public class HelloWorld : SystemWindow { public HelloWorld() : base(640, 480) { // add the ...

    Spark SQL编程初级实践-答案-实验报告-纠正版

    val avgAge = df.agg(avg("age")).first().getDouble(0) // 11. 计算 age 的最小值 val minAge = df.agg(min("age")).first().getInt(0) ``` #### 四、实验总结 通过本次实验,学生不仅掌握了Spark SQL的基本操作...

    Python数据分析常用方法手册.pdf

    - 分组计算:`df.groupby('column').agg(func)`按列分组并应用聚合函数。 - 聚合操作:`df.sum()`, `df.mean()`, `df.median()`等。 - 数据透视表:`pd.pivot_table(data, values=None, index=None, columns=None...

    agg-2.5 AGG是一个开源、高效的跨平台2D图形库

    •如果要用AGG的控件和窗体,要加入[AGG]\src\ctrl\*.cpp和[AGG]\src\platform\<OS>\*.cpp,头文件在[AGG]\include\ctrl和[AGG]\include\platform里 •如果要用到TrueType字体显示,要加入[AGG]\font_win32_tt目录下...

    django-rest-framework-aggregates:将Django模型查询集的聚合功能公开给DRF API

    此渲染器将覆盖对api v2 .agg端点的调用的默认行为。 支持GET调用以列出以下格式的端点: endpoint.agg/?aggregate[Count]=(field to count) endpoint.agg/?aggregate[Sum]=(field to sum) endpoint.agg/?...

    pandas 基础代码

    - 分组聚合:`df.groupby('group_column').agg(func)` 8. **合并与连接** - 合并DataFrame:`pd.concat([df1, df2], axis=0/1)` - 内连接:`df1.merge(df2, on='common_column')` - 左/右/外连接:`left_on`, `...

    jiexi_数据解析_

    描述中提到的“对采集到的数据进行解析,可以直接更换数据,采集”进一步强调了解析过程对于数据操作和更新的重要性。 数据解析通常是通过编程语言实现的,例如Python,它具有强大的数据处理库,如Pandas、Numpy和...

    pandas 对group进行聚合的例子

    接着,我们可以使用`.agg()`方法对每个分组执行聚合操作。聚合操作可以包括求和、平均值、最大值、最小值等。在示例中,`df.groupby('A').agg('min')`计算了每组中'B'和'C'列的最小值。`df.groupby('A').agg(['min',...

    详谈pandas中agg函数和apply函数的区别

    - `agg`: 这个函数主要用于**聚合运算**,它接受一个或多个函数作为参数,对DataFrame或GroupBy对象的列执行聚合操作。这些函数通常会减少数据的维度,如计算平均值、总和、最大值等。`agg`的目的是将一维数组(如...

    Pandas面试题.pdf

    aggregated = df.groupby('col').agg(custom_agg) ``` 47. **如何使用Pandas的MultiIndex来处理高维数据?** - 使用`.stack()`和`.unstack()`: ```python stacked = df.stack() unstacked = stacked.unstack...

    pandas使用工作技能总结

    - `df.groupby('A')['B'].agg(['mean', 'sum'])`:对每个分组执行多个操作。 **2.4 和 fillna 连用** - `df.groupby('A').fillna(method='ffill')`:向前填充缺失值。 #### 3. 注意事项 - 在使用 groupby 时,需要...

    Python数据科学速查表 - Pandas 基础

    - 分组聚合:`groupby().agg(['function1', 'function2'])` ### 8. 时间序列分析 Pandas支持时间序列数据处理,包括日期和时间操作。 - 创建日期索引:`pd.date_range(start, end)` - 将列转换为日期:`df['date_...

    agg_v2.0.0.apk

    agg_v2.0.0.apk

    Python pandas数据转化.docx

    result = grouped.B.agg(['sum', 'max']) ``` ### 4. 数据合并 #### `pd.merge()` 合并两个DataFrame,类似于SQL中的JOIN操作,可以按共同的列进行连接。 ```python df2 = pd.DataFrame({'A': [1, 2], 'D': [10, ...

    PySpark_Day06:SQL and DataFrames.pdf

    首先,我们需要创建一个 SparkSession 对象,它是 SparkSQL 的入口点,可以用来创建 DataFrame 和执行 SQL 查询。 ```python from pyspark.sql import SparkSession spark = SparkSession.builder.appName('...

    agg学习手册

    ### AGG学习手册知识点解析 #### 一、AGG简介 **AGG**(Anti-Grain Geometry)是一个开源的2D图形库,以其高效、跨平台的特点而著称。相较于GDI+(Graphics Device Interface Plus),AGG不仅提供了更为灵活的编程...

    agg-2.4-2.1.i386.rpm

    ( agg-2.4-2.1.i386.rpm )

Global site tag (gtag.js) - Google Analytics