

当前位置:首页 > 互联网 IT业界 > Spark SQL 实现分层抽样和分层随机抽样

Spark SQL 实现分层抽样和分层随机抽样

admin 互联网 IT业界 34热度


最近需要实现一段 Spark SQL 逻辑,对数据集进行抽样指定的行数,才发现直接使用​TABLESAMPLE​函数抽样指定行数的方法其实是非随机的。

由于数据集较大,刚开始的逻辑是,取窗口函数随机排序后 row_number 的前 n 行。但运行速度较慢,所以想起了 TABLESAMLE 函数,支持直接取 Rows, 尝试后发现速度特别快,基本上几秒内就完成对亿级数据的采样。所以好奇就去查看文档和代码逻辑。

Spark SQL 实现分层抽样和分层随机抽样

The TABLESAMPLE statement is used to sample the table. It supports the following sampling methods:

TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages are defined as a number between 0 and 100. TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a x out of y fraction.

Note: TABLESAMPLE returns the approximate number of rows or fraction requested.


源码中,匹配 SampleByRowsContext时,调用的方法是 Limit(expression(ctx.expression), query),也就是说和 limit rows是一个逻辑。

而 SampleByPercentileContext实现的才是随机采样。

所以,如果对抽样的随机性有要求,还是老老实实用 SampleByPercentileContext,或者窗口函数。

总结:Spark SQL 随机抽样方法



使用窗口函数 + 随机排序进行抽样

WITH RankedData AS ( SELECT *, row_number() OVER (ORDER BY rand(2077)) as rn FROM your_table)SELECT *FROM RankedDataWHERE rn <= 1000





分层抽样通常在数据科学中使用较多,为了保证样本的随机性,通常情况下,我们需要对 ​y​标签进行分层抽样;如果考虑时间因素的影响,为了保证样本时间的随机性,通常还需要对​月份 + y标签​或者​日期 + y标签​进行双层的分层抽样。


WITH RankedData AS ( SELECT *, row_number() OVER (PARTITION BY 分层字段 ORDER BY rand(2077)) as rn FROM your_table)SELECT *FROM RankedDataWHERE rn <= 100 -- 每层抽取100条数据




WITH RankedData AS ( SELECT *, row_number() OVER (PARTITION BY 分层字段 ORDER BY rand()) as rn, count(*) OVER (PARTITION BY 分层字段) as total_count FROM your_table)SELECT *FROM RankedDataWHERE rn <= total_count * 0.1 -- 每层抽取10%的数据 对双字段分层


WITH RankedData AS ( SELECT *, row_number() OVER (PARTITION BY 分层字段1, 分层字段2 ORDER BY rand()) as rn, count(*) OVER (PARTITION BY 分层字段1, 分层字段2) as total_count FROM your_table)SELECT *FROM RankedDataWHERE rn <= total_count * 0.1 -- 每层抽取10%的数据

附 相关源码:

/** * Add a [[Sample]] to a logical plan. * * This currently supports the following sampling methods: * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. * - TABLESAMPLE(x PERCENT) [REPEATABLE (y)]: Sample the table down to the given percentage with * seed y. Note that percentages are defined as a number between 0 and 100. * - TABLESAMPLE(BUCKET x OUT OF y) [REPEATABLE (z)]: Sample the table down to a x divided by * y fraction with seed z. */ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { // Create a sampled plan if we need one. def sample(fraction: Double, seed: Long): Sample = { // The range of fraction accepted by Sample is [0, 1]. Because Hives block sampling // function takes X PERCENT as the input and the range of X is [0, 100], we need to // adjust the fraction. val eps = RandomSampler.roundingEpsilon validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, s"Sampling fraction ($fraction) must be on interval [0, 1]", ctx) Sample(0.0, fraction, withReplacement = false, seed, query) } if (ctx.sampleMethod() == null) { throw QueryParsingErrors.emptyInputForTableSampleError(ctx) } val seed = if (ctx.seed != null) { ctx.seed.getText.toLong } else { (math.random() * 1000).toLong } ctx.sampleMethod() match { case ctx: SampleByRowsContext => Limit(expression(ctx.expression), query) case ctx: SampleByPercentileContext => val fraction = ctx.percentage.getText.toDouble val sign = if (ctx.negativeSign == null) 1 else -1 sample(sign * fraction / 100.0d, seed) case ctx: SampleByBytesContext => val bytesStr = ctx.bytes.getText if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { throw QueryParsingErrors.tableSampleByBytesUnsupportedError("byteLengthLiteral", ctx) } else { throw QueryParsingErrors.invalidByteLengthLiteralError(bytesStr, ctx) } case ctx: SampleByBucketContext if ctx.ON() != null => if (ctx.identifier != null) { throw QueryParsingErrors.tableSampleByBytesUnsupportedError( "BUCKET x OUT OF y ON colname", ctx) } else { throw QueryParsingErrors.tableSampleByBytesUnsupportedError( "BUCKET x OUT OF y ON function", ctx) } case ctx: SampleByBucketContext => sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble, seed) } }

更新时间 2024-05-07 17:04:45