diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml index bcae6e510..c3e7cb1dc 100644 --- a/.github/dependabot.yaml +++ b/.github/dependabot.yaml @@ -8,6 +8,8 @@ updates: groups: github-actions: patterns: ["*"] + cooldown: + default-days: 7 - package-ecosystem: "uv" directory: "/" schedule: @@ -15,3 +17,5 @@ updates: groups: python-uv-lock: patterns: ["*"] + cooldown: + default-days: 7 diff --git a/bench_fixes.py b/bench_fixes.py new file mode 100644 index 000000000..1015e5e1d --- /dev/null +++ b/bench_fixes.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +""" +Compare the two PR fixes independently and combined: + A) Baseline — list() + StringIO (V2, no declare_fields) + B) list-fix — no list() via declare_fields, still StringIO (V2) + C) spool-fix — list() still present, but SpooledTemporaryFile (V3) + D) both — declare_fields + V3 + +Measures wall-clock, tracemalloc peak heap, and RSS delta. +Uses GeneratingCommand path so list() materialisation in _execute_chunk_v2 +is actually exercised. +""" +import gc +import resource +import sys +import time +import tracemalloc +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) + +from splunklib.searchcommands.internals import ( + DiskBufferSettings, + RecordWriterV2, + RecordWriterV3, +) + +GB = 1024 ** 3 +MB = 1024 * 1024 + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- +RECORD_BYTES = 1_000 +GB_TARGET = 2.0 +CHUNK_ROWS = 50_000 +SPOOL_SIZE = 4 * MB + +N_RECORDS = int(GB_TARGET * GB / RECORD_BYTES) +PAYLOAD = "x" * RECORD_BYTES + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +class NullFile: + def write(self, d): return len(d) + def flush(self): pass + + +def rss_bytes() -> int: + ru = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + return ru if sys.platform == "darwin" else ru * 1024 + + +def record_gen(n: int): + for i in range(n): + yield {"index": str(i), "payload": PAYLOAD} + + +# --------------------------------------------------------------------------- +# Simulate GeneratingCommand._execute_chunk_v2 for each case +# --------------------------------------------------------------------------- +def run_baseline(n: int) -> tuple[float, int, int]: + """A: list() accumulation + StringIO (original behaviour).""" + w = RecordWriterV2(NullFile(), CHUNK_ROWS) + process = record_gen(n) + + gc.collect() + rss_before = rss_bytes() + tracemalloc.start() + t0 = time.perf_counter() + + while True: + count = 0 + records = [] + for row in process: + records.append(row) + count += 1 + if count == CHUNK_ROWS: + break + for row in records: + w.write_record(row) + finished = count < CHUNK_ROWS + w.write_chunk(finished=finished) + if finished: + break + + wall = time.perf_counter() - t0 + _, heap = tracemalloc.get_traced_memory() + tracemalloc.stop() + return wall, heap, max(0, rss_bytes() - rss_before) + + +def run_list_fix(n: int) -> tuple[float, int, int]: + """B: declare_fields removes list(), still StringIO.""" + w = RecordWriterV2(NullFile(), CHUNK_ROWS) + w.custom_fields.update(["index", "payload"]) + w.fields_declared = True + process = record_gen(n) + + gc.collect() + rss_before = rss_bytes() + tracemalloc.start() + t0 = time.perf_counter() + + while True: + count = 0 + # fields_declared path: stream directly + for row in process: + w.write_record(row) + count += 1 + if count == CHUNK_ROWS: + break + finished = count < CHUNK_ROWS + w.write_chunk(finished=finished) + if finished: + break + + wall = time.perf_counter() - t0 + _, heap = tracemalloc.get_traced_memory() + tracemalloc.stop() + return wall, heap, max(0, rss_bytes() - rss_before) + + +def run_spool_fix(n: int) -> tuple[float, int, int]: + """C: list() still used, but SpooledTemporaryFile (V3).""" + w = RecordWriterV3(NullFile(), CHUNK_ROWS, disk_buffer=DiskBufferSettings(spool_size=SPOOL_SIZE)) + process = record_gen(n) + + gc.collect() + rss_before = rss_bytes() + tracemalloc.start() + t0 = time.perf_counter() + + while True: + count = 0 + records = [] + for row in process: + records.append(row) + count += 1 + if count == CHUNK_ROWS: + break + for row in records: + w.write_record(row) + finished = count < CHUNK_ROWS + w.write_chunk(finished=finished) + if finished: + break + + wall = time.perf_counter() - t0 + _, heap = tracemalloc.get_traced_memory() + tracemalloc.stop() + return wall, heap, max(0, rss_bytes() - rss_before) + + +def run_both(n: int) -> tuple[float, int, int]: + """D: declare_fields + V3.""" + w = RecordWriterV3(NullFile(), CHUNK_ROWS, disk_buffer=DiskBufferSettings(spool_size=SPOOL_SIZE)) + w.custom_fields.update(["index", "payload"]) + w.fields_declared = True + process = record_gen(n) + + gc.collect() + rss_before = rss_bytes() + tracemalloc.start() + t0 = time.perf_counter() + + while True: + count = 0 + for row in process: + w.write_record(row) + count += 1 + if count == CHUNK_ROWS: + break + finished = count < CHUNK_ROWS + w.write_chunk(finished=finished) + if finished: + break + + wall = time.perf_counter() - t0 + _, heap = tracemalloc.get_traced_memory() + tracemalloc.stop() + return wall, heap, max(0, rss_bytes() - rss_before) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def main(): + print(f"\nFix comparison: {GB_TARGET:.1f} GB payload " + f"({N_RECORDS:,} records × {RECORD_BYTES} B) " + f"chunk_rows={CHUNK_ROWS:,} spool={SPOOL_SIZE // MB} MB\n") + + hdr = f"{'Variant':<35} {'Wall (s)':>8} {'Heap peak':>11} {'RSS delta':>11}" + print(hdr) + print("-" * len(hdr)) + + cases = [ + ("A baseline (list + StringIO)", run_baseline), + ("B list-fix (no list, StringIO)", run_list_fix), + ("C spool-fix (list + SpoolFile)", run_spool_fix), + ("D both (no list + SpoolFile)", run_both), + ] + + results = {} + for label, fn in cases: + wall, heap, rss = fn(N_RECORDS) + results[label] = (wall, heap, rss) + print(f"{label:<35} {wall:>8.2f} {heap / MB:>9.1f} MB {rss / MB:>9.1f} MB") + gc.collect() + + baseline_wall, baseline_heap, baseline_rss = results[cases[0][0]] + print() + print("Savings vs baseline:") + for label, _ in cases[1:]: + w, h, r = results[label] + dw = w - baseline_wall + dh = h - baseline_heap + dr = r - baseline_rss + print(f" {label:<35} wall {dw:>+7.2f}s heap {dh / MB:>+7.1f} MB ({dh / baseline_heap * 100:>+5.1f}%) " + f"rss {dr / MB:>+7.1f} MB") + + print() + # Which fix dominates heap savings? + _, h_b, _ = results[cases[1][0]] # list-fix + _, h_c, _ = results[cases[2][0]] # spool-fix + list_saving = (baseline_heap - h_b) / baseline_heap * 100 + spool_saving = (baseline_heap - h_c) / baseline_heap * 100 + print(f"Heap: list-fix alone saves {list_saving:.1f}%, spool-fix alone saves {spool_saving:.1f}%") + + +if __name__ == "__main__": + main() diff --git a/bench_writers.py b/bench_writers.py new file mode 100644 index 000000000..e5109a1cd --- /dev/null +++ b/bench_writers.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +""" +Standalone benchmark: RecordWriterV2 (StringIO) vs RecordWriterV3 (SpooledFile). + +Streams N GB of synthetic records through each writer into /dev/null. +Measures wall-clock time, peak tracemalloc heap, and peak RSS. + +Usage: + python bench_writers.py [--gb 10] [--record-bytes 1000] [--chunk-rows 50000] +""" +import argparse +import gc +import os +import resource +import shutil +import sys +import time +import tracemalloc +from io import BytesIO, TextIOWrapper +from pathlib import Path + +# Make sure we can import splunklib from the worktree +sys.path.insert(0, str(Path(__file__).parent)) + +from splunklib.searchcommands.internals import ( + DiskBufferSettings, + RecordWriterV2, + RecordWriterV3, +) + + +# --------------------------------------------------------------------------- +# /dev/null sink (binary) +# --------------------------------------------------------------------------- +class NullFile: + """Binary sink — accepts bytes, discards them.""" + def write(self, data: bytes) -> int: + return len(data) + def flush(self): + pass + + +# --------------------------------------------------------------------------- +# Synthetic record generator (never materialises all records) +# --------------------------------------------------------------------------- +def record_stream(n_records: int, record_bytes: int): + """Yield dicts with a fixed-size payload field.""" + payload = "x" * record_bytes + for i in range(n_records): + yield {"index": str(i), "payload": payload} + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- +def rss_bytes() -> int: + ru = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + return ru if sys.platform == "darwin" else ru * 1024 + + +def run_benchmark(writer, records_iter, chunk_rows: int) -> tuple[float, int, int]: + """ + Pump records through writer in chunks of `chunk_rows`. + Returns (wall_seconds, peak_tracemalloc_bytes, peak_rss_bytes). + """ + gc.collect() + rss_before = rss_bytes() + + tracemalloc.start() + t0 = time.perf_counter() + + count = 0 + for record in records_iter: + writer.write_record(record) + count += 1 + if count == chunk_rows: + writer.write_chunk(finished=False) + count = 0 + + writer.write_chunk(finished=True) + + wall = time.perf_counter() - t0 + _, peak_heap = tracemalloc.get_traced_memory() + tracemalloc.stop() + + rss_after = rss_bytes() + peak_rss_delta = max(0, rss_after - rss_before) + + return wall, peak_heap, peak_rss_delta + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--gb", type=float, default=10.0, help="Total payload GB") + parser.add_argument("--record-bytes", type=int, default=1_000, help="Bytes per record") + parser.add_argument("--chunk-rows", type=int, default=50_000, help="Rows per chunk (maxresultrows)") + parser.add_argument("--spool-size", type=int, default=4 * 1024 * 1024, help="V3 spool_size bytes") + args = parser.parse_args() + + total_bytes = int(args.gb * 1024 ** 3) + n_records = total_bytes // args.record_bytes + actual_gb = (n_records * args.record_bytes) / 1024 ** 3 + + mb = 1024 * 1024 + gb = 1024 ** 3 + + print(f"\nBenchmark: {actual_gb:.2f} GB payload " + f"({n_records:,} records × {args.record_bytes} B) " + f"chunk_rows={args.chunk_rows:,} spool_size={args.spool_size // mb} MB\n") + print(f"{'Writer':<30} {'Wall (s)':>10} {'Heap peak':>12} {'RSS delta':>12} {'Throughput':>14}") + print("-" * 84) + + results = {} + + for label, make_writer in [ + ("RecordWriterV2 (StringIO)", lambda: RecordWriterV2(NullFile(), args.chunk_rows)), + (f"RecordWriterV3 (spool={args.spool_size // mb}MB)", lambda: RecordWriterV3( + NullFile(), args.chunk_rows, disk_buffer=DiskBufferSettings(spool_size=args.spool_size) + )), + ]: + writer = make_writer() + gen = record_stream(n_records, args.record_bytes) + wall, heap, rss = run_benchmark(writer, gen, args.chunk_rows) + throughput = (n_records * args.record_bytes) / wall / gb + results[label] = (wall, heap, rss, throughput) + print(f"{label:<30} {wall:>10.2f} {heap / mb:>10.1f} MB {rss / mb:>10.1f} MB {throughput:>12.2f} GB/s") + gc.collect() + + # Delta row + labels = list(results.keys()) + w2, h2, r2, tp2 = results[labels[0]] + w3, h3, r3, tp3 = results[labels[1]] + print("-" * 84) + print(f"{'Delta (V3 - V2)':<30} {w3 - w2:>+10.2f} {(h3 - h2) / mb:>+10.1f} MB " + f"{(r3 - r2) / mb:>+10.1f} MB {tp3 - tp2:>+12.2f} GB/s") + print() + + heap_reduction_pct = (h2 - h3) / h2 * 100 if h2 else 0 + print(f"V3 heap reduction: {heap_reduction_pct:.1f}% vs V2") + print(f"Total CSV written: ~{n_records * args.record_bytes * 2 / gb:.1f} GB " # ~2x due to __mv_ columns + f"(raw payload × ~2 for __mv_ encoding)") + + +if __name__ == "__main__": + main() diff --git a/splunklib/searchcommands/__init__.py b/splunklib/searchcommands/__init__.py index 60904a1f5..ac4cf5684 100644 --- a/splunklib/searchcommands/__init__.py +++ b/splunklib/searchcommands/__init__.py @@ -153,6 +153,7 @@ from splunklib.searchcommands.eventing_command import EventingCommand from splunklib.searchcommands.external_search_command import ExternalSearchCommand, execute from splunklib.searchcommands.generating_command import GeneratingCommand +from splunklib.searchcommands.internals import DiskBufferSettings from splunklib.searchcommands.reporting_command import ReportingCommand from splunklib.searchcommands.search_command import SearchMetric, dispatch from splunklib.searchcommands.streaming_command import StreamingCommand @@ -173,6 +174,7 @@ "Boolean", "Code", "Configuration", + "DiskBufferSettings", "Duration", "EventingCommand", "ExternalSearchCommand", diff --git a/splunklib/searchcommands/generating_command.py b/splunklib/searchcommands/generating_command.py index 334265449..1fa7649f0 100644 --- a/splunklib/searchcommands/generating_command.py +++ b/splunklib/searchcommands/generating_command.py @@ -208,17 +208,23 @@ def _execute(self, ifile, process): def _execute_chunk_v2(self, process, chunk): count = 0 - records = [] - for row in process: - records.append(row) - count += 1 - if count == self._record_writer._maxresultrows: - break - - for row in records: - self._record_writer.write_record(row) - - if count == self._record_writer._maxresultrows: + if self._record_writer.fields_declared: + for row in process: + self._record_writer.write_record(row) + count += 1 + if count == self._record_writer.maxresultrows: + break + else: + records = [] + for row in process: + records.append(row) + count += 1 + if count == self._record_writer.maxresultrows: + break + for row in records: + self._record_writer.write_record(row) + + if count == self._record_writer.maxresultrows: self._finished = False else: self._finished = True diff --git a/splunklib/searchcommands/internals.py b/splunklib/searchcommands/internals.py index 40e468554..d5e8c6f28 100644 --- a/splunklib/searchcommands/internals.py +++ b/splunklib/searchcommands/internals.py @@ -15,10 +15,13 @@ import csv import gzip import re +import shutil import sys +import tempfile import urllib.parse import warnings from collections import OrderedDict, deque, namedtuple +from dataclasses import dataclass from io import StringIO, TextIOWrapper from itertools import chain from json import JSONDecoder, JSONEncoder @@ -229,6 +232,23 @@ def replace(match): # endregion +@dataclass(frozen=True, kw_only=True) +class DiskBufferSettings: + """Controls disk-spill buffering for RecordWriterV3. + + When set on a command via ``@Configuration(disk_buffer=DiskBufferSettings())``, + the CSV reply buffer spills to a temp file instead of accumulating entirely in + RAM. This trades some I/O overhead for a bounded memory footprint regardless + of result set size. + + Args: + spool_size: Bytes kept in RAM before spilling to disk. Defaults to 4 MB. + Set to 0 to always write directly to disk. + """ + + spool_size: int = 4 * 1024 * 1024 + + class ConfigurationSettingsType(type): """Metaclass for constructing ConfigurationSettings classes. @@ -306,6 +326,12 @@ def validate_configuration_setting(specification, name, value): constraint=lambda value: value in ("events", "reporting", "streaming"), supporting_protocols=[2], ), + # SDK-only: never sent to Splunk. supporting_protocols=[] keeps it out of iteritems(). + "disk_buffer": specification( + type=DiskBufferSettings, + constraint=None, + supporting_protocols=[], + ), } @@ -461,6 +487,11 @@ def __init__(self, ofile, maxresultrows=None): self._pending_record_count = 0 self._committed_record_count = 0 self.custom_fields = set() + self.fields_declared = False + + @property + def maxresultrows(self): + return self._maxresultrows @property def is_flushed(self): @@ -527,7 +558,10 @@ def write_record(self, record): def write_records(self, records): self._ensure_validity() - records = [] if records is NotImplemented else list(records) + if records is NotImplemented: + return + if not self.fields_declared: + records = list(records) write_record = self._write_record for record in records: write_record(record) @@ -797,3 +831,78 @@ def _write_chunk(self, metadata, body): self.write(body) self._ofile.flush() self._flushed = True + + +class RecordWriterV3(RecordWriterV2): + """RecordWriterV2 with disk-spill buffering via SpooledTemporaryFile. + + Used when a command is configured with ``@Configuration(disk_buffer=DiskBufferSettings())``. + The CSV reply buffer spills to a temp file instead of accumulating in a StringIO, + so peak RAM is bounded by ``spool_size`` rather than the full result payload. + """ + + def __init__(self, ofile, maxresultrows=None, disk_buffer=None): + if disk_buffer is None: + raise ValueError("RecordWriterV3 requires a DiskBufferSettings instance") + self._disk_buffer = disk_buffer + super().__init__(ofile, maxresultrows) + # Replace the StringIO created by RecordWriter.__init__ with a spool file + raw = tempfile.SpooledTemporaryFile( + max_size=self._disk_buffer.spool_size, + mode="w+b", + ) + self._buffer_raw = raw + self._buffer = TextIOWrapper(raw, encoding="utf-8", newline="") + self._writer = csv.writer(self._buffer, dialect=CsvDialect) + self._writerow = self._writer.writerow + + def write_chunk(self, finished=None): + inspector = self._inspector + self._committed_record_count += self.pending_record_count + self._chunk_count += 1 + + if len(inspector) == 0: + inspector = None + + metadata = [("inspector", inspector), ("finished", finished)] + + if metadata: + metadata_bytes = str( + "".join( + self._iterencode_json( + dict((n, v) for n, v in metadata if v is not None), 0 + ) + ) + ).encode("utf-8") + metadata_length = len(metadata_bytes) + else: + metadata_bytes = b"" + metadata_length = 0 + + # Flush TextIOWrapper so all pending CSV data lands in the binary spool file + self._buffer.flush() + + self._buffer_raw.seek(0, 2) + body_length = self._buffer_raw.tell() + self._buffer_raw.seek(0) + + if metadata_length > 0 or body_length > 0: + start_line = f"chunked 1.0,{metadata_length},{body_length}\n".encode("utf-8") + self._ofile.write(start_line) + self._ofile.write(metadata_bytes) + shutil.copyfileobj(self._buffer_raw, self._ofile, length=65536) + self._ofile.flush() + self._flushed = True + + self._clear() + + def _clear(self): + # Flush wrapper, reset the raw spool, re-sync wrapper position + self._buffer.flush() + self._buffer_raw.seek(0) + self._buffer_raw.truncate() + # Discard the wrapper's internal position cache by seeking it too + self._buffer.seek(0) + self._inspector.clear() + self._pending_record_count = 0 + self._fieldnames = None diff --git a/splunklib/searchcommands/search_command.py b/splunklib/searchcommands/search_command.py index 8716cec54..8f916f690 100644 --- a/splunklib/searchcommands/search_command.py +++ b/splunklib/searchcommands/search_command.py @@ -34,10 +34,11 @@ import splunklib.searchcommands.environment as environment from splunklib.client import Service -from splunklib.searchcommands.decorators import Option +from splunklib.searchcommands.decorators import ConfigurationSetting, Option from splunklib.searchcommands.internals import ( CommandLineParser, CsvDialect, + DiskBufferSettings, InputHeader, Message, MetadataDecoder, @@ -46,6 +47,7 @@ Recorder, RecordWriterV1, RecordWriterV2, + RecordWriterV3, ) from splunklib.searchcommands.validators import Boolean from splunklib.utils import ensure_str @@ -166,6 +168,30 @@ def gen_record(self, **record): self._record_writer.custom_fields |= set(record.keys()) return record + def declare_fields(self, *field_names: str) -> None: + """Pre-declare all custom fields before any records are yielded. + + When every extra field is declared upfront the SDK streams records + lazily without materialising them into a list first, reducing peak + memory for large result sets. + + Must be called before the first ``yield`` in ``generate()`` or + ``stream()``. Incompatible with ``add_field``/``gen_record``, which + populate fields lazily and require the full-materialisation path. + + Example:: + + @Configuration() + class MyCommand(GeneratingCommand): + def generate(self): + self.declare_fields('extra_field') + for row in huge_dataset(): + row['extra_field'] = compute(row) + yield row + """ + self._record_writer.custom_fields.update(field_names) + self._record_writer.fields_declared = True + record = Option( doc=""" **Syntax:** record= @@ -746,9 +772,12 @@ def _process_protocol_v2(self, argv, ifile, ofile): # Write search command configuration for consumption by splunkd # noinspection PyBroadException try: - self._record_writer = RecordWriterV2( - ofile, getattr(self._metadata.searchinfo, "maxresultrows", None) - ) + _disk_buffer = getattr(self._configuration, "disk_buffer", None) + _maxresultrows = getattr(self._metadata.searchinfo, "maxresultrows", None) + if _disk_buffer is not None: + self._record_writer = RecordWriterV3(ofile, _maxresultrows, disk_buffer=_disk_buffer) + else: + self._record_writer = RecordWriterV2(ofile, _maxresultrows) self.fieldnames = [] self.options.reset() @@ -1135,6 +1164,28 @@ def iteritems(self): # endregion + # region SDK-only settings (not sent to Splunk) + + disk_buffer = ConfigurationSetting( + doc=""" + Enable disk-spill buffering for the CSV reply buffer. + + Set to a :class:`DiskBufferSettings` instance to have the SDK write the + CEXC reply payload to a :mod:`tempfile.SpooledTemporaryFile` instead of + a ``StringIO``. The spool file stays in RAM up to ``spool_size`` bytes, + then spills to a temp directory. + + This trades I/O overhead for bounded peak memory usage — useful for + commands that generate or pass through very large result sets. + + Default: :const:`None` (StringIO, original behaviour) + + Supported by: SDK only (not sent to Splunk) + """ + ) + + # endregion + # endregion diff --git a/tests/unit/searchcommands/test_disk_buffer.py b/tests/unit/searchcommands/test_disk_buffer.py new file mode 100644 index 000000000..2f2c203c2 --- /dev/null +++ b/tests/unit/searchcommands/test_disk_buffer.py @@ -0,0 +1,364 @@ +""" +Tests for RecordWriterV3 / DiskBufferSettings disk-spill buffering. + +Two concerns: + 1. Correctness: disk_buffer produces identical output to the default StringIO path. + 2. Memory: RecordWriterV3 keeps the CSV reply buffer off the Python heap. + Measured via tracemalloc (Python-level allocations only), which isolates the + StringIO vs SpooledTemporaryFile difference from Python-object overhead. + +Benchmark (CPU + RAM): + test_benchmark_v2_vs_v3 prints a wall-clock + tracemalloc comparison table. + It never asserts on performance — only on correctness — so CI always passes. + +Why tracemalloc instead of ru_maxrss: + resource.getrusage().ru_maxrss is the process-lifetime peak RSS (monotonically + non-decreasing). In a multi-test pytest session the baseline is already high + from earlier tests, making delta measurements unreliable. tracemalloc tracks + Python-level heap allocations only, resettable per-test, which cleanly isolates + the StringIO vs SpooledTemporaryFile buffer difference. +""" + +import io +import time +import tracemalloc +from collections.abc import Generator, Iterator + +import pytest + +from splunklib.searchcommands import ( + Configuration, + DiskBufferSettings, + GeneratingCommand, + StreamingCommand, +) + +from . import chunked_data_stream as chunky + +RECORD_SIZE_BYTES = 1_000 +N_RECORDS = 50_000 +EXPECTED_TOTAL_BYTES = RECORD_SIZE_BYTES * N_RECORDS # ~50 MB + +# RecordWriterV3 keeps the CSV bytes off the Python heap (spilled to disk). +# Allowed peak: spool_size (4 MB default) + small per-record overhead. +# We allow 10% of total payload as generous headroom for encoder buffers etc. +DISK_BUFFER_HEAP_THRESHOLD = EXPECTED_TOTAL_BYTES * 0.10 + + +# --------------------------------------------------------------------------- +# Correctness: disk_buffer output matches default StringIO output +# --------------------------------------------------------------------------- + + +def test_disk_buffer_streaming_output_matches_default() -> None: + """RecordWriterV3 must produce byte-for-byte identical output to RecordWriterV2.""" + large_value = "x" * 100 + records = [{"payload": large_value, "idx": str(i)} for i in range(200)] + + @Configuration() + class DefaultCommand(StreamingCommand): + def stream(self, records: Iterator[dict]) -> Generator[dict]: + yield from records + + @Configuration(disk_buffer=DiskBufferSettings(spool_size=1024)) + class DiskCommand(StreamingCommand): + def stream(self, records: Iterator[dict]) -> Generator[dict]: + yield from records + + def run_command(cmd_class: type) -> bytes: + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_data_chunk(records, finished=True)) + ifile.seek(0) + ofile = io.BytesIO() + cmd_class()._process_protocol_v2([], ifile, ofile) + return ofile.getvalue() + + default_out = run_command(DefaultCommand) + disk_out = run_command(DiskCommand) + + assert default_out == disk_out, ( + f"disk_buffer output differs from default.\n" + f"default length: {len(default_out)}, disk length: {len(disk_out)}" + ) + + +def test_disk_buffer_generating_output_matches_default() -> None: + """RecordWriterV3 GeneratingCommand output must match RecordWriterV2.""" + + @Configuration() + class DefaultGenCommand(GeneratingCommand): + def generate(self) -> Generator[dict]: + for i in range(200): + yield {"idx": str(i), "val": "y" * 100} + + @Configuration(disk_buffer=DiskBufferSettings(spool_size=1024)) + class DiskGenCommand(GeneratingCommand): + def generate(self) -> Generator[dict]: + for i in range(200): + yield {"idx": str(i), "val": "y" * 100} + + def run_command(cmd_class: type) -> bytes: + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_chunk({"action": "execute"})) + ifile.seek(0) + ofile = io.BytesIO() + cmd_class()._process_protocol_v2([], ifile, ofile) + return ofile.getvalue() + + default_out = run_command(DefaultGenCommand) + disk_out = run_command(DiskGenCommand) + + assert default_out == disk_out + + +# --------------------------------------------------------------------------- +# Memory: disk_buffer keeps CSV bytes off the Python heap (tracemalloc) +# --------------------------------------------------------------------------- + + +def _measure_heap_streaming(use_disk_buffer: bool) -> int: + """Return peak Python heap growth (bytes) during a 50k-record streaming run.""" + large_value = "x" * RECORD_SIZE_BYTES + data = [{"payload": large_value} for _ in range(N_RECORDS)] + + if use_disk_buffer: + @Configuration(disk_buffer=DiskBufferSettings()) + class DiskStreamCmd(StreamingCommand): + def stream(self, records: Iterator[dict]) -> Generator[dict]: + yield from records + cmd_class = DiskStreamCmd + else: + @Configuration() + class DefaultStreamCmd(StreamingCommand): + def stream(self, records: Iterator[dict]) -> Generator[dict]: + yield from records + cmd_class = DefaultStreamCmd + + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_data_chunk(data, finished=True)) + ifile.seek(0) + ofile = io.BytesIO() + + tracemalloc.start() + cmd_class()._process_protocol_v2([], ifile, ofile) + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + return peak + + +def _measure_heap_generating(use_disk_buffer: bool) -> int: + """Return peak Python heap growth (bytes) during a 50k-record generating run.""" + large_value = "x" * RECORD_SIZE_BYTES + + if use_disk_buffer: + @Configuration(disk_buffer=DiskBufferSettings()) + class DiskGenCmd(GeneratingCommand): + def generate(self) -> Generator[dict]: + for i in range(N_RECORDS): + yield {"index": str(i), "payload": large_value} + cmd_class = DiskGenCmd + else: + @Configuration() + class DefaultGenCmd(GeneratingCommand): + def generate(self) -> Generator[dict]: + for i in range(N_RECORDS): + yield {"index": str(i), "payload": large_value} + cmd_class = DefaultGenCmd + + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_chunk({"action": "execute"})) + ifile.seek(0) + ofile = io.BytesIO() + + tracemalloc.start() + cmd_class()._process_protocol_v2([], ifile, ofile) + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + return peak + + +def test_disk_buffer_streaming_heap_less_than_default() -> None: + """RecordWriterV3 must use less Python heap than RecordWriterV2 for large payloads. + + V2 holds the full CSV in a StringIO on the Python heap. + V3 spills CSV bytes to disk; only up to spool_size stays in RAM. + """ + peak_v2 = _measure_heap_streaming(use_disk_buffer=False) + peak_v3 = _measure_heap_streaming(use_disk_buffer=True) + + mb = 1024 * 1024 + assert peak_v3 < peak_v2, ( + f"RecordWriterV3 should use less Python heap than V2.\n" + f"V2 peak: {peak_v2 / mb:.1f} MB, V3 peak: {peak_v3 / mb:.1f} MB" + ) + + +def test_disk_buffer_generating_heap_less_than_default() -> None: + """RecordWriterV3 must use less Python heap than RecordWriterV2 for GeneratingCommand.""" + peak_v2 = _measure_heap_generating(use_disk_buffer=False) + peak_v3 = _measure_heap_generating(use_disk_buffer=True) + + mb = 1024 * 1024 + assert peak_v3 < peak_v2, ( + f"RecordWriterV3 should use less Python heap than V2.\n" + f"V2 peak: {peak_v2 / mb:.1f} MB, V3 peak: {peak_v3 / mb:.1f} MB" + ) + + +# --------------------------------------------------------------------------- +# Benchmark: wall-clock time + tracemalloc heap for V2 vs V3 +# --------------------------------------------------------------------------- + + +def test_benchmark_v2_vs_v3(capsys: pytest.CaptureFixture[str]) -> None: + """Measure and print wall-clock time + peak heap for RecordWriterV2 vs V3. + + Never fails on performance — only prints the comparison table. + """ + mb = 1024 * 1024 + + def run(use_disk: bool) -> tuple[float, int]: + large_value = "x" * RECORD_SIZE_BYTES + data = [{"payload": large_value} for _ in range(N_RECORDS)] + + if use_disk: + @Configuration(disk_buffer=DiskBufferSettings()) + class BenchDisk(StreamingCommand): + def stream(self, records: Iterator[dict]) -> Generator[dict]: + yield from records + cmd_class = BenchDisk + else: + @Configuration() + class BenchDefault(StreamingCommand): + def stream(self, records: Iterator[dict]) -> Generator[dict]: + yield from records + cmd_class = BenchDefault + + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_data_chunk(data, finished=True)) + ifile.seek(0) + ofile = io.BytesIO() + + tracemalloc.start() + t0 = time.perf_counter() + cmd_class()._process_protocol_v2([], ifile, ofile) + wall = time.perf_counter() - t0 + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + return wall, peak + + wall_v2, heap_v2 = run(use_disk=False) + wall_v3, heap_v3 = run(use_disk=True) + + with capsys.disabled(): + print( + f"\n" + f"RecordWriter V2 vs V3 benchmark " + f"({N_RECORDS} records x {RECORD_SIZE_BYTES} B = {EXPECTED_TOTAL_BYTES / mb:.0f} MB payload)\n" + f"{'':26s} {'Wall (s)':>10} {'Heap peak':>12}\n" + f"{'RecordWriterV2 (StringIO)':26s} {wall_v2:>10.3f} {heap_v2 / mb:>10.1f} MB\n" + f"{'RecordWriterV3 (SpoolFile)':26s} {wall_v3:>10.3f} {heap_v3 / mb:>10.1f} MB\n" + f"{'Overhead':26s} {(wall_v3 - wall_v2):>+10.3f} {(heap_v3 - heap_v2) / mb:>+10.1f} MB\n" + ) + + +# --------------------------------------------------------------------------- +# declare_fields: opt-in streaming without list() materialisation +# --------------------------------------------------------------------------- + + +def test_declare_fields_streaming_output_matches_default() -> None: + """declare_fields() path must produce identical output to the default list() path.""" + records_data = [{"payload": "x" * 100, "idx": str(i), "extra": str(i * 2)} for i in range(200)] + + @Configuration() + class DefaultCmd(StreamingCommand): + def stream(self, records: Iterator[dict]) -> Generator[dict]: + yield from records + + @Configuration() + class DeclaredCmd(StreamingCommand): + def stream(self, records: Iterator[dict]) -> Generator[dict]: + self.declare_fields("payload", "idx", "extra") + yield from records + + def run_command(cmd_class: type) -> bytes: + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_data_chunk(records_data, finished=True)) + ifile.seek(0) + ofile = io.BytesIO() + cmd_class()._process_protocol_v2([], ifile, ofile) + return ofile.getvalue() + + assert run_command(DefaultCmd) == run_command(DeclaredCmd) + + +def test_declare_fields_generating_output_matches_default() -> None: + """declare_fields() on GeneratingCommand must produce identical output.""" + + @Configuration() + class DefaultGen(GeneratingCommand): + def generate(self) -> Generator[dict]: + for i in range(200): + yield {"idx": str(i), "val": "y" * 50} + + @Configuration() + class DeclaredGen(GeneratingCommand): + def generate(self) -> Generator[dict]: + self.declare_fields("idx", "val") + for i in range(200): + yield {"idx": str(i), "val": "y" * 50} + + def run_command(cmd_class: type) -> bytes: + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_chunk({"action": "execute"})) + ifile.seek(0) + ofile = io.BytesIO() + cmd_class()._process_protocol_v2([], ifile, ofile) + return ofile.getvalue() + + assert run_command(DefaultGen) == run_command(DeclaredGen) + + +def test_declare_fields_sets_flag() -> None: + """declare_fields() must set fields_declared=True on the record writer.""" + + @Configuration() + class FlagCmd(GeneratingCommand): + def generate(self) -> Generator[dict]: + self.declare_fields("a", "b") + yield {"a": "1", "b": "2"} + + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_chunk({"action": "execute"})) + ifile.seek(0) + cmd = FlagCmd() + cmd._process_protocol_v2([], ifile, io.BytesIO()) + assert cmd._record_writer.fields_declared is True + assert "a" in cmd._record_writer.custom_fields + assert "b" in cmd._record_writer.custom_fields + + +def test_declare_fields_without_declaration_flag_is_false() -> None: + """Without declare_fields(), fields_declared must remain False.""" + + @Configuration() + class NoDeclareCmd(GeneratingCommand): + def generate(self) -> Generator[dict]: + yield {"a": "1"} + + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_chunk({"action": "execute"})) + ifile.seek(0) + cmd = NoDeclareCmd() + cmd._process_protocol_v2([], ifile, io.BytesIO()) + assert cmd._record_writer.fields_declared is False diff --git a/tests/unit/searchcommands/test_oom_reproducer.py b/tests/unit/searchcommands/test_oom_reproducer.py new file mode 100644 index 000000000..ad82e143a --- /dev/null +++ b/tests/unit/searchcommands/test_oom_reproducer.py @@ -0,0 +1,150 @@ +""" +Reproducer for issue #687 / PR #800: streaming commands materialise the full +record iterator into a list before writing, causing high memory usage on large +result sets. + +Two paths are exercised: + 1. StreamingCommand → write_records(process(records)) via _execute_chunk_v2 + in search_command.py (base class). + 2. GeneratingCommand → _execute_chunk_v2 in generating_command.py (own override), + which collected all rows into `records = []` before writing. + +Protocol ceiling (SPL-103525 / DVPL-6448): + The CEXC protocol (chunked-command-protocol.txt) is strictly request-response: + Splunk sends one execute chunk, the SDK must reply with exactly one chunk. + Footnote [1] of the spec notes "Pipelining may be supported in future versions". + Until SPL-103525 ships, RecordWriterV2 must buffer the entire CSV reply in its + StringIO buffer before flushing — partial mid-chunk writes are not possible. + + Consequence: ~1x CSV payload buffering is unavoidable and these tests are + marked xfail(strict=False). They will show as XFAIL (expected failure) until + SPL-103525 is resolved and the SDK is updated to use partial chunks. + + What IS avoidable (and what PR #800 targets) is the extra Python-object-level + copy: list(records) in write_records() and records=[] in _execute_chunk_v2. + Removing those copies saves roughly 1x Python-object overhead on top of the + unavoidable CSV buffer, but cannot bring RSS growth below ~1x payload. +""" + +import io +import resource +import sys +from collections.abc import Generator, Iterator + +import pytest + +from splunklib.searchcommands import Configuration, GeneratingCommand, StreamingCommand + +from . import chunked_data_stream as chunky + +RECORD_SIZE_BYTES = 1_000 +N_RECORDS = 50_000 +EXPECTED_TOTAL_BYTES = RECORD_SIZE_BYTES * N_RECORDS # ~50 MB + +# A correctly fixed SDK still buffers ~1x CSV payload in RecordWriter._buffer +# (required by the CEXC protocol). Flag if growth exceeds 20 % — this threshold +# can only be met once SPL-103525 ships partial chunk support. +OOM_THRESHOLD_BYTES = EXPECTED_TOTAL_BYTES * 0.20 + +_XFAIL_REASON = ( + "CEXC protocol requires full-chunk buffering in RecordWriter._buffer " + "(RecordWriterV2.flush(partial=True) is a no-op until SPL-103525 ships). " + "~1x CSV payload buffering is unavoidable regardless of list() removal. " + "Remove xfail once SPL-103525 is resolved and partial chunk support is wired up." +) + + +def _rss_bytes() -> int: + # resource.getrusage returns kilobytes on Linux, bytes on macOS. + ru = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if sys.platform == "darwin": + return ru # bytes + return ru * 1024 # kilobytes → bytes + + +# --------------------------------------------------------------------------- +# Streaming command reproducer (issue #687 root cause: write_records list()) +# --------------------------------------------------------------------------- + + +@pytest.mark.xfail(strict=False, reason=_XFAIL_REASON) +def test_streaming_command_does_not_buffer_all_records() -> None: + """ + StreamingCommand must not materialise all records into memory before writing. + + The base-class write_records() used to call list(records), which forced full + materialisation of the iterator on top of the unavoidable CSV buffer in + RecordWriter._buffer. Removing list() halves the peak RSS but cannot bring + it below ~1x payload while CEXC partial chunk support is absent (SPL-103525). + """ + large_value = "x" * RECORD_SIZE_BYTES + + @Configuration() + class PassThroughCommand(StreamingCommand): + def stream(self, records: Iterator[dict]) -> Generator[dict]: + yield from records + + data = [{"payload": large_value} for _ in range(N_RECORDS)] + + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_data_chunk(data, finished=True)) + ifile.seek(0) + ofile = io.BytesIO() + + rss_before = _rss_bytes() + cmd = PassThroughCommand() + cmd._process_protocol_v2([], ifile, ofile) + rss_after = _rss_bytes() + + rss_growth = rss_after - rss_before + assert rss_growth < OOM_THRESHOLD_BYTES, ( + f"Streaming command buffered too much: RSS grew by {rss_growth / 1024 / 1024:.1f} MB " + f"(threshold {OOM_THRESHOLD_BYTES / 1024 / 1024:.1f} MB). " + f"Total payload was {EXPECTED_TOTAL_BYTES / 1024 / 1024:.1f} MB. " + "Likely cause: write_records() is calling list(records) or SPL-103525 is still unresolved." + ) + + +# --------------------------------------------------------------------------- +# Generating command reproducer (generating_command._execute_chunk_v2 buffer) +# --------------------------------------------------------------------------- + + +@pytest.mark.xfail(strict=False, reason=_XFAIL_REASON) +def test_generating_command_does_not_buffer_all_records() -> None: + """ + GeneratingCommand._execute_chunk_v2 must not accumulate all yielded rows + into a Python list before writing. + + The original code collected rows into `records = []` then wrote them in a + second pass, doubling peak memory on top of the unavoidable CSV buffer. + Removing the list halves the peak RSS but the CSV buffer floor remains until + SPL-103525 ships. + """ + large_value = "x" * RECORD_SIZE_BYTES + + @Configuration() + class LargeGeneratorCommand(GeneratingCommand): + def generate(self) -> Generator[dict]: + for i in range(N_RECORDS): + yield {"index": str(i), "payload": large_value} + + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_chunk({"action": "execute"})) + ifile.seek(0) + ofile = io.BytesIO() + + rss_before = _rss_bytes() + generator = LargeGeneratorCommand() + generator._process_protocol_v2([], ifile, ofile) + rss_after = _rss_bytes() + + rss_growth = rss_after - rss_before + assert rss_growth < OOM_THRESHOLD_BYTES, ( + f"Generating command buffered too much: RSS grew by {rss_growth / 1024 / 1024:.1f} MB " + f"(threshold {OOM_THRESHOLD_BYTES / 1024 / 1024:.1f} MB). " + f"Total payload was {EXPECTED_TOTAL_BYTES / 1024 / 1024:.1f} MB. " + "Likely cause: _execute_chunk_v2 is collecting rows into records=[] or SPL-103525 is still unresolved." + )