From 791207eff534e25def614afa00b14245625e55e8 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 1 Jun 2026 14:05:11 +0000 Subject: [PATCH 1/4] Tune compressed matmul fast paths and Spark execution decisions Mixes two related performance changes: refined compressed multiply heuristics, and a Spark-vs-CP decision refresh on the Hop layer. CLALib matmul changes: - CLALibMMChain: for XtXv with few col groups and a wide-enough matrix, compute X' * X via leftMultByTransposeSelf and finish with a regular matrix multiply against v. Cheaper than chaining when the X' * X path can stay compressed - CLALibTSMM: refactor leftMultByTransposeSelf into a package-private helper so MMChain can call it; widen the ColGroupUncompressed handling - CLALibRightMultBy: stop forcing decompression for ASDC / ASDCZero inputs; they have working preAggregate paths that beat the dense fallback - CLALibCompAgg: fix blklen rounding so the last partition is not short by k rows on parallel aggregates Spark/CP exec-decision refresh (Hop, UnaryOp, BinaryOp): - Hop: new helpers hasSparkOutput() and isScalarOrVectorBellowBlockSize() shared between unary and binary decision points - UnaryOp.optFindExecType: replace the inline chain of negations with isDisallowedSparkOps(), allow Frame outputs, and pull unary ops into Spark whenever the input already has a Spark output - BinaryOp.optFindExecType: same kind of restructuring; allow matrix-or-frame outputs to be pulled into Spark when exactly one operand is a scalar or small vector Instruction-side adjustments: - VariableCPInstruction (CAST_AS_MATRIX from frame): use the parallel MatrixBlockFromFrame.convertToMatrixBlock(fin, k) path instead of the single-threaded DataConverter helper - ParameterizedBuiltinCPInstruction (transformdecode): call the parallel decoder.decode(data, out, k) overload using InfrastructureAnalyzer.getLocalParallelism() --- .../java/org/apache/sysds/hops/BinaryOp.java | 36 ++++++++++------ src/main/java/org/apache/sysds/hops/Hop.java | 11 +++++ .../java/org/apache/sysds/hops/UnaryOp.java | 34 ++++++++++----- .../runtime/compress/lib/CLALibCompAgg.java | 2 +- .../runtime/compress/lib/CLALibMMChain.java | 6 +++ .../compress/lib/CLALibRightMultBy.java | 14 +++--- .../runtime/compress/lib/CLALibTSMM.java | 43 +++++++++++++------ .../cp/ParameterizedBuiltinCPInstruction.java | 2 +- .../cp/VariableCPInstruction.java | 3 +- 9 files changed, 108 insertions(+), 43 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index dc7edf76e50..5accb497501 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -763,8 +763,8 @@ protected ExecType optFindExecType(boolean transitive) { checkAndSetForcedPlatform(); - DataType dt1 = getInput().get(0).getDataType(); - DataType dt2 = getInput().get(1).getDataType(); + final DataType dt1 = getInput(0).getDataType(); + final DataType dt2 = getInput(1).getDataType(); if( _etypeForced != null ) { setExecType(_etypeForced); @@ -812,18 +812,28 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) { checkAndSetInvalidCPDimsAndSize(); } - //spark-specific decision refinement (execute unary scalar w/ spark input and + // spark-specific decision refinement (execute unary scalar w/ spark input and // single parent also in spark because it's likely cheap and reduces intermediates) - if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED && - getDataType().isMatrix() // output should be a matrix - && (dt1.isScalar() || dt2.isScalar()) // one side should be scalar - && supportsMatrixScalarOperations() // scalar operations - && !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint - && getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent - && !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec - && getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) { - // pull unary scalar operation into spark - _etype = ExecType.SPARK; + if(transitive // we allow transitive Spark operations. continue sequences of spark operations + && _etype == ExecType.CP // The instruction is currently in CP + && _etypeForced != ExecType.CP // not forced CP + && _etypeForced != ExecType.FED // not federated + && (getDataType().isMatrix() || getDataType().isFrame()) // output should be a matrix or frame + ) { + final boolean v1 = getInput(0).isScalarOrVectorBellowBlockSize(); + final boolean v2 = getInput(1).isScalarOrVectorBellowBlockSize(); + final boolean left = v1 == true; // left side is the vector or scalar + final Hop sparkIn = getInput(left ? 1 : 0); + if((v1 ^ v2) // XOR only one side is allowed to be a vector or a scalar. + && (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation + && sparkIn.getParent().size() == 1 // only one parent + && !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec + && sparkIn.optFindExecType() == ExecType.SPARK // input was spark op. + && !(sparkIn instanceof DataOp) // input is not checkpoint + ) { + // pull operation into spark + _etype = ExecType.SPARK; + } } if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE && diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index 86749d44c1c..675fbb380a1 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -1045,6 +1045,12 @@ public final String toString() { // ======================================================================================== + protected boolean isScalarOrVectorBellowBlockSize(){ + return getDataType().isScalar() || (dimsKnown() && + (( _dc.getRows() == 1 && _dc.getCols() < ConfigurationManager.getBlocksize()) + || _dc.getCols() == 1 && _dc.getRows() < ConfigurationManager.getBlocksize())); + } + protected boolean isVector() { return (dimsKnown() && (_dc.getRows() == 1 || _dc.getCols() == 1) ); } @@ -1629,6 +1635,11 @@ protected void setMemoryAndComputeEstimates(Lop lop) { lop.setComputeEstimate(ComputeCost.getHOPComputeCost(this)); } + protected boolean hasSparkOutput(){ + return (this.optFindExecType() == ExecType.SPARK + || (this instanceof DataOp && ((DataOp)this).hasOnlyRDD())); + } + /** * Set parse information. * diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index b3475edfbae..73e24eb17e2 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -366,7 +366,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) } else { sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); } - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); + + if(getDataType() == DataType.FRAME) + return OptimizerUtils.estimateSizeExactFrame(dim1, dim2); + else + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); } @Override @@ -463,6 +467,13 @@ public boolean isMetadataOperation() { || _op == OpOp1.CAST_AS_LIST; } + private boolean isDisallowedSparkOps(){ + return isCumulativeUnaryOperation() + || isCastUnaryOperation() + || _op==OpOp1.MEDIAN + || _op==OpOp1.IQM; + } + @Override protected ExecType optFindExecType(boolean transitive) { @@ -493,19 +504,22 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto checkAndSetInvalidCPDimsAndSize(); } + //spark-specific decision refinement (execute unary w/ spark input and //single parent also in spark because it's likely cheap and reduces intermediates) - if( _etype == ExecType.CP && _etypeForced != ExecType.CP - && getInput().get(0).optFindExecType() == ExecType.SPARK - && getDataType().isMatrix() - && !isCumulativeUnaryOperation() && !isCastUnaryOperation() - && _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM - && !(getInput().get(0) instanceof DataOp) //input is not checkpoint - && getInput().get(0).getParent().size()==1 ) //unary is only parent - { + if(_etype == ExecType.CP // currently CP instruction + && _etype != ExecType.SPARK /// currently not SP. + && _etypeForced != ExecType.CP // not forced as CP instruction + && getInput(0).hasSparkOutput() // input is a spark instruction + && (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame + && !isDisallowedSparkOps() // is invalid spark instruction + // && !(getInput().get(0) instanceof DataOp) // input is not checkpoint + // && getInput(0).getParent().size() <= 1// unary is only parent + ) { //pull unary operation into spark _etype = ExecType.SPARK; } + //mark for recompile (forever) setRequiresRecompileIfNecessary(); @@ -520,7 +534,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent } else { setRequiresRecompileIfNecessary(); } - + return _etype; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 99693635a9b..948a78f96af 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -486,7 +486,7 @@ private static List> generateUnaryAggregateOverlappingFuture final ArrayList tasks = new ArrayList<>(); final int nCol = m1.getNumColumns(); final int nRow = m1.getNumRows(); - final int blklen = Math.max(64, nRow / k); + final int blklen = Math.max(64, (nRow + k) / k); final List groups = m1.getColGroups(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); if(shouldFilter) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java index d82d58e323e..cc7953f8c5d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -95,6 +96,11 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix if(x.isEmpty()) return returnEmpty(x, out); + if(ctype == ChainType.XtXv && x.getColGroups().size() < 5 && x.getNumColumns()> 30){ + MatrixBlock tmp = CLALibTSMM.leftMultByTransposeSelf(x, k); + return tmp.aggregateBinaryOperations(tmp, v, out, InstructionUtils.getMatMultOperator(k)); + } + // Morph the columns to efficient types for the operation. x = filterColGroups(x); double preFilterTime = t.stop(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index f14d6833d95..ce06262b9a5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -31,6 +31,8 @@ import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ASDC; +import org.apache.sysds.runtime.compress.colgroup.ASDCZero; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -71,10 +73,10 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc if(m2 instanceof CompressedMatrixBlock) m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k); - if(betterIfDecompressed(m1)) { - // perform uncompressed multiplication. - return decompressingMatrixMult(m1, m2, k); - } + // if(betterIfDecompressed(m1)) { + // // perform uncompressed multiplication. + // return decompressingMatrixMult(m1, m2, k); + // } if(!allowOverlap) { LOG.trace("Overlapping output not allowed in call to Right MM"); @@ -143,7 +145,9 @@ private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, Mat private static boolean betterIfDecompressed(CompressedMatrixBlock m) { for(AColGroup g : m.getColGroups()) { - if(!(g instanceof ColGroupUncompressed) && g.getNumValues() * 2 >= m.getNumRows()) { + // TODO add subpport for decompressing RMM to ASDC and ASDCZero + if(!(g instanceof ColGroupUncompressed || g instanceof ASDC || g instanceof ASDCZero) && + g.getNumValues() * 2 >= m.getNumRows()) { return true; } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java index a1d47a9b150..d0396b63810 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java @@ -31,6 +31,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -42,6 +43,10 @@ private CLALibTSMM() { // private constructor } + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, int k) { + return leftMultByTransposeSelf(cmb, new MatrixBlock(), k); + } + /** * Self left Matrix multiplication (tsmm) * @@ -51,24 +56,32 @@ private CLALibTSMM() { * @param ret The output matrix to put the result into * @param k The parallelization degree allowed */ - public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + final int numColumns = cmb.getNumColumns(); + final int numRows = cmb.getNumRows(); + if(cmb.isEmpty()) + return new MatrixBlock(numColumns, numColumns, true); + // create output matrix block + if(ret == null) + ret = new MatrixBlock(numColumns, numColumns, false); + else + ret.reset(numColumns, numColumns, false); + ret.allocateDenseBlock(); final List groups = cmb.getColGroups(); - final int numColumns = cmb.getNumColumns(); - if(groups.size() >= numColumns) { + if(groups.size() >= numColumns || containsUncompressedColGroup(groups)) { MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k); LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k); - return; + return ret; } - final int numRows = cmb.getNumRows(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); final boolean overlapping = cmb.isOverlapping(); if(shouldFilter) { final double[] constV = new double[numColumns]; final List filteredGroups = CLALibUtils.filterGroups(groups, constV); tsmmColGroups(filteredGroups, ret, numRows, overlapping, k); - addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV); + addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV, k); } else { @@ -77,17 +90,23 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret)); ret.examSparsity(); + return ret; + } + + private static boolean containsUncompressedColGroup(List groups) { + for(AColGroup g : groups) + if(g instanceof ColGroupUncompressed) + return true; + return false; } private static void addCorrectionLayer(List filteredGroups, MatrixBlock result, int nRows, int nCols, - double[] constV) { + double[] constV, int k) { final double[] retV = result.getDenseBlockValues(); final double[] filteredColSum = CLALibUtils.getColSum(filteredGroups, nCols, nRows); addCorrectionLayer(constV, filteredColSum, nRows, retV); } - - private static void tsmmColGroups(List groups, MatrixBlock ret, int nRows, boolean overlapping, int k) { if(k <= 1) tsmmColGroupsSingleThread(groups, ret, nRows); @@ -136,12 +155,12 @@ private static void tsmmColGroupsMultiThread(List groups, MatrixBlock public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) { final int nColRow = constV.length; - for(int row = 0; row < nColRow; row++){ + for(int row = 0; row < nColRow; row++) { int offOut = nColRow * row; final double v1l = constV[row]; final double v2l = filteredColSum[row] + constV[row] * nRow; - for(int col = row; col < nColRow; col++){ - ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; + for(int col = row; col < nColRow; col++) { + ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 119589a3033..e53958ac4b8 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -352,7 +352,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString())) { // compute transformdecode Decoder decoder = DecoderFactory .createDecoder(getParameterMap().get("spec"), colnames, null, meta, data.getNumColumns()); - FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema())); + FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema()), InfrastructureAnalyzer.getLocalParallelism()); fbout.setColumnNames(Arrays.copyOfRange(colnames, 0, fbout.getNumColumns())); // release locks diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 359df747e7b..0f707b74412 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.MatrixBlockFromFrame; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; @@ -923,7 +924,7 @@ private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { switch( getInput1().getDataType() ) { case FRAME: { FrameBlock fin = ec.getFrameInput(getInput1().getName()); - MatrixBlock out = DataConverter.convertToMatrixBlock(fin); + MatrixBlock out = MatrixBlockFromFrame.convertToMatrixBlock(fin, k); ec.releaseFrameInput(getInput1().getName()); ec.setMatrixOutput(output.getName(), out); break; From c12eccad9f8fa480573cb7d1b00397412c452496 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 8 Jun 2026 15:14:40 +0000 Subject: [PATCH 2/4] Fix race in parallel composite decode breaking dummycode+recode The multi-threaded DecoderComposite.decode submitted one task per decoder per row block, running all decoders concurrently. This broke the ordering dependency between decoders: recode-on-output reads the category indexes written by the dummycode decoder, so when the recode task raced ahead it read unwritten cells and produced null or the raw index instead of the original value. Parallelize over row blocks instead, running all decoders in order within each block via the sequential block decode. Also short-circuit to the single-threaded path when k <= 1. Fixes order-dependent failures in TransformFrameEncodeDecodeTest and TransformFrameEncodeColmapTest (dummycode single-node/hybrid) that surfaced once transformdecode started using the parallel decode path. --- .../transform/decode/DecoderComposite.java | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index f4bc9f8b216..f1afcfac194 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -62,17 +62,20 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + if(k <= 1) + return decode(in, out); final ExecutorService pool = CommonThreadPool.get(k); out.ensureAllocatedColumns(in.getNumRows()); try { final List> tasks = new ArrayList<>(); int blz = Math.max(in.getNumRows() / k, 1000); - for(Decoder decoder : _decoders){ - for(int i = 0; i < in.getNumRows(); i += blz){ - final int start = i; - final int end = Math.min(in.getNumRows(), i + blz); - tasks.add(pool.submit(() -> decoder.decode(in, out, start, end))); - } + // Parallelize over row blocks (not over decoders): all decoders must + // run in order within a block, e.g. recode-on-output depends on the + // category indexes produced by the preceding dummycode decoder. + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decode(in, out, start, end))); } for(Future f : tasks) f.get(); From b6e390018925226e79f4704a1af90d184c84cacd Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 11 Jun 2026 16:03:59 +0000 Subject: [PATCH 3/4] Add coverage tests for compressed matmul fast paths and exec-type decisions Re-enable the betterIfDecompressed gate in CLALibRightMultBy so the decompressing right-multiply path stays reachable, while still excluding ASDC/ASDCZero column groups from forced decompression. Add targeted and end-to-end tests covering the recently tuned paths: - CLALibMMChainTest: the public CLALibTSMM.leftMultByTransposeSelf overload (wide/narrow/uncompressed/empty/reuse/null) and the XtXv mm-chain fast path, including a tile-then-recompress wide-chain case. - CLALibRightMultBySDCTest: right multiply on ASDC/ASDCZero inputs is not forced to decompress, single-threaded and parallel. - DecoderCompositeTest: parallel and single-thread composite decode, exercising the dummycode+recode ordering dependency. - SparkTransitiveExecTypeTest with DML scripts: UnaryOp/BinaryOp/Hop transitive Spark exec-type pulling under a constrained memory budget. - CompressedTestBase: two parameterized e2e cases that validate the TSMM overload and the wide XtXv fast path against uncompressed results across all compression configurations. --- .../compress/lib/CLALibRightMultBy.java | 8 +- .../compress/CompressedTestBase.java | 46 +++ .../compress/lib/CLALibMMChainTest.java | 273 ++++++++++++++++++ .../lib/CLALibRightMultBySDCTest.java | 116 ++++++++ .../frame/transform/DecoderCompositeTest.java | 132 +++++++++ .../SparkTransitiveExecTypeTest.java | 104 +++++++ .../sparkexectype/SparkExecTypeBinary.dml | 33 +++ .../sparkexectype/SparkExecTypeUnary.dml | 31 ++ 8 files changed, 739 insertions(+), 4 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/compress/lib/CLALibMMChainTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/lib/CLALibRightMultBySDCTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/frame/transform/DecoderCompositeTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/sparkexectype/SparkTransitiveExecTypeTest.java create mode 100644 src/test/scripts/functions/sparkexectype/SparkExecTypeBinary.dml create mode 100644 src/test/scripts/functions/sparkexectype/SparkExecTypeUnary.dml diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index ce06262b9a5..642b57124f1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -73,10 +73,10 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc if(m2 instanceof CompressedMatrixBlock) m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k); - // if(betterIfDecompressed(m1)) { - // // perform uncompressed multiplication. - // return decompressingMatrixMult(m1, m2, k); - // } + if(betterIfDecompressed(m1)) { + // perform uncompressed multiplication. + return decompressingMatrixMult(m1, m2, k); + } if(!allowOverlap) { LOG.trace("Overlapping output not allowed in call to Right MM"); diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java index c1fb10d211a..a6ad0d4ee0d 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java @@ -60,6 +60,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.lib.CLALibCBind; +import org.apache.sysds.runtime.compress.lib.CLALibTSMM; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysds.runtime.functionobjects.Divide; @@ -503,6 +504,51 @@ public void testMatrixMultChain(ChainType ctype) { } } + @Test + public void testTransposeSelfLeftMultOverload() { + // Exercises the package-public CLALibTSMM.leftMultByTransposeSelf(cmb, k) entry point (used by the + // XtXv mm-chain fast path) across all compression configurations. + if(!(cmb instanceof CompressedMatrixBlock)) + return; + try { + MatrixBlock ret2 = CLALibTSMM.leftMultByTransposeSelf((CompressedMatrixBlock) cmb, _k); + MatrixBlock ucRet2 = mb.transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, _k); + compareResultMatrices(ucRet2, ret2, overlappingType != OverLapping.NONE ? 256 : 2); + } + catch(Exception e) { + e.printStackTrace(); + throw new RuntimeException(bufferedToString + "\n" + e.getMessage(), e); + } + } + + @Test + public void testMatrixMultChainXtXvWide() { + // Widen the input beyond 30 columns so the XtXv fast path triggers, validating it against the + // uncompressed result for whatever compression the current configuration produces. + if(!(cmb instanceof CompressedMatrixBlock)) + return; + try { + final int nCol = mb.getNumColumns(); + final int reps = (int) Math.ceil(31.0 / nCol) + 1; + MatrixBlock wide = mb; + for(int i = 1; i < reps; i++) + wide = wide.append(mb, new MatrixBlock(), true); + + MatrixBlock wideC = CompressedMatrixBlockFactory.compress(wide, _k).getLeft(); + if(!(wideC instanceof CompressedMatrixBlock)) + return; // not compressible in this configuration + + MatrixBlock vector1 = TestUtils.generateTestMatrixBlock(wide.getNumColumns(), 1, 0.9, 1.5, 1.0, 3); + MatrixBlock ucRet2 = wide.chainMatrixMultOperations(vector1, null, new MatrixBlock(), ChainType.XtXv, _k); + MatrixBlock ret2 = wideC.chainMatrixMultOperations(vector1, null, new MatrixBlock(), ChainType.XtXv, _k); + compareResultMatricesPercentDistance(ucRet2, ret2, 0.99, 0.99); + } + catch(Exception e) { + e.printStackTrace(); + throw new RuntimeException(bufferedToString + "\n" + e.getMessage(), e); + } + } + @Test public void testVectorMatrixMult() { MatrixBlock vector = TestUtils.generateTestMatrixBlock(1, rows, 0, 5, 1.0, 3); diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibMMChainTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibMMChainTest.java new file mode 100644 index 00000000000..833128ad9f0 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibMMChainTest.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.lib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.lops.MapMultChain.ChainType; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.lib.CLALibTSMM; +import org.apache.sysds.lops.MMTSJ.MMTSJType; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.component.compress.mapping.MappingTestUtil; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Targeted tests for the compressed transpose-self multiply ({@link CLALibTSMM}) and the XtXv mm-chain fast path that + * was added in {@code CLALibMMChain}. The fast path triggers when the input has fewer than five column groups and more + * than thirty columns, in which case the chain is computed as {@code (t(X) %*% X) %*% v}. + */ +public class CLALibMMChainTest { + protected static final Log LOG = LogFactory.getLog(CLALibMMChainTest.class.getName()); + + @BeforeClass + public static void setup() { + Thread.currentThread().setName("main_test_" + Thread.currentThread().getId()); + } + + /** + * Build a compressed matrix backed by a single DDC column group spanning all {@code nCol} columns. This guarantees a + * single (non-uncompressed) column group, which is what triggers the mm-chain fast path for wide enough matrices. + */ + private static CompressedMatrixBlock singleDDC(int nRow, int nCol, int nVal, int seed) { + Random r = new Random(seed); + double[] dictValues = new double[nVal * nCol]; + for(int i = 0; i < dictValues.length; i++) + dictValues[i] = Math.round(r.nextDouble() * 20 - 10); + IDictionary dict = Dictionary.create(dictValues); + AMapToData data = MappingTestUtil.createRandomMap(nRow, nVal, r); + AColGroup g = ColGroupDDC.create(ColIndexFactory.create(nCol), dict, data, null); + CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol); + cmb.allocateColGroup(g); + cmb.recomputeNonZeros(); + return cmb; + } + + private static CompressedMatrixBlock uncompressedGroup(int nRow, int nCol, int seed) { + MatrixBlock mb = TestUtils.round(TestUtils.generateTestMatrixBlock(nRow, nCol, -10, 10, 1.0, seed)); + CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol); + cmb.allocateColGroup(ColGroupUncompressed.create(mb)); + cmb.recomputeNonZeros(); + return cmb; + } + + private static CompressedMatrixBlock empty(int nRow, int nCol) { + CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol); + cmb.allocateColGroup(new ColGroupEmpty(ColIndexFactory.create(nCol))); + cmb.recomputeNonZeros(); + return cmb; + } + + @Test + public void tsmmWideSingleThread() { + execTSMM(singleDDC(200, 40, 6, 1), 1); + } + + @Test + public void tsmmWideParallel() { + execTSMM(singleDDC(200, 40, 6, 2), 4); + } + + @Test + public void tsmmNarrowSingleThread() { + execTSMM(singleDDC(200, 8, 4, 3), 1); + } + + @Test + public void tsmmNarrowParallel() { + execTSMM(singleDDC(200, 8, 4, 4), 4); + } + + @Test + public void tsmmUncompressedGroupSingleThread() { + // A compressed block holding an uncompressed column group must fall back to the dense tsmm path. + execTSMM(uncompressedGroup(150, 12, 5), 1); + } + + @Test + public void tsmmUncompressedGroupParallel() { + execTSMM(uncompressedGroup(150, 12, 6), 4); + } + + @Test + public void tsmmEmpty() { + CompressedMatrixBlock cmb = empty(100, 13); + MatrixBlock ret = CLALibTSMM.leftMultByTransposeSelf(cmb, 1); + assertEquals(13, ret.getNumRows()); + assertEquals(13, ret.getNumColumns()); + assertTrue("empty input must produce an empty result", ret.isEmptyBlock(false)); + } + + @Test + public void tsmmRetReused() { + // A non-null ret must be reset and reused, producing the same result as a fresh allocation. + CompressedMatrixBlock cmb = singleDDC(120, 36, 5, 7); + MatrixBlock preAllocated = new MatrixBlock(3, 3, 99.0); + preAllocated.allocateDenseBlock(); + MatrixBlock cRet = CLALibTSMM.leftMultByTransposeSelf(cmb, preAllocated, 4); + MatrixBlock uRet = CompressedMatrixBlock.getUncompressed(cmb) + .transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, 4); + TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0); + } + + @Test + public void tsmmRetNull() { + // Explicitly exercise the null-ret allocation branch of the helper. + CompressedMatrixBlock cmb = singleDDC(120, 36, 5, 8); + MatrixBlock cRet = CLALibTSMM.leftMultByTransposeSelf(cmb, null, 1); + MatrixBlock uRet = CompressedMatrixBlock.getUncompressed(cmb) + .transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, 1); + TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0); + } + + private static void execTSMM(CompressedMatrixBlock cmb, int k) { + try { + MatrixBlock cRet = CLALibTSMM.leftMultByTransposeSelf(cmb, k); + MatrixBlock uRet = CompressedMatrixBlock.getUncompressed(cmb) + .transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, k); + assertEquals(cmb.getNumColumns(), cRet.getNumRows()); + assertEquals(cmb.getNumColumns(), cRet.getNumColumns()); + TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void mmChainFastPathSingleThread() { + // 40 columns, single column group -> XtXv fast path. + execMMChain(singleDDC(200, 40, 6, 11), 1); + } + + @Test + public void mmChainFastPathParallel() { + execMMChain(singleDDC(200, 40, 6, 12), 4); + } + + @Test + public void mmChainFastPathFewGroups() { + // Two column groups (< 5) over 40 columns still triggers the fast path. + execMMChain(twoGroups(200, 40, 13), 4); + } + + @Test + public void mmChainRegularPathNarrow() { + // Only 20 columns -> below the width threshold, exercises the regular (non fast) chain path. + execMMChain(singleDDC(200, 20, 6, 14), 4); + } + + private static CompressedMatrixBlock twoGroups(int nRow, int nCol, int seed) { + final int half = nCol / 2; + Random r = new Random(seed); + List gs = new ArrayList<>(); + gs.add(ddcGroup(nRow, ColIndexFactory.create(0, half), 5, r)); + gs.add(ddcGroup(nRow, ColIndexFactory.create(half, nCol), 5, r)); + CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol); + cmb.allocateColGroupList(gs); + cmb.recomputeNonZeros(); + return cmb; + } + + private static AColGroup ddcGroup(int nRow, org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex cols, + int nVal, Random r) { + int nCol = cols.size(); + double[] dictValues = new double[nVal * nCol]; + for(int i = 0; i < dictValues.length; i++) + dictValues[i] = Math.round(r.nextDouble() * 20 - 10); + IDictionary dict = Dictionary.create(dictValues); + AMapToData data = MappingTestUtil.createRandomMap(nRow, nVal, r); + return ColGroupDDC.create(cols, dict, data, null); + } + + @Test + public void mmChainWideRecompressedDDC() { + // Mirrors the e2e CompressedTestBase#testMatrixMultChainXtXvWide flow: tile a narrow matrix until it + // exceeds the 30-column fast-path threshold, recompress it, then validate XtXv against uncompressed. + execMMChainWide(TestUtils.round(TestUtils.generateTestMatrixBlock(300, 4, -10, 10, 1.0, 21)), 1); + } + + @Test + public void mmChainWideRecompressedSparse() { + execMMChainWide(TestUtils.round(TestUtils.generateTestMatrixBlock(300, 3, 1, 5, 0.2, 22)), 4); + } + + private static void execMMChainWide(MatrixBlock base, int k) { + try { + final int nCol = base.getNumColumns(); + final int reps = (int) Math.ceil(31.0 / nCol) + 1; + MatrixBlock wide = base; + for(int i = 1; i < reps; i++) + wide = wide.append(base, new MatrixBlock(), true); + assertTrue("widened matrix must exceed the fast-path threshold", wide.getNumColumns() > 30); + + MatrixBlock wideC = CompressedMatrixBlockFactory.compress(wide, k).getLeft(); + assertTrue("tiled matrix should compress", wideC instanceof CompressedMatrixBlock); + + MatrixBlock v = TestUtils.generateTestMatrixBlock(wide.getNumColumns(), 1, 0.9, 1.5, 1.0, 3); + MatrixBlock uRet = wide.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k); + MatrixBlock cRet = wideC.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k); + TestUtils.compareMatrices(uRet, cRet, 1e-6, "wide recompressed mm-chain result mismatch"); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + private static void execMMChain(CompressedMatrixBlock cmb, int k) { + try { + final int cols = cmb.getNumColumns(); + MatrixBlock v = TestUtils.round(TestUtils.generateTestMatrixBlock(cols, 1, -3, 3, 1.0, 42)); + MatrixBlock uncompressed = CompressedMatrixBlock.getUncompressed(cmb); + + MatrixBlock cRet = cmb.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k); + MatrixBlock uRet = uncompressed.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k); + + assertEquals(cols, cRet.getNumRows()); + assertEquals(1, cRet.getNumColumns()); + TestUtils.compareMatrices(uRet, cRet, 1e-6, "mm-chain result mismatch"); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibRightMultBySDCTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibRightMultBySDCTest.java new file mode 100644 index 00000000000..0aa4064b5a7 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibRightMultBySDCTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.lib; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ASDC; +import org.apache.sysds.runtime.compress.colgroup.ASDCZero; +import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Right matrix multiply on compressed inputs that contain SDC / SDC-zeros column groups. + * + *

+ * The PR stops forcing a decompressing right multiply for {@link ASDC} / {@link ASDCZero} backed inputs (they have + * working pre-aggregate paths). These tests build such inputs and verify the compressed right multiply still matches + * the uncompressed reference for both single-threaded and parallel execution. + *

+ */ +public class CLALibRightMultBySDCTest { + protected static final Log LOG = LogFactory.getLog(CLALibRightMultBySDCTest.class.getName()); + + @BeforeClass + public static void setup() { + Thread.currentThread().setName("main_test_" + Thread.currentThread().getId()); + } + + /** + * Build a compressed matrix dominated by a single value with a handful of exceptions per column, which compresses + * into SDC / SDC-zeros column groups. + */ + private static CompressedMatrixBlock sdcBlock(int rows, int cols, double sparsity, int seed) { + MatrixBlock mb = TestUtils.round(TestUtils.generateTestMatrixBlock(rows, cols, 1, 5, sparsity, seed)); + CompressedMatrixBlock cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); + return cmb; + } + + private static boolean containsSDC(CompressedMatrixBlock cmb) { + for(AColGroup g : cmb.getColGroups()) + if(g instanceof ASDC || g instanceof ASDCZero) + return true; + return false; + } + + @Test + public void rightMultVectorSparseSingleThread() { + execRightMult(sdcBlock(500, 6, 0.2, 21), 1, 1); + } + + @Test + public void rightMultVectorSparseParallel() { + execRightMult(sdcBlock(500, 6, 0.2, 22), 1, 4); + } + + @Test + public void rightMultMatrixSparseSingleThread() { + execRightMult(sdcBlock(500, 6, 0.2, 23), 4, 1); + } + + @Test + public void rightMultMatrixSparseParallel() { + execRightMult(sdcBlock(500, 6, 0.2, 24), 4, 4); + } + + @Test + public void rightMultWideSparseParallel() { + execRightMult(sdcBlock(500, 6, 0.2, 27), 12, 4); + } + + private static void execRightMult(CompressedMatrixBlock cmb, int rhsCols, int k) { + try { + assertTrue("test input should contain an SDC/SDCZeros column group", containsSDC(cmb)); + + final int cols = cmb.getNumColumns(); + MatrixBlock right = TestUtils.round(TestUtils.generateTestMatrixBlock(cols, rhsCols, -3, 3, 1.0, 99)); + MatrixBlock uncompressed = CompressedMatrixBlock.getUncompressed(cmb); + + MatrixBlock cRet = CLALibRightMultBy.rightMultByMatrix(cmb, right, null, k); + MatrixBlock uRet = LibMatrixMult.matrixMult(uncompressed, right, k); + + TestUtils.compareMatricesBitAvgDistance(uRet, CompressedMatrixBlock.getUncompressed(cRet), 1024, 1); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/DecoderCompositeTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/DecoderCompositeTest.java new file mode 100644 index 00000000000..ccba674707b --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/DecoderCompositeTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.fail; + +import java.util.Random; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderComposite; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +/** + * Tests for the multi-threaded {@link DecoderComposite#decode(MatrixBlock, FrameBlock, int)} path. + * + *

+ * The parallel decode partitions over row blocks and runs all sub decoders in order within each block. This is + * important for the dummycode+recode case: the recode-on-output decoder reads the category indexes written by the + * preceding dummycode decoder, so running them out of order produces wrong (or null) values. These tests verify the + * parallel result equals the single-threaded result and reconstructs the original frame, and they also exercise the + * {@code k <= 1} short-circuit to the sequential path. + *

+ */ +public class DecoderCompositeTest { + protected static final Log LOG = LogFactory.getLog(DecoderCompositeTest.class.getName()); + + /** Enough rows that the parallel path forms multiple row blocks (block size is max(rows/k, 1000)). */ + private static final int ROWS = 8000; + + private static FrameBlock categoricalFrame(int rows, int nCol, int nCat, int seed) { + ValueType[] schema = new ValueType[nCol]; + for(int c = 0; c < nCol; c++) + schema[c] = ValueType.STRING; + String[][] data = new String[rows][nCol]; + Random r = new Random(seed); + for(int i = 0; i < rows; i++) + for(int c = 0; c < nCol; c++) + data[i][c] = "v" + r.nextInt(nCat); + return new FrameBlock(schema, data); + } + + private static Decoder buildDecoder(FrameBlock data, String spec, MultiColumnEncoder encoder) { + FrameBlock meta = encoder.getMetaData(new FrameBlock(data.getNumColumns(), ValueType.STRING)); + return DecoderFactory.createDecoder(spec, data.getColumnNames(), data.getSchema(), meta); + } + + private void runDecode(String spec, int nCol, int nCat) { + try { + FrameBlock data = categoricalFrame(ROWS, nCol, nCat, 17); + + MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, data.getColumnNames(), + data.getNumColumns(), null); + MatrixBlock encoded = encoder.encode(data, 1); + + Decoder decoder = buildDecoder(data, spec, encoder); + + FrameBlock single = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), 1); + FrameBlock parallel = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), 4); + + // Parallel decode must match the single-threaded decode exactly. + TestUtils.compareFrames(single, parallel, false); + // And both must reconstruct the original categorical values. + TestUtils.compareFrames(data, parallel, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void recodeOnly() { + runDecode("{recode:[C1,C2,C3]}", 3, 6); + } + + @Test + public void dummycodeAndRecode() { + // dummycode implies recode-on-output: the composite decoder is [Dummycode, Recode-on-output] + // and the recode step depends on the indexes the dummycode step writes. This is exactly the + // ordering the parallel fix protects against breaking. + runDecode("{dummycode:[C1,C2,C3]}", 3, 5); + } + + @Test + public void dummycodeAndRecodeSameColumns() { + // recode and dummycode listed on the same columns -> recoded then dummycoded, decoded in order. + runDecode("{recode:[C1,C2], dummycode:[C1,C2]}", 2, 4); + } + + @Test + public void singleThreadEqualsParallelManyCategories() { + runDecode("{dummycode:[C1,C2]}", 2, 25); + } + + @Test + public void decoderIsComposite() { + FrameBlock data = categoricalFrame(100, 2, 3, 1); + String spec = "{recode:[C1], dummycode:[C2]}"; + MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, data.getColumnNames(), + data.getNumColumns(), null); + encoder.encode(data, 1); + Decoder decoder = buildDecoder(data, spec, encoder); + if(!(decoder instanceof DecoderComposite)) + fail("expected a DecoderComposite but got " + decoder.getClass().getSimpleName()); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/sparkexectype/SparkTransitiveExecTypeTest.java b/src/test/java/org/apache/sysds/test/functions/sparkexectype/SparkTransitiveExecTypeTest.java new file mode 100644 index 00000000000..04e1e2e0be6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/sparkexectype/SparkTransitiveExecTypeTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.sparkexectype; + +import java.util.HashMap; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.recompile.Recompiler; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.apache.sysds.utils.stats.InfrastructureAnalyzer; +import org.junit.Assert; +import org.junit.Test; + +/** + * Exercises the transitive Spark exec-type refinement in {@link org.apache.sysds.hops.UnaryOp} and + * {@link org.apache.sysds.hops.BinaryOp}: cheap unary / matrix-scalar / matrix-vector operations whose input already + * has a Spark output are pulled into Spark. + * + *

+ * Each script is run in HYBRID mode with a constrained memory budget, once with the transitive decision enabled and + * once disabled. The results must match (correctness regardless of placement), and the transitive run must actually + * execute Spark instructions. + *

+ */ +public class SparkTransitiveExecTypeTest extends AutomatedTestBase { + + private static final String TEST_DIR = "functions/sparkexectype/"; + private static final String TEST_CLASS_DIR = TEST_DIR + SparkTransitiveExecTypeTest.class.getSimpleName() + "/"; + private static final String TEST_UNARY = "SparkExecTypeUnary"; + private static final String TEST_BINARY = "SparkExecTypeBinary"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_UNARY, new TestConfiguration(TEST_CLASS_DIR, TEST_UNARY, new String[] {"R"})); + addTestConfiguration(TEST_BINARY, new TestConfiguration(TEST_CLASS_DIR, TEST_BINARY, new String[] {"R"})); + } + + @Test + public void testUnaryPulledIntoSpark() { + runTransitiveExecTypeTest(TEST_UNARY); + } + + @Test + public void testBinaryPulledIntoSpark() { + runTransitiveExecTypeTest(TEST_BINARY); + } + + private void runTransitiveExecTypeTest(String testname) { + final boolean oldTransitive = OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE; + final ExecMode oldPlatform = setExecMode(ExecMode.HYBRID); + final long oldMem = InfrastructureAnalyzer.getLocalMaxMemory(); + // Small memory budget so the large operations are placed on Spark. + InfrastructureAnalyzer.setLocalMaxMemory(1024 * 1024 * 8); + + try { + getAndLoadTestConfiguration(testname); + fullDMLScriptName = getScript(); + programArgs = new String[] {"-args", output("R")}; + + // Reference run with the transitive Spark decision disabled. + OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false; + runTest(true, false, null, -1); + HashMap expected = readDMLScalarFromOutputDir("R"); + + // Run with the transitive Spark decision enabled (the path under test). + OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true; + runTest(true, false, null, -1); + HashMap actual = readDMLScalarFromOutputDir("R"); + + TestUtils.compareScalars(expected.get(new CellIndex(1, 1)), actual.get(new CellIndex(1, 1)), 1e-8); + Assert.assertTrue("Expected Spark instructions to be executed in the transitive run.", + Statistics.getNoOfExecutedSPInst() > 0); + } + finally { + OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = oldTransitive; + resetExecMode(oldPlatform); + InfrastructureAnalyzer.setLocalMaxMemory(oldMem); + Recompiler.reinitRecompiler(); + } + } +} diff --git a/src/test/scripts/functions/sparkexectype/SparkExecTypeBinary.dml b/src/test/scripts/functions/sparkexectype/SparkExecTypeBinary.dml new file mode 100644 index 00000000000..b15391d5c60 --- /dev/null +++ b/src/test/scripts/functions/sparkexectype/SparkExecTypeBinary.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Binary operations where exactly one operand is a scalar or a small vector and +# the other operand has a Spark output. These should be pulled into Spark by the +# transitive exec-type decision (matrix-scalar / matrix-vector broadcasting). +X = rand(rows=10000, cols=200, seed=42); +v = rand(rows=1, cols=200, seed=7); # small row vector (below block size) +c = rand(rows=10000, cols=1, seed=9); # tall column vector + +sp1 = X * 2.0; # matrix-scalar, spark input +sp2 = sp1 + v; # matrix + small row vector, spark input +sp3 = sp2 - c; # matrix - column vector, spark input +R = sum(sp3); +write(R, $1, format="text"); diff --git a/src/test/scripts/functions/sparkexectype/SparkExecTypeUnary.dml b/src/test/scripts/functions/sparkexectype/SparkExecTypeUnary.dml new file mode 100644 index 00000000000..7c2c6e3c03d --- /dev/null +++ b/src/test/scripts/functions/sparkexectype/SparkExecTypeUnary.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Large input forces a Spark-resident output. The following unary operations +# are individually cheap but their input already has a Spark output, so the +# transitive exec-type decision should pull them into Spark. +X = rand(rows=10000, cols=200, seed=42); +sp1 = X + ceil(X); # spark transformation -> spark output +sp2 = round(sp1); # unary on spark input +sp3 = abs(sp2); # unary on spark input +sp4 = sp3 * 2.0; # binary matrix-scalar on spark input +R = sum(sp4); +write(R, $1, format="text"); From e4dd03bd76530a59f4a10cfb597a193414b35b48 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 16 Jun 2026 20:51:56 +0000 Subject: [PATCH 4/4] Align unary/binary transitive Spark exec-type decision and add compile test Make the spark-specific decision refinement consistent between UnaryOp and BinaryOp: - UnaryOp: restore the input-is-not-checkpoint and single-parent guards, and drop the redundant `_etype != ExecType.SPARK` clause - BinaryOp: use the shared hasSparkOutput() helper instead of an inline optFindExecType() == SPARK check Add a compilation-verification test suite under component/compile that compiles a DML script into a runtime program and inspects instruction exec types without executing. CompilerTestBase provides the compile and plan-inspection helpers; SparkTransitiveExecTypeCompileTest verifies a CP-by-estimate unary on a Spark-resident input is pulled into Spark only when it is the sole consumer. --- .../java/org/apache/sysds/hops/BinaryOp.java | 2 +- .../java/org/apache/sysds/hops/UnaryOp.java | 5 +- .../component/compile/CompilerTestBase.java | 189 ++++++++++++++++++ .../SparkTransitiveExecTypeCompileTest.java | 59 ++++++ 4 files changed, 251 insertions(+), 4 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/compile/CompilerTestBase.java create mode 100644 src/test/java/org/apache/sysds/test/component/compile/SparkTransitiveExecTypeCompileTest.java diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 5accb497501..1bf15475b94 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -828,7 +828,7 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) { && (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation && sparkIn.getParent().size() == 1 // only one parent && !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec - && sparkIn.optFindExecType() == ExecType.SPARK // input was spark op. + && sparkIn.hasSparkOutput() // input was spark op. && !(sparkIn instanceof DataOp) // input is not checkpoint ) { // pull operation into spark diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 73e24eb17e2..c06d15961bc 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -508,13 +508,12 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto //spark-specific decision refinement (execute unary w/ spark input and //single parent also in spark because it's likely cheap and reduces intermediates) if(_etype == ExecType.CP // currently CP instruction - && _etype != ExecType.SPARK /// currently not SP. && _etypeForced != ExecType.CP // not forced as CP instruction && getInput(0).hasSparkOutput() // input is a spark instruction && (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame && !isDisallowedSparkOps() // is invalid spark instruction - // && !(getInput().get(0) instanceof DataOp) // input is not checkpoint - // && getInput(0).getParent().size() <= 1// unary is only parent + && !(getInput(0) instanceof DataOp) // input is not checkpoint + && getInput(0).getParent().size() == 1 // unary is only parent ) { //pull unary operation into spark _etype = ExecType.SPARK; diff --git a/src/test/java/org/apache/sysds/test/component/compile/CompilerTestBase.java b/src/test/java/org/apache/sysds/test/component/compile/CompilerTestBase.java new file mode 100644 index 00000000000..07ec9752928 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compile/CompilerTestBase.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compile; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.recompile.Recompiler; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.DMLTranslator; +import org.apache.sysds.parser.ParserFactory; +import org.apache.sysds.parser.ParserWrapper; +import org.apache.sysds.runtime.controlprogram.BasicProgramBlock; +import org.apache.sysds.runtime.controlprogram.ForProgramBlock; +import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock; +import org.apache.sysds.runtime.controlprogram.IfProgramBlock; +import org.apache.sysds.runtime.controlprogram.Program; +import org.apache.sysds.runtime.controlprogram.ProgramBlock; +import org.apache.sysds.runtime.controlprogram.WhileProgramBlock; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.cp.CPInstruction; +import org.apache.sysds.runtime.instructions.spark.SPInstruction; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.utils.Explain; +import org.apache.sysds.utils.stats.InfrastructureAnalyzer; +import org.junit.Assert; + +/** + * Base class for compilation-verification tests: compile a DML script into a runtime {@link Program} and inspect the + * resulting plan (instructions and their exec types) without ever executing it. + */ +public abstract class CompilerTestBase extends AutomatedTestBase { + + /** A small default local memory budget (8 MB) that forces large operations onto Spark in HYBRID mode. */ + public static final long SMALL_MEM_BUDGET = 8L * 1024 * 1024; + + @Override + public void setUp() { + // no test-configuration setup needed; scripts are compiled from in-memory strings + } + + /** + * Compile a DML script string into a runtime {@link Program} without executing it. + * + * @param dmlScript the DML source + * @param args named command-line arguments ($name -> value), may be null + * @param mode the global execution mode (e.g. {@link ExecMode#HYBRID}) + * @param localMaxMem the local memory budget in bytes used for memory-based exec-type decisions + * @return the compiled runtime program + */ + protected Program compile(String dmlScript, Map args, ExecMode mode, long localMaxMem) { + final ExecMode oldMode = DMLScript.getGlobalExecMode(); + final long oldMem = InfrastructureAnalyzer.getLocalMaxMemory(); + final DMLConfig oldConfig = ConfigurationManager.getDMLConfig(); + try { + ConfigurationManager.setGlobalConfig(new DMLConfig()); + DMLScript.setGlobalExecMode(mode); + InfrastructureAnalyzer.setLocalMaxMemory(localMaxMem); + OptimizerUtils.resetDefaultSize(); + + Map argVals = (args == null) ? new HashMap<>() : new HashMap<>(args); + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram prog = parser.parse(null, dmlScript, argVals); + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + dmlt.rewriteHopsDAG(prog); + dmlt.constructLops(prog); + dmlt.rewriteLopDAG(prog); + return dmlt.getRuntimeProgram(prog, ConfigurationManager.getDMLConfig()); + } + catch(Exception e) { + throw new RuntimeException("Failed to compile DML script:\n" + dmlScript, e); + } + finally { + DMLScript.setGlobalExecMode(oldMode); + InfrastructureAnalyzer.setLocalMaxMemory(oldMem); + ConfigurationManager.setGlobalConfig(oldConfig); + Recompiler.reinitRecompiler(); + } + } + + /** Recursively collect every instruction in the program, including control-flow predicates and function bodies. */ + protected List getInstructions(Program prog) { + List out = new ArrayList<>(); + for(ProgramBlock pb : prog.getProgramBlocks()) + collect(pb, out); + for(FunctionProgramBlock fpb : prog.getFunctionProgramBlocks(false).values()) + collect(fpb, out); + return out; + } + + private void collect(ProgramBlock pb, List out) { + if(pb instanceof BasicProgramBlock) { + out.addAll(((BasicProgramBlock) pb).getInstructions()); + } + else if(pb instanceof IfProgramBlock) { + IfProgramBlock ipb = (IfProgramBlock) pb; + out.addAll(ipb.getPredicate()); + ipb.getChildBlocksIfBody().forEach(c -> collect(c, out)); + ipb.getChildBlocksElseBody().forEach(c -> collect(c, out)); + } + else if(pb instanceof WhileProgramBlock) { + WhileProgramBlock wpb = (WhileProgramBlock) pb; + out.addAll(wpb.getPredicate()); + wpb.getChildBlocks().forEach(c -> collect(c, out)); + } + else if(pb instanceof ForProgramBlock) { // incl. ParForProgramBlock + ForProgramBlock fpb = (ForProgramBlock) pb; + out.addAll(fpb.getFromInstructions()); + out.addAll(fpb.getToInstructions()); + out.addAll(fpb.getIncrementInstructions()); + fpb.getChildBlocks().forEach(c -> collect(c, out)); + } + else if(pb instanceof FunctionProgramBlock) { + ((FunctionProgramBlock) pb).getChildBlocks().forEach(c -> collect(c, out)); + } + } + + /** All instructions whose opcode equals {@code opcode} (exact match). */ + protected List getByOpcode(Program prog, String opcode) { + return getInstructions(prog).stream().filter(i -> opcode.equals(i.getOpcode())) + .collect(Collectors.toList()); + } + + protected static boolean isSpark(Instruction inst) { + return inst instanceof SPInstruction; + } + + protected static boolean isCP(Instruction inst) { + return inst instanceof CPInstruction; + } + + /** Assert that at least one instruction with the given opcode exists and that all such instructions are Spark. */ + protected void assertSpark(Program prog, String opcode) { + assertExecType(prog, opcode, true); + } + + /** Assert that at least one instruction with the given opcode exists and that all such instructions are CP. */ + protected void assertCP(Program prog, String opcode) { + assertExecType(prog, opcode, false); + } + + private void assertExecType(Program prog, String opcode, boolean expectSpark) { + List matches = getByOpcode(prog, opcode); + Assert.assertFalse("Expected at least one '" + opcode + "' instruction but found none.\n" + + Explain.explain(prog), matches.isEmpty()); + for(Instruction inst : matches) { + boolean spark = isSpark(inst); + Assert.assertEquals("Instruction '" + opcode + "' expected exec type " + + (expectSpark ? "SPARK" : "CP") + " but was " + (spark ? "SPARK" : "CP") + ".\n" + + Explain.explain(prog), expectSpark, spark); + } + } + + protected long countSpark(Program prog) { + return getInstructions(prog).stream().filter(CompilerTestBase::isSpark).count(); + } + + protected String explain(Program prog) { + return Explain.explain(prog); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compile/SparkTransitiveExecTypeCompileTest.java b/src/test/java/org/apache/sysds/test/component/compile/SparkTransitiveExecTypeCompileTest.java new file mode 100644 index 00000000000..f60a9d61758 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compile/SparkTransitiveExecTypeCompileTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compile; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.controlprogram.Program; +import org.junit.Test; + +/** + * Verifies the transitive Spark exec-type refinement in {@link org.apache.sysds.hops.UnaryOp}: a CP-by-estimate unary on + * a Spark-resident input is pulled into Spark only when it is the sole consumer ({@code getParent().size() == 1}). + */ +public class SparkTransitiveExecTypeCompileTest extends CompilerTestBase { + + private static final String DML_HEADER = + "X = rand(rows=20000000, cols=8, seed=1);\n" + // ~1.2GB -> rand and colSums run on Spark + "v = colSums(X);\n"; // 1x8 Spark-resident vector (opcode uack+) + + @Test + public void singleConsumerUnaryPulledIntoSpark() { + String dml = DML_HEADER + + "r = round(v);\n" + // sole consumer of the Spark-resident vector -> pulled into Spark + "print(sum(r));\n"; + Program prog = compile(dml, null, ExecMode.HYBRID, SMALL_MEM_BUDGET); + + assertSpark(prog, "uack+"); // input genuinely has a Spark output + assertSpark(prog, "round"); // unary pulled into Spark (CP by mem estimate, single consumer) + } + + @Test + public void multiConsumerUnaryStaysCP() { + String dml = DML_HEADER + + "a = round(v);\n" + // v now has two consumers (round + abs) ... + "b = abs(v);\n" + + "print(sum(a) + sum(b));\n"; + Program prog = compile(dml, null, ExecMode.HYBRID, SMALL_MEM_BUDGET); + + assertSpark(prog, "uack+"); // input still has a Spark output ... + assertCP(prog, "round"); // ... but the multi-parent guard keeps both unaries in CP + assertCP(prog, "abs"); + } +}