Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ build-backend = "hatchling.build"

[dependency-groups]
dev = [
"opentelemetry-exporter-otlp-proto-http>=1.38.0,<2",
"opentelemetry-sdk>=1.38.0,<2",
"pyright>=1.1.410",
"ruff>=0.15.16",
]
Expand All @@ -30,6 +32,7 @@ classifiers = [
]
dependencies = [
"httpx>=0.28.1,<1",
"opentelemetry-api>=1.38.0,<2",
"pydantic>=2.13.4,<3",
"python-dotenv>=1.2.2,<2",
]
Expand All @@ -41,6 +44,12 @@ keywords = [
Homepage = "https://github.com/sourcegraph/src-py-lib"
Issues = "https://github.com/sourcegraph/src-py-lib/issues"

[project.optional-dependencies]
otel = [
"opentelemetry-exporter-otlp-proto-http>=1.38.0,<2",
"opentelemetry-sdk>=1.38.0,<2",
]

[tool.hatch.build.targets.wheel]
packages = ["src/src_py_lib"]

Expand Down
60 changes: 41 additions & 19 deletions src/src_py_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,32 @@
from src_py_lib.utils.logging import (
LoggingConfig,
LoggingSettings,
TraceContext,
configure_logging,
critical,
current_trace_context,
debug,
error,
event,
info,
log,
log_context,
log_event,
logging_context,
logging_settings_from_config,
new_trace_context,
resolve_log_level_name,
sampled_traceparent,
span,
stage,
startup_event,
submit_with_log_context,
trace_context,
trace_context_from_traceparent,
traceparent_header,
warning,
)
from src_py_lib.utils.telemetry import (
OpenTelemetryConfig,
OpenTelemetryRuntime,
OpenTelemetrySettings,
OpenTelemetrySetupError,
configure_open_telemetry,
current_traceparent_header,
open_telemetry_settings_from_config,
traceparent_fields,
)
from src_py_lib.utils.tsv import write_tsv


Expand All @@ -105,15 +108,33 @@ def logging(
command: str | None = None,
git_cwd: Path | str | None = None,
logging_config: LoggingSettings | None = None,
open_telemetry: OpenTelemetrySettings | None = None,
run_fields: Mapping[str, Any] | None = None,
run_summary: Callable[[], Mapping[str, Any]] | None = None,
) -> AbstractContextManager[Path | None]:
"""Configure standard CLI logging and emit startup metadata."""
resolved_logging_config = logging_config
if open_telemetry is not None:
resolved_logging_config = logging_config or logging_settings_from_config(config)
resolved_logging_config = LoggingSettings(
logger_name=resolved_logging_config.logger_name,
terminal_level=resolved_logging_config.terminal_level,
log_file_level=resolved_logging_config.log_file_level,
log_file=resolved_logging_config.log_file,
logs_dir=resolved_logging_config.logs_dir,
run=resolved_logging_config.run,
retain_log_files=resolved_logging_config.retain_log_files,
suppress_http_dependency_logs=resolved_logging_config.suppress_http_dependency_logs,
resource_sample_interval_seconds=(
resolved_logging_config.resource_sample_interval_seconds
),
open_telemetry=open_telemetry,
)
return logging_context(
command or _script_name(),
config,
git_cwd=git_cwd,
logging_config=logging_config,
logging_config=resolved_logging_config,
run_fields=run_fields,
run_summary=run_summary,
)
Expand All @@ -139,6 +160,10 @@ def _script_name() -> str:
"LinearClientConfig",
"LoggingConfig",
"LoggingSettings",
"OpenTelemetryConfig",
"OpenTelemetryRuntime",
"OpenTelemetrySettings",
"OpenTelemetrySetupError",
"PullRequest",
"SlackClient",
"SlackClientConfig",
Expand All @@ -149,23 +174,23 @@ def _script_name() -> str:
"SourcegraphJaegerTraceError",
"SourcegraphJaegerTraceSummary",
"SourcegraphTrace",
"TraceContext",
"aliased_batched_query",
"config_field",
"config_field_names",
"config_help_formatter",
"config_snapshot",
"configure_open_telemetry",
"configure_logging",
"critical",
"current_trace_context",
"current_traceparent_header",
"debug",
"decode_external_service_id",
"decode_repository_id",
"decode_sourcegraph_node_id",
"encode_repository_id",
"encode_sourcegraph_node_id",
"error",
"event",
"span",
"gh_cli_token",
"gcloud_adc_access_token",
"info",
Expand All @@ -182,25 +207,22 @@ def _script_name() -> str:
"logging",
"logging_context",
"logging_settings_from_config",
"log",
"log_event",
"log_context",
"new_trace_context",
"normalize_sourcegraph_endpoint",
"open_telemetry_settings_from_config",
"parse_args",
"pr_ref_from_url",
"quota_project_from_adc",
"resolve_log_level_name",
"save_json_cache",
"sampled_traceparent",
"slack_client_from_config",
"sourcegraph_client_from_config",
"stage",
"startup_event",
"stream_connection_nodes",
"submit_with_log_context",
"trace_context",
"trace_context_from_traceparent",
"traceparent_header",
"traceparent_fields",
"warning",
"write_tsv",
]
4 changes: 2 additions & 2 deletions src/src_py_lib/clients/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from src_py_lib.utils.http import HTTPClient, HTTPClientError, HTTPResponse, log_safe_url
from src_py_lib.utils.json_types import JSONDict, JSONValue, json_dict, json_list, json_str
from src_py_lib.utils.logging import event
from src_py_lib.utils.logging import span

_OPERATION_NAME_RE = re.compile(r"\b(?:query|mutation|subscription)\s+(\w+)")
HeaderProvider = Mapping[str, str] | Callable[[], Mapping[str, str]]
Expand Down Expand Up @@ -241,7 +241,7 @@ def _execute_once(
after_variable: str = "after",
) -> JSONDict:
body = {"query": query, "variables": variables or {}}
with event(
with span(
"graphql_query",
level="debug",
graphql_client=self.label,
Expand Down
64 changes: 35 additions & 29 deletions src/src_py_lib/clients/sourcegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@
from collections.abc import Iterable, Iterator, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Final, cast
from typing import Final
from urllib.parse import urlsplit

from src_py_lib.clients.graphql import GraphQLClient, stream_connection_nodes
from src_py_lib.utils.config import Config, config_field
from src_py_lib.utils.http import HTTPClient, HTTPClientError, HTTPResponse
from src_py_lib.utils.json_types import JSONDict, JSONValue, json_dict, json_list
from src_py_lib.utils.logging import (
current_trace_context,
new_trace_context,
submit_with_log_context,
trace_context_from_traceparent,
traceparent_header,
from src_py_lib.utils.logging import submit_with_log_context
from src_py_lib.utils.telemetry import (
current_traceparent_header,
set_current_span_attributes,
traceparent_fields,
)

SOURCEGRAPH_EXTERNAL_SERVICE_NODE_TYPE: Final[str] = "ExternalService"
Expand Down Expand Up @@ -187,16 +186,16 @@ class SourcegraphClient:
Plain HTTP endpoints are rejected unless `allow_insecure_http=True` is set
for local development.

Set `trace=True` to ask Sourcegraph to retain traces for each GraphQL
request. Traced requests are available through `drain_traces()` and can be
fetched from the instance's Jaeger/debug endpoint with
Set `fetch_sg_traces=True` to ask Sourcegraph to retain traces for each
GraphQL request. Traced requests are available through `drain_traces()` and
can be fetched from the instance's Jaeger/debug endpoint with
`stream_jaeger_trace_summaries()`.
"""

endpoint: str
token: str
http: HTTPClient = field(default_factory=HTTPClient)
trace: bool = False
fetch_sg_traces: bool = False
allow_insecure_http: bool = False
_traces: queue.Queue[SourcegraphTrace] = field(
default_factory=lambda: queue.Queue[SourcegraphTrace](), init=False, repr=False
Expand Down Expand Up @@ -355,49 +354,51 @@ def _client(self) -> GraphQLClient:
headers=self._graphql_headers,
label="Sourcegraph",
http=self.http,
response_hook=self._record_trace_response if self.trace else None,
response_hook=self._record_trace_response if self.fetch_sg_traces else None,
)

def _authorization_headers(self) -> dict[str, str]:
return {"Authorization": f"token {self.token}"}

def _graphql_headers(self) -> dict[str, str]:
headers = self._authorization_headers()
if self.trace:
if self.fetch_sg_traces:
headers[REQUEST_TRACE_HEADER] = "true"
headers[TRACEPARENT_HEADER] = traceparent_header(
current_trace_context() or new_trace_context()
)
traceparent = current_traceparent_header()
if traceparent is not None:
headers[TRACEPARENT_HEADER] = traceparent
return headers

def _record_trace_response(
self, response: HTTPResponse, request_headers: Mapping[str, str]
) -> None:
trace = sourcegraph_trace_from_headers(response.headers, request_headers)
if trace is not None:
set_current_span_attributes(
{
"sourcegraph.trace_id": trace.trace_id,
"sourcegraph.trace_url": trace.trace_url,
"sourcegraph.span_id": trace.span_id,
}
)
self._traces.put(trace)


def sourcegraph_client_from_config(
config: SourcegraphClientConfig,
*,
http: HTTPClient | None = None,
trace: bool = False,
fetch_sg_traces: bool = False,
) -> SourcegraphClient:
"""Return a Sourcegraph API client from shared Sourcegraph Config fields."""
return SourcegraphClient(
endpoint=config.src_endpoint,
token=config.src_access_token,
http=http or HTTPClient(),
trace=trace,
fetch_sg_traces=fetch_sg_traces,
)


def sampled_traceparent() -> str:
"""Compatibility wrapper for sampled W3C traceparent generation."""
return traceparent_header(sampled=True)


def sourcegraph_trace_from_headers(
response_headers: Mapping[str, str], request_headers: Mapping[str, str]
) -> SourcegraphTrace | None:
Expand All @@ -407,13 +408,13 @@ def sourcegraph_trace_from_headers(
return None
span_id = header_value(response_headers, TRACE_SPAN_RESPONSE_HEADER)
trace_url = header_value(response_headers, TRACE_URL_RESPONSE_HEADER)
parent = trace_context_from_traceparent(header_value(request_headers, TRACEPARENT_HEADER))
parent = traceparent_fields(header_value(request_headers, TRACEPARENT_HEADER))
return SourcegraphTrace(
trace_id=trace_id.lower(),
span_id=span_id.lower() if span_id and is_hex_identifier(span_id, 16) else span_id,
trace_url=trace_url,
parent_trace_id=parent.trace_id if parent is not None else None,
parent_span_id=parent.span_id if parent is not None else None,
parent_trace_id=parent.get("trace_id"),
parent_span_id=parent.get("span_id"),
)


Expand Down Expand Up @@ -472,7 +473,7 @@ def summarize_jaeger_trace(
}
)

hot_operations = [
hot_operations: list[JSONDict] = [
{
"operation": operation,
"count": len(durations),
Expand All @@ -481,12 +482,12 @@ def summarize_jaeger_trace(
}
for operation, durations in durations_by_operation.items()
]
hot_operations.sort(key=lambda operation: float(operation["sum_ms"]), reverse=True)
hot_operations.sort(key=jaeger_summary_operation_sum_ms, reverse=True)
return SourcegraphJaegerTraceSummary(
trace=trace_metadata,
jaeger_found=True,
span_count=len(spans),
hot_operations=tuple(cast(JSONDict, operation) for operation in hot_operations[:10]),
hot_operations=tuple(hot_operations[:10]),
graphql_operations=tuple(
{"operation": operation, "count": count}
for operation, count in graphql_operations.most_common(10)
Expand All @@ -495,6 +496,11 @@ def summarize_jaeger_trace(
)


def jaeger_summary_operation_sum_ms(operation: JSONDict) -> float:
"""Return the total duration for sorting compact Jaeger operation summaries."""
return float_value(operation.get("sum_ms"))


def jaeger_span_tags(span: JSONDict) -> dict[str, object]:
"""Return Jaeger span tags keyed by tag name."""
tags: dict[str, object] = {}
Expand Down
Loading