From de43b8243c9f94cea9f046e6701ef526bc5b4bf4 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 8 Jun 2026 18:22:20 +0200 Subject: [PATCH 1/6] feat(gpu)!: add resource requirements BREAKING CHANGE: SandboxSpec.gpu and DriverSandboxSpec.gpu were replaced with resource_requirements.gpu, changing protobuf field 9 from a bool to a message for both public and driver APIs. Signed-off-by: Evan Lezar --- architecture/compute-runtimes.md | 5 +- crates/openshell-cli/src/main.rs | 171 +++++++++++++++++- crates/openshell-cli/src/run.rs | 20 +- .../sandbox_create_lifecycle_integration.rs | 106 +++++++++++ crates/openshell-core/src/gpu.rs | 34 ++++ crates/openshell-driver-docker/README.md | 2 +- crates/openshell-driver-docker/src/lib.rs | 23 ++- crates/openshell-driver-docker/src/tests.rs | 41 ++++- crates/openshell-driver-kubernetes/README.md | 11 +- .../openshell-driver-kubernetes/src/driver.rs | 126 +++++++++++-- crates/openshell-driver-podman/README.md | 2 +- .../openshell-driver-podman/src/container.rs | 22 ++- crates/openshell-driver-podman/src/driver.rs | 38 +++- crates/openshell-driver-vm/src/driver.rs | 78 ++++++-- crates/openshell-server/src/compute/mod.rs | 62 +++++-- crates/openshell-server/src/grpc/sandbox.rs | 4 +- .../openshell-server/src/grpc/validation.rs | 39 +++- docs/reference/sandbox-compute-drivers.mdx | 3 +- docs/sandboxes/manage-sandboxes.mdx | 11 ++ proto/compute_driver.proto | 17 +- proto/openshell.proto | 17 +- 21 files changed, 740 insertions(+), 92 deletions(-) diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index ec0efded6..10ef69838 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -92,7 +92,10 @@ users. Custom sandbox images must include the agent runtime and any system dependencies, but they should not need to include the gateway. GPU-capable images must include the user-space libraries required by the workload. The -runtime still owns GPU device injection. +runtime still owns GPU device injection. GPU requests are explicit, and can be +refined with a driver-native device identifier or requested count; the gateway +validates the request shape and each runtime enforces the GPU allocation modes it +supports. ## Deployment Shape diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index ea0dd79ca..220010c6d 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -28,6 +28,12 @@ struct GatewayContext { endpoint: String, } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum GpuCliRequest { + DriverDefault, + Count(u32), +} + /// Resolve the gateway name to a [`GatewayContext`] with the gateway endpoint. /// /// Resolution priority: @@ -109,6 +115,31 @@ fn resolve_gateway( }) } +fn resolve_gpu_args(gpu: Option) -> (bool, Option) { + let gpu_count = match gpu { + Some(GpuCliRequest::Count(count)) => Some(count), + Some(GpuCliRequest::DriverDefault) | None => None, + }; + let gpu = gpu.is_some(); + + (gpu, gpu_count) +} + +fn parse_gpu_request(value: &str) -> std::result::Result { + if value.is_empty() { + return Ok(GpuCliRequest::DriverDefault); + } + + let count = value + .parse::() + .map_err(|_| "GPU count must be a positive integer".to_string())?; + if count == 0 { + return Err("GPU count must be greater than 0".to_string()); + } + + Ok(GpuCliRequest::Count(count)) +} + fn resolve_gateway_name(gateway_flag: &Option) -> Option { gateway_flag .clone() @@ -1212,8 +1243,11 @@ enum SandboxCommands { editor: Option, /// Request GPU resources for the sandbox. - #[arg(long)] - gpu: bool, + /// + /// Omit COUNT for the driver's default GPU selection, or pass COUNT + /// to request a specific number of GPUs. + #[arg(long, num_args = 0..=1, value_name = "COUNT", default_missing_value = "", value_parser = parse_gpu_request)] + gpu: Option, /// CPU limit for the sandbox (for example: 500m, 1, 2.5). #[arg(long)] @@ -2621,6 +2655,7 @@ async fn main() -> Result<()> { .map(|s| openshell_core::forward::ForwardSpec::parse(&s)) .transpose()?; let keep = keep || !no_keep || editor.is_some() || forward.is_some(); + let (gpu, gpu_count) = resolve_gpu_args(gpu); let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; let endpoint = &ctx.endpoint; @@ -2634,6 +2669,7 @@ async fn main() -> Result<()> { &upload_specs, keep, gpu, + gpu_count, cpu.as_deref(), memory.as_deref(), driver_config_json.as_deref(), @@ -3628,6 +3664,30 @@ mod tests { }); } + #[test] + fn resolve_gpu_args_handles_absent_gpu() { + let (gpu, gpu_count) = resolve_gpu_args(None); + + assert!(!gpu); + assert_eq!(gpu_count, None); + } + + #[test] + fn resolve_gpu_args_handles_driver_default() { + let (gpu, gpu_count) = resolve_gpu_args(Some(GpuCliRequest::DriverDefault)); + + assert!(gpu); + assert_eq!(gpu_count, None); + } + + #[test] + fn resolve_gpu_args_handles_gpu_count() { + let (gpu, gpu_count) = resolve_gpu_args(Some(GpuCliRequest::Count(2))); + + assert!(gpu); + assert_eq!(gpu_count, Some(2)); + } + #[test] fn apply_auth_uses_stored_token() { let tmp = tempfile::tempdir().unwrap(); @@ -4443,6 +4503,113 @@ mod tests { } } + #[test] + fn sandbox_create_gpu_parses_driver_default() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu"]) + .expect("sandbox create --gpu should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::DriverDefault)); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_parses_from_gpu_flag() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "2"]) + .expect("sandbox create --gpu 2 should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::Count(2))); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_driver_default_allows_trailing_command() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "--", "claude"]) + .expect("sandbox create --gpu -- claude should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, command, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::DriverDefault)); + assert_eq!(command, vec!["claude".to_string()]); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_allows_trailing_command() { + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu", + "2", + "--", + "claude", + ]) + .expect("sandbox create --gpu 2 -- claude should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, command, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::Count(2))); + assert_eq!(command, vec!["claude".to_string()]); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_zero() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "0"]); + + assert!(result.is_err(), "sandbox create --gpu 0 should be rejected"); + } + + #[test] + fn sandbox_create_gpu_count_accepts_equals_syntax() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu=2"]) + .expect("sandbox create --gpu=2 should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::Count(2))); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_non_integer() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "many"]); + + assert!( + result.is_err(), + "sandbox create --gpu many should be rejected" + ); + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index dbd240238..afb9e1d55 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -39,12 +39,13 @@ use openshell_core::proto::{ GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, HealthRequest, - ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, - ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, - ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider, - ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile, - ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, + GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuResourceRequirements, + HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, + ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, + ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent, + PolicySource, PolicyStatus, Provider, ProviderCredentialRefreshStatus, + ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic, + ProviderProfileImportItem, RejectDraftChunkRequest, ResourceRequirements, RevokeSshSessionRequest, RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, @@ -1725,6 +1726,7 @@ pub async fn sandbox_create( uploads: &[(String, Option, bool)], keep: bool, gpu: bool, + gpu_count: Option, cpu: Option<&str>, memory: Option<&str>, driver_config_json: Option<&str>, @@ -1813,9 +1815,13 @@ pub async fn sandbox_create( None }; + let resource_requirements = requested_gpu.then_some(ResourceRequirements { + gpu: Some(GpuResourceRequirements { count: gpu_count }), + }); + let request = CreateSandboxRequest { spec: Some(SandboxSpec { - gpu: requested_gpu, + resource_requirements, environment: environment.clone(), policy, providers: configured_providers, diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 7061614cb..2adf04587 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -787,6 +787,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { None, None, None, + None, &[], None, None, @@ -826,6 +827,7 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { &[], true, false, + None, Some("500m"), Some("2Gi"), None, @@ -905,6 +907,7 @@ async fn sandbox_create_sends_driver_config_json() { false, None, None, + None, Some(r#"{"kubernetes":{"pod":{"priority_class_name":"batch-low"}}}"#), None, &[], @@ -959,6 +962,100 @@ async fn sandbox_create_sends_driver_config_json() { ); } +#[tokio::test] +async fn sandbox_create_sends_gpu_default_request() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-default"), + None, + "openshell", + &[], + true, + true, + None, + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let gpu = requests[0] + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| requirements.gpu.as_ref()) + .expect("GPU requirement should be sent"); + + assert_eq!(gpu.count, None); +} + +#[tokio::test] +async fn sandbox_create_sends_gpu_count_request() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-two"), + None, + "openshell", + &[], + true, + true, + Some(2), + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let gpu = requests[0] + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| requirements.gpu.as_ref()) + .expect("GPU requirement should be sent"); + + assert_eq!(gpu.count, Some(2)); +} + #[tokio::test] async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { let server = run_server().await; @@ -981,6 +1078,7 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { None, None, None, + None, &[], None, None, @@ -1039,6 +1137,7 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { None, None, None, + None, &[], None, None, @@ -1093,6 +1192,7 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { None, None, None, + None, &[], None, None, @@ -1139,6 +1239,7 @@ async fn sandbox_create_times_out_when_only_logs_arrive() { None, None, None, + None, &[], None, None, @@ -1181,6 +1282,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1227,6 +1329,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1273,6 +1376,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { None, None, None, + None, &[], None, None, @@ -1319,6 +1423,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { None, None, None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), @@ -1361,6 +1466,7 @@ async fn sandbox_create_sends_environment_variables() { None, None, None, + None, &[], None, None, diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 9718b50f2..8c25f47b5 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -4,6 +4,40 @@ //! Shared GPU request helpers. use crate::config::CDI_GPU_DEVICE_ALL; +use crate::proto::ResourceRequirements as PublicResourceRequirements; +use crate::proto::compute::v1::ResourceRequirements as DriverResourceRequirements; + +/// Return whether public resource requirements request a GPU. +#[must_use] +pub fn public_gpu_requested(resources: Option<&PublicResourceRequirements>) -> bool { + resources + .and_then(|resources| resources.gpu.as_ref()) + .is_some() +} + +/// Return the requested public GPU count, if one was specified. +#[must_use] +pub fn public_gpu_count(resources: Option<&PublicResourceRequirements>) -> Option { + resources + .and_then(|resources| resources.gpu.as_ref()) + .and_then(|gpu| gpu.count) +} + +/// Return whether driver resource requirements request a GPU. +#[must_use] +pub fn driver_gpu_requested(resources: Option<&DriverResourceRequirements>) -> bool { + resources + .and_then(|resources| resources.gpu.as_ref()) + .is_some() +} + +/// Return the requested driver GPU count, if one was specified. +#[must_use] +pub fn driver_gpu_count(resources: Option<&DriverResourceRequirements>) -> Option { + resources + .and_then(|resources| resources.gpu.as_ref()) + .and_then(|gpu| gpu.count) +} /// Resolve a GPU request into CDI device identifiers. /// diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index 7f74cbe17..b8e244a13 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses `driver_config.cdi_devices` when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | +| CDI GPU request | Uses `driver_config.cdi_devices` when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. Counted GPU requests are rejected. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 05137a2b0..036f2683d 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -27,7 +27,7 @@ use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME, LABEL_SANDBOX_NAMESPACE, SUPERVISOR_IMAGE_BINARY_PATH, supervisor_image_should_refresh, }; -use openshell_core::gpu::cdi_gpu_device_ids; +use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_count, driver_gpu_requested}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -461,7 +461,14 @@ impl DockerComputeDriver { let driver_config = DockerSandboxDriverConfig::from_template(template).map_err(Status::invalid_argument)?; - Self::validate_gpu_request(spec.gpu, config.supports_gpu, &driver_config)?; + let gpu_requested = driver_gpu_requested(spec.resource_requirements.as_ref()); + let gpu_count = driver_gpu_count(spec.resource_requirements.as_ref()); + Self::validate_gpu_request( + gpu_requested, + gpu_count, + config.supports_gpu, + &driver_config, + )?; Ok(()) } @@ -510,6 +517,7 @@ impl DockerComputeDriver { fn validate_gpu_request( gpu: bool, + gpu_count: Option, supports_gpu: bool, driver_config: &DockerSandboxDriverConfig, ) -> Result<(), Status> { @@ -519,6 +527,12 @@ impl DockerComputeDriver { )); } + if gpu_count.is_some() { + return Err(Status::invalid_argument( + "docker GPU count requests are not supported; use --gpu without a count or driver_config.cdi_devices", + )); + } + if gpu && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", @@ -2121,14 +2135,15 @@ fn build_device_requests(sandbox: &DriverSandbox) -> Result DriverSandbox { environment: HashMap::from([("TEMPLATE_ENV".to_string(), "template".to_string())]), ..Default::default() }), - gpu: false, + resource_requirements: None, sandbox_token: String::new(), }), status: None, @@ -79,6 +80,12 @@ fn list_string_driver_config(field: &str, values: &[&str]) -> prost_types::Struc } } +fn gpu_resources(count: Option) -> ResourceRequirements { + ResourceRequirements { + gpu: Some(GpuResourceRequirements { count }), + } +} + fn runtime_config() -> DockerDriverRuntimeConfig { DockerDriverRuntimeConfig { default_image: "image:latest".to_string(), @@ -1007,7 +1014,7 @@ fn build_container_create_body_clears_inherited_cmd() { fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resources(None)); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -1020,7 +1027,7 @@ fn validate_sandbox_rejects_invalid_cdi_devices_before_gpu_capability() { let config = runtime_config(); let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&[])); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -1035,7 +1042,7 @@ fn validate_sandbox_rejects_unknown_driver_config_fields() { let config = runtime_config(); let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); spec.template.as_mut().unwrap().driver_config = Some(cdi_device_typo_config(&["nvidia.com/gpu=0"])); @@ -1045,12 +1052,28 @@ fn validate_sandbox_rejects_unknown_driver_config_fields() { assert!(err.message().contains("unknown field")); } +#[test] +fn validate_sandbox_rejects_gpu_count_request() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resources(Some(2))); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!( + err.message() + .contains("GPU count requests are not supported") + ); +} + #[test] fn validate_sandbox_rejects_template_errors_before_device_config() { let config = runtime_config(); let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); let template = spec.template.as_mut().unwrap(); template.agent_socket_path = "/tmp/agent.sock".to_string(); template.driver_config = Some(cdi_devices_config(&[])); @@ -1088,7 +1111,7 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resources(None)); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -1111,7 +1134,7 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { config.supports_gpu = true; let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&["nvidia.com/gpu=0"])); let create_body = build_container_create_body(&sandbox, &config).unwrap(); @@ -1150,7 +1173,7 @@ fn build_container_create_body_rejects_cdi_devices_without_gpu() { fn build_container_create_body_rejects_empty_cdi_devices() { let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&[])); let err = build_container_create_body(&sandbox, &runtime_config()).unwrap_err(); diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 6ad0b27c8..3cdb9fa57 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -62,9 +62,9 @@ the supervisor's network namespace mount setup on AppArmor-enabled nodes. ## GPU Support When a sandbox requests GPU support, the driver checks node allocatable capacity -for `nvidia.com/gpu` and requests one GPU resource in the workload spec. The -sandbox image must provide the user-space libraries needed by the agent -workload. +for `nvidia.com/gpu` and requests the configured GPU count in the workload spec. +When no count is set, the driver requests one GPU resource. The sandbox image +must provide the user-space libraries needed by the agent workload. ## Driver Config POC @@ -97,5 +97,6 @@ POC parser renders the keys listed above and rejects unknown fields. `pod.runtime_class_name` maps to PodSpec `runtimeClassName` and overrides the driver's configured `default_runtime_class_name`; the typed public `SandboxTemplate.runtime_class_name` still takes precedence when set. Use the -public `gpu` flag for the default GPU request and `driver_config` only for -additional driver-owned resource details. +public `--gpu` flag for the default GPU request, pass a count to `--gpu` for +counted GPU requests, and use `driver_config` only for additional driver-owned +resource details. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index dc636efc3..e1467f3e9 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -17,6 +17,7 @@ use kube::{Client, Error as KubeError}; use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, SUPERVISOR_IMAGE_BINARY_PATH, }; +use openshell_core::gpu::{driver_gpu_count, driver_gpu_requested}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -281,7 +282,7 @@ impl KubernetesComputeDriver { pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { let _ = KubernetesSandboxDriverConfig::from_sandbox(sandbox) .map_err(tonic::Status::invalid_argument)?; - let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); + let gpu_requested = validate_gpu_resource_requirements(sandbox)?; if gpu_requested && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) @@ -376,6 +377,9 @@ impl KubernetesComputeDriver { pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { let _ = KubernetesSandboxDriverConfig::from_sandbox(sandbox) .map_err(KubernetesDriverError::InvalidArgument)?; + validate_gpu_resource_requirements(sandbox).map_err(|status| { + KubernetesDriverError::InvalidArgument(status.message().to_string()) + })?; let name = sandbox.name.as_str(); info!( sandbox_id = %sandbox.id, @@ -638,6 +642,24 @@ impl KubernetesComputeDriver { } } +fn validate_gpu_resource_requirements(sandbox: &Sandbox) -> Result { + let Some(resource_requirements) = sandbox + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + else { + return Ok(false); + }; + + if driver_gpu_count(Some(resource_requirements)) == Some(0) { + return Err(tonic::Status::invalid_argument( + "gpu count must be greater than 0", + )); + } + + Ok(driver_gpu_requested(Some(resource_requirements))) +} + fn sandbox_labels(sandbox: &Sandbox) -> BTreeMap { let mut labels = BTreeMap::new(); labels.insert(LABEL_SANDBOX_ID.to_string(), sandbox.id.clone()); @@ -1201,7 +1223,14 @@ fn sandbox_to_k8s_spec( if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s(template, spec.gpu, &pod_env, inject_workspace, params), + sandbox_template_to_k8s_with_gpu_count( + template, + driver_gpu_requested(spec.resource_requirements.as_ref()), + driver_gpu_count(spec.resource_requirements.as_ref()), + &pod_env, + inject_workspace, + params, + ), ); if !template.agent_socket_path.is_empty() { root.insert( @@ -1231,9 +1260,12 @@ fn sandbox_to_k8s_spec( let pod_env = spec_pod_env(spec); root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s( + sandbox_template_to_k8s_with_gpu_count( &SandboxTemplate::default(), - spec.is_some_and(|s| s.gpu), + spec.and_then(|s| s.resource_requirements.as_ref()) + .is_some_and(|requirements| driver_gpu_requested(Some(requirements))), + spec.and_then(|s| s.resource_requirements.as_ref()) + .and_then(|requirements| driver_gpu_count(Some(requirements))), &pod_env, inject_workspace, params, @@ -1246,12 +1278,31 @@ fn sandbox_to_k8s_spec( ) } +#[cfg(test)] fn sandbox_template_to_k8s( template: &SandboxTemplate, gpu: bool, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, +) -> serde_json::Value { + sandbox_template_to_k8s_with_gpu_count( + template, + gpu, + None, + spec_environment, + inject_workspace, + params, + ) +} + +fn sandbox_template_to_k8s_with_gpu_count( + template: &SandboxTemplate, + gpu: bool, + gpu_count: Option, + spec_environment: &std::collections::HashMap, + inject_workspace: bool, + params: &SandboxPodParams<'_>, ) -> serde_json::Value { let driver_config = kubernetes_driver_config(template); @@ -1440,7 +1491,7 @@ fn sandbox_template_to_k8s( serde_json::Value::Array(volume_mounts), ); - if let Some(resources) = container_resources(template, gpu) { + if let Some(resources) = container_resources(template, gpu, gpu_count) { container.insert("resources".to_string(), resources); } apply_agent_driver_resources(&mut container, &driver_config.containers.agent.resources); @@ -1618,7 +1669,11 @@ fn app_armor_profile_to_k8s(profile: &AppArmorProfile) -> serde_json::Value { value } -fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option { +fn container_resources( + template: &SandboxTemplate, + gpu: bool, + gpu_count: Option, +) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API // Struct), then overlay the typed DriverResourceRequirements on top. @@ -1652,7 +1707,11 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Option> = @@ -2000,10 +2057,9 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, template: Some(SandboxTemplate { driver_config: Some(json_struct(serde_json::json!({ - "cdi_devices": ["nvidia.com/gpu=0"] + "gpu_device_ids": ["0000:2d:00.0"] }))), ..Default::default() }), @@ -2014,6 +2070,24 @@ mod tests { let err = KubernetesSandboxDriverConfig::from_sandbox(&sandbox).unwrap_err(); assert!(err.contains("unknown field")); + assert!(err.contains("gpu_device_ids")); + } + + #[test] + fn validate_rejects_zero_gpu_count() { + let sandbox = Sandbox { + spec: Some(SandboxSpec { + resource_requirements: Some(ResourceRequirements { + gpu: Some(GpuResourceRequirements { count: Some(0) }), + }), + ..SandboxSpec::default() + }), + ..Sandbox::default() + }; + + let err = validate_gpu_resource_requirements(&sandbox).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("gpu count must be greater than 0")); } #[test] @@ -2345,6 +2419,26 @@ mod tests { ); } + #[test] + fn gpu_count_sandbox_adds_requested_gpu_limit() { + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s_with_gpu_count( + &SandboxTemplate::default(), + true, + Some(2), + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], + serde_json::json!("2") + ); + } + #[test] fn gpu_sandbox_uses_template_runtime_class_name_when_set() { let template = SandboxTemplate { diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index 68a223bde..3c22c919a 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -46,7 +46,7 @@ The container spec in `container.rs` sets these security-critical fields: | `no_new_privileges` | `true` | Prevents privilege escalation after exec. | | `seccomp_profile_path` | `unconfined` | The supervisor installs its own policy-aware BPF filter. A container-level profile can block Landlock/seccomp syscalls during setup. | | `mounts` | Private tmpfs at `/run/netns` | Lets the supervisor create named network namespaces in rootless Podman. | -| CDI GPU devices | `driver_config.cdi_devices` when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. | +| CDI GPU devices | `driver_config.cdi_devices` when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. Counted GPU requests are rejected. | The restricted agent child does not retain these supervisor privileges. diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index afcc17585..f14c01934 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -5,7 +5,7 @@ use crate::config::PodmanComputeConfig; use openshell_core::ComputeDriverError; -use openshell_core::gpu::cdi_gpu_device_ids; +use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_requested}; use openshell_core::proto::compute::v1::{DriverSandbox, DriverSandboxTemplate}; use openshell_core::proto_struct::deserialize_optional_non_empty_string_list; use openshell_core::{driver_mounts, proto_struct}; @@ -484,14 +484,15 @@ fn build_devices(sandbox: &DriverSandbox) -> Result>, Co let cdi_devices = PodmanSandboxDriverConfig::from_sandbox(sandbox)? .cdi_devices .unwrap_or_default(); - if !spec.gpu && !cdi_devices.is_empty() { + let gpu_requested = driver_gpu_requested(spec.resource_requirements.as_ref()); + if !gpu_requested && !cdi_devices.is_empty() { return Err(ComputeDriverError::InvalidArgument( "driver_config.cdi_devices requires gpu=true".to_string(), )); } Ok( - cdi_gpu_device_ids(spec.gpu, &cdi_devices).map(|device_ids| { + cdi_gpu_device_ids(gpu_requested, &cdi_devices).map(|device_ids| { device_ids .into_iter() .map(|path| LinuxDevice { path }) @@ -1092,6 +1093,7 @@ fn parse_memory_to_bytes(quantity: &str) -> Option { #[cfg(test)] mod tests { use super::*; + use openshell_core::proto::compute::v1::{GpuResourceRequirements, ResourceRequirements}; static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); @@ -1133,6 +1135,12 @@ mod tests { } } + fn gpu_resources(count: Option) -> ResourceRequirements { + ResourceRequirements { + gpu: Some(GpuResourceRequirements { count }), + } + } + #[test] fn parse_cpu_millicore() { assert_eq!(parse_cpu_to_microseconds("500m"), Some(50_000)); @@ -1246,7 +1254,7 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), ..Default::default() }); let config = test_config(); @@ -1264,7 +1272,7 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(DriverSandboxTemplate { driver_config: Some(cdi_devices_config(&["nvidia.com/gpu=0"])), ..Default::default() @@ -1305,7 +1313,7 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(DriverSandboxTemplate { driver_config: Some(cdi_devices_config(&[])), ..Default::default() @@ -1325,7 +1333,7 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(DriverSandboxTemplate { driver_config: Some(cdi_device_typo_config(&["nvidia.com/gpu=0"])), ..Default::default() diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index 6f9762c15..2fbe8292d 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -11,6 +11,7 @@ use crate::watcher::{ }; use openshell_core::ComputeDriverError; use openshell_core::driver_utils::supervisor_image_should_refresh; +use openshell_core::gpu::{driver_gpu_count, driver_gpu_requested}; use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; use std::path::PathBuf; use std::time::Duration; @@ -281,19 +282,38 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); + let gpu_requested = sandbox + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .is_some_and(|requirements| driver_gpu_requested(Some(requirements))); + let gpu_count = sandbox + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| driver_gpu_count(Some(requirements))); let driver_config = PodmanSandboxDriverConfig::from_sandbox(sandbox)?; if !gpu_requested && driver_config.cdi_devices.is_some() { return Err(ComputeDriverError::InvalidArgument( "driver_config.cdi_devices requires gpu=true".to_string(), )); } - Self::validate_gpu_request(gpu_requested)?; + Self::validate_gpu_request(gpu_requested, gpu_count)?; self.validate_user_volume_mounts_available(sandbox).await?; Ok(()) } - fn validate_gpu_request(gpu_requested: bool) -> Result<(), ComputeDriverError> { + fn validate_gpu_request( + gpu_requested: bool, + gpu_count: Option, + ) -> Result<(), ComputeDriverError> { + if gpu_count.is_some() { + return Err(ComputeDriverError::InvalidArgument( + "podman GPU count requests are not supported; use --gpu without a count or driver_config.cdi_devices" + .to_string(), + )); + } + if gpu_requested && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), @@ -693,6 +713,18 @@ mod tests { assert!(matches!(err, ComputeDriverError::Message(_))); } + #[test] + fn validate_gpu_request_rejects_gpu_count() { + let err = PodmanComputeDriver::validate_gpu_request(true, Some(2)) + .expect_err("gpu count should be rejected"); + + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!( + err.to_string() + .contains("GPU count requests are not supported") + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 30fecd8be..de34cb507 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -29,6 +29,7 @@ use oci_client::manifest::{ }; use oci_client::secrets::RegistryAuth; use oci_client::{Reference, RegistryOperation}; +use openshell_core::gpu::{driver_gpu_count, driver_gpu_requested}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -627,7 +628,11 @@ impl VmDriver { overlay_preparation: OverlayPreparation, ) -> Result<(), Status> { self.ensure_provisioning_active(&sandbox.id).await?; - let is_gpu = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); + let is_gpu = sandbox + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .is_some_and(|requirements| driver_gpu_requested(Some(requirements))); self.publish_platform_event( sandbox.id.clone(), platform_event( @@ -3106,8 +3111,21 @@ fn validate_vm_gpu_request(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), S .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; + let gpu_requested = driver_gpu_requested(spec.resource_requirements.as_ref()); + let gpu_count = driver_gpu_count(spec.resource_requirements.as_ref()); + + if gpu_count == Some(0) { + return Err(Status::invalid_argument("gpu count must be greater than 0")); + } + + if gpu_count.is_some_and(|count| count > 1) { + return Err(Status::invalid_argument( + "VM GPU sandboxes support only one GPU", + )); + } + let _ = vm_gpu_device_id(sandbox)?; - if spec.gpu && !gpu_enabled { + if gpu_requested && !gpu_enabled { return Err(Status::failed_precondition( "GPU support is not enabled on this driver; start with --gpu", )); @@ -3124,7 +3142,8 @@ fn vm_gpu_device_id(sandbox: &Sandbox) -> Result, Status> { .map_err(Status::invalid_argument)? .gpu_device_ids .unwrap_or_default(); - if !spec.gpu && !gpu_device_ids.is_empty() { + let gpu_requested = driver_gpu_requested(spec.resource_requirements.as_ref()); + if !gpu_requested && !gpu_device_ids.is_empty() { return Err(Status::invalid_argument( "driver_config.gpu_device_ids requires gpu=true", )); @@ -3135,9 +3154,7 @@ fn vm_gpu_device_id(sandbox: &Sandbox) -> Result, Status> { )); } - Ok(spec - .gpu - .then(|| gpu_device_ids.into_iter().next().unwrap_or_default())) + Ok(gpu_requested.then(|| gpu_device_ids.into_iter().next().unwrap_or_default())) } #[allow(clippy::result_large_err)] @@ -5064,6 +5081,7 @@ mod tests { }; use openshell_core::proto::compute::v1::{ DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, + GpuResourceRequirements, ResourceRequirements, }; use prost_types::{Struct, Value, value::Kind}; use std::fs; @@ -5102,6 +5120,12 @@ mod tests { } } + fn gpu_resources(count: Option) -> ResourceRequirements { + ResourceRequirements { + gpu: Some(GpuResourceRequirements { count }), + } + } + #[test] fn vm_pulling_layer_event_adds_progress_detail_metadata() { let mut event = platform_event( @@ -5169,7 +5193,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), ..Default::default() }), ..Default::default() @@ -5185,7 +5209,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), ..Default::default() }), ..Default::default() @@ -5193,12 +5217,40 @@ mod tests { validate_vm_sandbox(&sandbox, true).expect("gpu should be accepted when enabled"); } + #[test] + fn validate_vm_sandbox_accepts_gpu_count_one() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(1))), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true).expect("one GPU should be accepted when enabled"); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_above_one() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("multiple GPU VM request should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("support only one GPU")); + } + #[test] fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: false, template: Some(SandboxTemplate { driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), ..Default::default() @@ -5218,7 +5270,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(SandboxTemplate { driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0", "0000:31:00.0"])), ..Default::default() @@ -5238,7 +5290,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(SandboxTemplate { driver_config: Some(gpu_device_ids_config(&[])), ..Default::default() @@ -5258,7 +5310,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(SandboxTemplate { driver_config: Some(gpu_device_id_typo_config(&["0000:2d:00.0"])), ..Default::default() @@ -5278,7 +5330,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(SandboxTemplate { agent_socket_path: "/tmp/agent.sock".to_string(), driver_config: Some(gpu_device_ids_config(&[])), diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 812b9c59a..5c3e53860 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -19,10 +19,11 @@ use openshell_core::ComputeDriverKind; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, - DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, - ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, - compute_driver_client::ComputeDriverClient, compute_driver_server::ComputeDriver, - watch_sandboxes_event, + DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, + GpuResourceRequirements as DriverGpuResourceRequirements, ListSandboxesRequest, + ResourceRequirements as DriverSandboxResourceRequirements, ValidateSandboxCreateRequest, + WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_client::ComputeDriverClient, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::proto::{ PlatformEvent, Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, SandboxStatus, @@ -1279,7 +1280,14 @@ fn driver_sandbox_spec_from_public( .as_ref() .map(|template| driver_sandbox_template_from_public(template, driver_kind)) .transpose()?, - gpu: spec.gpu, + resource_requirements: spec.resource_requirements.as_ref().map(|requirements| { + DriverSandboxResourceRequirements { + gpu: requirements + .gpu + .as_ref() + .map(|gpu| DriverGpuResourceRequirements { count: gpu.count }), + } + }), sandbox_token: String::new(), }) } @@ -1660,7 +1668,9 @@ fn derive_phase(status: Option<&DriverSandboxStatus>) -> SandboxPhase { } fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { - let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); + let gpu_requested = spec + .and_then(|sandbox_spec| sandbox_spec.resource_requirements.as_ref()) + .is_some_and(|requirements| openshell_core::gpu::public_gpu_requested(Some(requirements))); if !gpu_requested { return; } @@ -1856,6 +1866,26 @@ mod tests { } } + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_requirement() { + let public = SandboxSpec { + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: Some(2) }), + }), + ..Default::default() + }; + + let driver = + driver_sandbox_spec_from_public(&public, None).expect("driver spec should map"); + + let gpu = driver + .resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) + .expect("driver GPU requirement should be set"); + assert_eq!(gpu.count, Some(2)); + } + #[test] fn select_driver_config_forwards_only_matching_driver_block() { let config = prost_types::Struct { @@ -2355,7 +2385,9 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: true, + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: None }), + }), ..Default::default() }), ); @@ -2383,13 +2415,7 @@ mod tests { ..Default::default() }); - rewrite_user_facing_conditions( - &mut status, - Some(&SandboxSpec { - gpu: false, - ..Default::default() - }), - ); + rewrite_user_facing_conditions(&mut status, Some(&SandboxSpec::default())); assert_eq!(status.unwrap().conditions[0].message, original); } @@ -2668,7 +2694,9 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: None }), + }), ..Default::default() }), ..sandbox_record("sb-1", "sandbox-a", SandboxPhase::Provisioning) @@ -2691,7 +2719,9 @@ mod tests { SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); - assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu)); + assert!(stored.spec.as_ref().is_some_and(|spec| { + openshell_core::gpu::public_gpu_requested(spec.resource_requirements.as_ref()) + })); } #[tokio::test] diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index e60ce3995..2817e7381 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -98,9 +98,11 @@ fn emit_sandbox_create_telemetry( } else { SandboxTemplateSource::Default }; + let gpu_requested = + openshell_core::gpu::public_gpu_requested(spec.resource_requirements.as_ref()); openshell_core::telemetry::emit_sandbox_create( outcome, - spec.gpu, + gpu_requested, spec.providers.len() as u64, spec.policy.is_some(), template_source, diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 03a69d6e9..16b05a90b 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -134,6 +134,9 @@ pub(super) fn validate_sandbox_spec( validate_env_entries(&tmpl.environment, "spec.template.environment")?; } + // --- spec.resource_requirements.gpu --- + validate_gpu_request_fields(spec)?; + // --- spec.policy serialized size --- if let Some(ref policy) = spec.policy { let size = policy.encoded_len(); @@ -147,6 +150,14 @@ pub(super) fn validate_sandbox_spec( Ok(()) } +fn validate_gpu_request_fields(spec: &openshell_core::proto::SandboxSpec) -> Result<(), Status> { + if openshell_core::gpu::public_gpu_count(spec.resource_requirements.as_ref()) == Some(0) { + return Err(Status::invalid_argument("gpu count must be greater than 0")); + } + + Ok(()) +} + /// Validate template-level field sizes. fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { // String fields. @@ -760,12 +771,38 @@ mod tests { #[test] fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { - gpu: true, + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: None }), + }), ..Default::default() }; assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); } + #[test] + fn validate_sandbox_spec_accepts_gpu_count() { + let spec = SandboxSpec { + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: Some(2) }), + }), + ..Default::default() + }; + assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); + } + + #[test] + fn validate_sandbox_spec_rejects_zero_gpu_count() { + let spec = SandboxSpec { + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: Some(0) }), + }), + ..Default::default() + }; + let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("gpu count must be greater than 0")); + } + #[test] fn validate_sandbox_spec_accepts_empty_defaults() { assert!(validate_sandbox_spec("", &default_spec()).is_ok()); diff --git a/docs/reference/sandbox-compute-drivers.mdx b/docs/reference/sandbox-compute-drivers.mdx index caf8e9961..cbadac071 100644 --- a/docs/reference/sandbox-compute-drivers.mdx +++ b/docs/reference/sandbox-compute-drivers.mdx @@ -53,7 +53,8 @@ openshell sandbox create \ ``` Driver config is for fields without a stable public flag. Prefer `--cpu`, -`--memory`, and `--gpu` for portable resource intent. +`--memory`, and `--gpu` for supported resource intent. Pass a count to `--gpu` +when the active driver supports counted allocation. Exact GPU device selection remains driver-owned and requires `--gpu`. Docker and Podman accept `cdi_devices`; replace the top-level `docker` key with diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 1a54d0a06..58851b632 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -70,6 +70,17 @@ To request GPU resources, add `--gpu`: openshell sandbox create --gpu -- claude ``` +Request a specific number of GPUs by passing a count to `--gpu`: + +```shell +openshell sandbox create --gpu 2 -- claude +``` + +Support for counted GPU requests is driver-dependent. Kubernetes honors a +counted `--gpu` request by setting the `nvidia.com/gpu` limit. Docker and Podman +reject count-based selection. VM gateways accept only one GPU, either through +`--gpu` or `--gpu 1`. + For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the updated Docker daemon capability. diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index dbcb9e818..679433a6f 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -83,8 +83,9 @@ message DriverSandboxSpec { map environment = 5; // Runtime template consumed by the driver during provisioning. DriverSandboxTemplate template = 6; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; + // Portable resource requirements used by the gateway for driver selection + // and by drivers for provisioning. + ResourceRequirements resource_requirements = 9; reserved 10; reserved "gpu_device"; // Gateway-minted JWT identifying this sandbox to the gateway. Set by @@ -96,6 +97,18 @@ message DriverSandboxSpec { string sandbox_token = 11; } +message ResourceRequirements { + // GPU requirements for the sandbox. Presence indicates a GPU request. + GpuResourceRequirements gpu = 1; +} + +// Driver GPU resource requirements. +message GpuResourceRequirements { + // Optional number of GPUs requested. When omitted, the driver uses its + // default GPU assignment behavior. + optional uint32 count = 1; +} + // Driver-owned runtime template consumed by the compute platform. // // This message describes the sandbox workload in backend-neutral terms. diff --git a/proto/openshell.proto b/proto/openshell.proto index d701956d3..fc8975b02 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -317,8 +317,9 @@ message SandboxSpec { openshell.sandbox.v1.SandboxPolicy policy = 7; // Provider names to attach to this sandbox. repeated string providers = 8; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; + // Portable resource requirements used by the gateway for driver selection + // and by drivers for provisioning. + ResourceRequirements resource_requirements = 9; reserved 10; reserved "gpu_device"; // Field 11 was `proposal_approval_mode`. The approval mode is now a @@ -329,6 +330,18 @@ message SandboxSpec { reserved "proposal_approval_mode"; } +message ResourceRequirements { + // GPU requirements for the sandbox. Presence indicates a GPU request. + GpuResourceRequirements gpu = 1; +} + +// Public GPU resource requirements. +message GpuResourceRequirements { + // Optional number of GPUs requested. When omitted, the driver uses its + // default GPU assignment behavior. + optional uint32 count = 1; +} + // Public sandbox template mapped onto compute-driver template inputs. message SandboxTemplate { // Fully-qualified OCI image reference used to boot the sandbox. From 1af16fd61300dd5c2acac29d979fd50a3eee643a Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 11 Jun 2026 13:54:39 +0200 Subject: [PATCH 2/6] refactor(gpu): pass requirements through sandbox create Pass the coupled GPU requirement object through the CLI sandbox_create boundary instead of splitting presence and count into separate arguments. Signed-off-by: Evan Lezar --- crates/openshell-cli/src/main.rs | 43 +++++++++---------- crates/openshell-cli/src/run.rs | 9 +--- .../sandbox_create_lifecycle_integration.rs | 37 ++++++---------- 3 files changed, 36 insertions(+), 53 deletions(-) diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 220010c6d..01b57ec97 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -19,6 +19,7 @@ use openshell_bootstrap::{ use openshell_cli::completers; use openshell_cli::run; use openshell_cli::tls::TlsOptions; +use openshell_core::proto::GpuResourceRequirements; /// Resolved gateway context: name + gateway endpoint. struct GatewayContext { @@ -115,14 +116,12 @@ fn resolve_gateway( }) } -fn resolve_gpu_args(gpu: Option) -> (bool, Option) { - let gpu_count = match gpu { - Some(GpuCliRequest::Count(count)) => Some(count), - Some(GpuCliRequest::DriverDefault) | None => None, - }; - let gpu = gpu.is_some(); - - (gpu, gpu_count) +fn resolve_gpu_requirements(gpu: Option) -> Option { + match gpu { + Some(GpuCliRequest::Count(count)) => Some(GpuResourceRequirements { count: Some(count) }), + Some(GpuCliRequest::DriverDefault) => Some(GpuResourceRequirements { count: None }), + None => None, + } } fn parse_gpu_request(value: &str) -> std::result::Result { @@ -2655,7 +2654,7 @@ async fn main() -> Result<()> { .map(|s| openshell_core::forward::ForwardSpec::parse(&s)) .transpose()?; let keep = keep || !no_keep || editor.is_some() || forward.is_some(); - let (gpu, gpu_count) = resolve_gpu_args(gpu); + let gpu_requirements = resolve_gpu_requirements(gpu); let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; let endpoint = &ctx.endpoint; @@ -2668,8 +2667,7 @@ async fn main() -> Result<()> { &ctx.name, &upload_specs, keep, - gpu, - gpu_count, + gpu_requirements, cpu.as_deref(), memory.as_deref(), driver_config_json.as_deref(), @@ -3665,27 +3663,26 @@ mod tests { } #[test] - fn resolve_gpu_args_handles_absent_gpu() { - let (gpu, gpu_count) = resolve_gpu_args(None); + fn resolve_gpu_requirements_handles_absent_gpu() { + let gpu = resolve_gpu_requirements(None); - assert!(!gpu); - assert_eq!(gpu_count, None); + assert_eq!(gpu, None); } #[test] - fn resolve_gpu_args_handles_driver_default() { - let (gpu, gpu_count) = resolve_gpu_args(Some(GpuCliRequest::DriverDefault)); + fn resolve_gpu_requirements_handles_driver_default() { + let gpu = resolve_gpu_requirements(Some(GpuCliRequest::DriverDefault)) + .expect("GPU requirement should be present"); - assert!(gpu); - assert_eq!(gpu_count, None); + assert_eq!(gpu.count, None); } #[test] - fn resolve_gpu_args_handles_gpu_count() { - let (gpu, gpu_count) = resolve_gpu_args(Some(GpuCliRequest::Count(2))); + fn resolve_gpu_requirements_handles_gpu_count() { + let gpu = resolve_gpu_requirements(Some(GpuCliRequest::Count(2))) + .expect("GPU requirement should be present"); - assert!(gpu); - assert_eq!(gpu_count, Some(2)); + assert_eq!(gpu.count, Some(2)); } #[test] diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index afb9e1d55..50fbab85a 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -1725,8 +1725,7 @@ pub async fn sandbox_create( gateway_name: &str, uploads: &[(String, Option, bool)], keep: bool, - gpu: bool, - gpu_count: Option, + gpu_requirements: Option, cpu: Option<&str>, memory: Option<&str>, driver_config_json: Option<&str>, @@ -1782,8 +1781,6 @@ pub async fn sandbox_create( } None => None, }; - let requested_gpu = gpu; - let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?; let inferred_types: Vec = if providers_v2_enabled { Vec::new() @@ -1815,9 +1812,7 @@ pub async fn sandbox_create( None }; - let resource_requirements = requested_gpu.then_some(ResourceRequirements { - gpu: Some(GpuResourceRequirements { count: gpu_count }), - }); + let resource_requirements = gpu_requirements.map(|gpu| ResourceRequirements { gpu: Some(gpu) }); let request = CreateSandboxRequest { spec: Some(SandboxSpec { diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 2adf04587..4d3614d2c 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -18,13 +18,14 @@ use openshell_core::proto::{ ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxProvidersRequest, - ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, PlatformEvent, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, Sandbox, SandboxCondition, - SandboxLogLine, SandboxPhase, SandboxResponse, SandboxStatus, SandboxStreamEvent, - ServiceStatus, SettingValue, SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, - sandbox_stream_event, setting_value, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, GpuResourceRequirements, + HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, PlatformEvent, ProviderResponse, RevokeSshSessionRequest, + RevokeSshSessionResponse, Sandbox, SandboxCondition, SandboxLogLine, SandboxPhase, + SandboxResponse, SandboxStatus, SandboxStreamEvent, ServiceStatus, SettingValue, + SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, sandbox_stream_event, + setting_value, }; use std::collections::HashMap; use std::fs; @@ -766,6 +767,10 @@ fn test_tls(server: &TestServer) -> TlsOptions { server.tls.with_gateway_name("openshell") } +fn gpu_requirements(count: Option) -> GpuResourceRequirements { + GpuResourceRequirements { count } +} + #[tokio::test] async fn sandbox_create_keeps_command_sessions_by_default() { let server = run_server().await; @@ -782,7 +787,6 @@ async fn sandbox_create_keeps_command_sessions_by_default() { "openshell", &[], true, - false, None, None, None, @@ -826,7 +830,6 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { "openshell", &[], true, - false, None, Some("500m"), Some("2Gi"), @@ -904,7 +907,6 @@ async fn sandbox_create_sends_driver_config_json() { "openshell", &[], true, - false, None, None, None, @@ -978,8 +980,7 @@ async fn sandbox_create_sends_gpu_default_request() { "openshell", &[], true, - true, - None, + Some(gpu_requirements(None)), None, None, None, @@ -1025,8 +1026,7 @@ async fn sandbox_create_sends_gpu_count_request() { "openshell", &[], true, - true, - Some(2), + Some(gpu_requirements(Some(2))), None, None, None, @@ -1073,7 +1073,6 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { "openshell", &[], true, - false, None, None, None, @@ -1132,7 +1131,6 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { "openshell", &[], true, - false, None, None, None, @@ -1187,7 +1185,6 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { "openshell", &[], true, - false, None, None, None, @@ -1234,7 +1231,6 @@ async fn sandbox_create_times_out_when_only_logs_arrive() { "openshell", &[], true, - false, None, None, None, @@ -1277,7 +1273,6 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { "openshell", &[], false, - false, None, None, None, @@ -1324,7 +1319,6 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { "openshell", &[], false, - false, None, None, None, @@ -1371,7 +1365,6 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { "openshell", &[], true, - false, None, None, None, @@ -1418,7 +1411,6 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { "openshell", &[], false, - false, None, None, None, @@ -1461,7 +1453,6 @@ async fn sandbox_create_sends_environment_variables() { "openshell", &[], true, - false, None, None, None, From 579c163bcc6b754fd90457cbc34fc713937fd1ef Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 11 Jun 2026 16:45:49 +0200 Subject: [PATCH 3/6] refactor(gpu): pass requirements to timeout message Pass ResourceRequirements into the provisioning timeout message helper so GPU hints are derived from the same nested request object used to create the sandbox. Signed-off-by: Evan Lezar --- crates/openshell-cli/src/run.rs | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 50fbab85a..d242434fb 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -121,7 +121,7 @@ fn ready_false_condition_message( fn provisioning_timeout_message( timeout_secs: u64, - requested_gpu: bool, + resource_requirements: Option<&ResourceRequirements>, condition_message: Option<&str>, ) -> String { let mut message = format!("sandbox provisioning timed out after {timeout_secs}s"); @@ -131,7 +131,7 @@ fn provisioning_timeout_message( message.push_str(condition_message); } - if requested_gpu { + if resource_requirements.is_some_and(|requirements| requirements.gpu.is_some()) { message.push_str( ". Hint: this may be because the available GPU is already in use by another sandbox.", ); @@ -1961,7 +1961,7 @@ pub async fn sandbox_create( if remaining.is_zero() { let timeout_message = provisioning_timeout_message( provision_timeout.as_secs(), - requested_gpu, + resource_requirements.as_ref(), last_condition_message.as_deref(), ); if let Some(d) = display.as_mut() { @@ -1980,7 +1980,7 @@ pub async fn sandbox_create( // Timeout fired — the stream was idle for too long. let timeout_message = provisioning_timeout_message( provision_timeout.as_secs(), - requested_gpu, + resource_requirements.as_ref(), last_condition_message.as_deref(), ); if let Some(d) = display.as_mut() { @@ -7596,9 +7596,10 @@ mod tests { PROGRESS_STEP_STARTING_SANDBOX, }; use openshell_core::proto::{ - Provider, ProviderCredentialRefresh, ProviderCredentialRefreshStatus, - ProviderCredentialRefreshStrategy, ProviderCredentialTokenGrant, ProviderProfile, - ProviderProfileCredential, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta, + GpuResourceRequirements, Provider, ProviderCredentialRefresh, + ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, + ProviderCredentialTokenGrant, ProviderProfile, ProviderProfileCredential, + ResourceRequirements, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta, }; struct EnvVarGuard { @@ -8286,9 +8287,12 @@ mod tests { #[test] fn provisioning_timeout_message_includes_condition_and_gpu_hint() { + let resource_requirements = ResourceRequirements { + gpu: Some(GpuResourceRequirements { count: None }), + }; let message = provisioning_timeout_message( 120, - true, + Some(&resource_requirements), Some("DependenciesNotReady: Pod exists with phase: Pending; Service Exists"), ); @@ -8299,7 +8303,15 @@ mod tests { #[test] fn provisioning_timeout_message_omits_gpu_hint_for_non_gpu_requests() { - let message = provisioning_timeout_message(120, false, None); + let message = provisioning_timeout_message(120, None, None); + + assert_eq!(message, "sandbox provisioning timed out after 120s"); + } + + #[test] + fn provisioning_timeout_message_omits_gpu_hint_without_gpu_requirements() { + let resource_requirements = ResourceRequirements { gpu: None }; + let message = provisioning_timeout_message(120, Some(&resource_requirements), None); assert_eq!(message, "sandbox provisioning timed out after 120s"); } From f2569faa91368da312bdb864c76e2f8bcad4b6ab Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 11 Jun 2026 17:18:15 +0200 Subject: [PATCH 4/6] refactor(gpu): pass driver requirements through helpers Thread Option through driver validation and rendering helpers instead of splitting GPU presence and count into separate arguments. Signed-off-by: Evan Lezar --- crates/openshell-core/src/gpu.rs | 66 +++++++++------ crates/openshell-driver-docker/src/lib.rs | 37 ++++----- .../openshell-driver-kubernetes/src/driver.rs | 83 ++++++++++--------- .../openshell-driver-podman/src/container.rs | 8 +- crates/openshell-driver-podman/src/driver.rs | 29 +++---- crates/openshell-driver-vm/src/driver.rs | 23 ++--- 6 files changed, 126 insertions(+), 120 deletions(-) diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 8c25f47b5..a9bbc87ae 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -1,58 +1,66 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Shared GPU request helpers. +//! Shared GPU resource requirement helpers. use crate::config::CDI_GPU_DEVICE_ALL; -use crate::proto::ResourceRequirements as PublicResourceRequirements; -use crate::proto::compute::v1::ResourceRequirements as DriverResourceRequirements; +use crate::proto::ResourceRequirements as SandboxResourceRequirements; +use crate::proto::compute::v1::{ + GpuResourceRequirements as DriverGpuResourceRequirements, + ResourceRequirements as DriverResourceRequirements, +}; -/// Return whether public resource requirements request a GPU. +/// Return whether sandbox resource requirements request a GPU. #[must_use] -pub fn public_gpu_requested(resources: Option<&PublicResourceRequirements>) -> bool { +pub fn public_gpu_requested(resources: Option<&SandboxResourceRequirements>) -> bool { resources .and_then(|resources| resources.gpu.as_ref()) .is_some() } -/// Return the requested public GPU count, if one was specified. +/// Return the requested sandbox GPU count, if one was specified. #[must_use] -pub fn public_gpu_count(resources: Option<&PublicResourceRequirements>) -> Option { +pub fn public_gpu_count(resources: Option<&SandboxResourceRequirements>) -> Option { resources .and_then(|resources| resources.gpu.as_ref()) .and_then(|gpu| gpu.count) } -/// Return whether driver resource requirements request a GPU. +/// Return whether compute-driver resource requirements request a GPU. #[must_use] pub fn driver_gpu_requested(resources: Option<&DriverResourceRequirements>) -> bool { - resources - .and_then(|resources| resources.gpu.as_ref()) - .is_some() + driver_gpu_requirements(resources).is_some() } -/// Return the requested driver GPU count, if one was specified. +/// Return the requested compute-driver GPU count, if one was specified. #[must_use] pub fn driver_gpu_count(resources: Option<&DriverResourceRequirements>) -> Option { - resources - .and_then(|resources| resources.gpu.as_ref()) - .and_then(|gpu| gpu.count) + driver_gpu_requirements(resources).and_then(|gpu| gpu.count) } -/// Resolve a GPU request into CDI device identifiers. +/// Return the requested compute-driver GPU requirements, if present. +#[must_use] +pub fn driver_gpu_requirements( + resources: Option<&DriverResourceRequirements>, +) -> Option<&DriverGpuResourceRequirements> { + resources.and_then(|resources| resources.gpu.as_ref()) +} + +/// Resolve a compute-driver GPU request into CDI device identifiers. /// /// `None` means no GPU was requested. A GPU request with no explicit CDI /// devices uses the CDI all-GPU request; otherwise the driver-configured CDI /// devices pass through unchanged. #[must_use] -pub fn cdi_gpu_device_ids(gpu: bool, cdi_devices: &[String]) -> Option> { - gpu.then(|| { - if cdi_devices.is_empty() { - vec![CDI_GPU_DEVICE_ALL.to_string()] - } else { - cdi_devices.to_vec() - } - }) +pub fn cdi_gpu_device_ids( + gpu: Option<&DriverGpuResourceRequirements>, + cdi_devices: &[String], +) -> Option> { + match gpu { + Some(_) if cdi_devices.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), + Some(_) => Some(cdi_devices.to_vec()), + None => None, + } } #[cfg(test)] @@ -61,22 +69,26 @@ mod tests { #[test] fn cdi_gpu_device_ids_returns_none_when_absent() { - assert_eq!(cdi_gpu_device_ids(false, &[]), None); + assert_eq!(cdi_gpu_device_ids(None, &[]), None); } #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { + let gpu = DriverGpuResourceRequirements { count: None }; + assert_eq!( - cdi_gpu_device_ids(true, &[]), + cdi_gpu_device_ids(Some(&gpu), &[]), Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) ); } #[test] fn cdi_gpu_device_ids_passes_explicit_device_ids_through() { + let gpu = DriverGpuResourceRequirements { count: None }; + assert_eq!( cdi_gpu_device_ids( - true, + Some(&gpu), &[ "nvidia.com/gpu=0".to_string(), "nvidia.com/gpu=1".to_string() diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 036f2683d..5138c34d3 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -27,7 +27,7 @@ use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME, LABEL_SANDBOX_NAMESPACE, SUPERVISOR_IMAGE_BINARY_PATH, supervisor_image_should_refresh, }; -use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_count, driver_gpu_requested}; +use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_requirements}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -36,11 +36,11 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition, DriverPlatformEvent, DriverSandbox, DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, - GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, - StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, - WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, - WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, - watch_sandboxes_event, + GetSandboxResponse, GpuResourceRequirements, ListSandboxesRequest, ListSandboxesResponse, + StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, + ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::proto_struct::{ deserialize_optional_non_empty_string_list, struct_to_json_value, @@ -461,14 +461,8 @@ impl DockerComputeDriver { let driver_config = DockerSandboxDriverConfig::from_template(template).map_err(Status::invalid_argument)?; - let gpu_requested = driver_gpu_requested(spec.resource_requirements.as_ref()); - let gpu_count = driver_gpu_count(spec.resource_requirements.as_ref()); - Self::validate_gpu_request( - gpu_requested, - gpu_count, - config.supports_gpu, - &driver_config, - )?; + let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); + Self::validate_gpu_request(gpu_requirements, config.supports_gpu, &driver_config)?; Ok(()) } @@ -516,24 +510,23 @@ impl DockerComputeDriver { } fn validate_gpu_request( - gpu: bool, - gpu_count: Option, + gpu_requirements: Option<&GpuResourceRequirements>, supports_gpu: bool, driver_config: &DockerSandboxDriverConfig, ) -> Result<(), Status> { - if !gpu && driver_config.cdi_devices.is_some() { + if gpu_requirements.is_none() && driver_config.cdi_devices.is_some() { return Err(Status::invalid_argument( "driver_config.cdi_devices requires gpu=true", )); } - if gpu_count.is_some() { + if gpu_requirements.and_then(|gpu| gpu.count).is_some() { return Err(Status::invalid_argument( "docker GPU count requests are not supported; use --gpu without a count or driver_config.cdi_devices", )); } - if gpu && !supports_gpu { + if gpu_requirements.is_some() && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", )); @@ -2135,15 +2128,15 @@ fn build_device_requests(sandbox: &DriverSandbox) -> Result Result<(), tonic::Status> { let _ = KubernetesSandboxDriverConfig::from_sandbox(sandbox) .map_err(tonic::Status::invalid_argument)?; - let gpu_requested = validate_gpu_resource_requirements(sandbox)?; - if gpu_requested + let gpu_requirements = sandbox + .spec + .as_ref() + .and_then(|spec| driver_gpu_requirements(spec.resource_requirements.as_ref())); + validate_gpu_request(gpu_requirements)?; + if gpu_requirements.is_some() && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) })? @@ -377,7 +382,11 @@ impl KubernetesComputeDriver { pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { let _ = KubernetesSandboxDriverConfig::from_sandbox(sandbox) .map_err(KubernetesDriverError::InvalidArgument)?; - validate_gpu_resource_requirements(sandbox).map_err(|status| { + let gpu_requirements = sandbox + .spec + .as_ref() + .and_then(|spec| driver_gpu_requirements(spec.resource_requirements.as_ref())); + validate_gpu_request(gpu_requirements).map_err(|status| { KubernetesDriverError::InvalidArgument(status.message().to_string()) })?; let name = sandbox.name.as_str(); @@ -642,22 +651,16 @@ impl KubernetesComputeDriver { } } -fn validate_gpu_resource_requirements(sandbox: &Sandbox) -> Result { - let Some(resource_requirements) = sandbox - .spec - .as_ref() - .and_then(|spec| spec.resource_requirements.as_ref()) - else { - return Ok(false); - }; - - if driver_gpu_count(Some(resource_requirements)) == Some(0) { +fn validate_gpu_request( + gpu_requirements: Option<&GpuResourceRequirements>, +) -> Result<(), tonic::Status> { + if gpu_requirements.and_then(|gpu| gpu.count) == Some(0) { return Err(tonic::Status::invalid_argument( "gpu count must be greater than 0", )); } - Ok(driver_gpu_requested(Some(resource_requirements))) + Ok(()) } fn sandbox_labels(sandbox: &Sandbox) -> BTreeMap { @@ -1223,10 +1226,9 @@ fn sandbox_to_k8s_spec( if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s_with_gpu_count( + sandbox_template_to_k8s_with_gpu_requirements( template, - driver_gpu_requested(spec.resource_requirements.as_ref()), - driver_gpu_count(spec.resource_requirements.as_ref()), + driver_gpu_requirements(spec.resource_requirements.as_ref()), &pod_env, inject_workspace, params, @@ -1260,12 +1262,9 @@ fn sandbox_to_k8s_spec( let pod_env = spec_pod_env(spec); root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s_with_gpu_count( + sandbox_template_to_k8s_with_gpu_requirements( &SandboxTemplate::default(), - spec.and_then(|s| s.resource_requirements.as_ref()) - .is_some_and(|requirements| driver_gpu_requested(Some(requirements))), - spec.and_then(|s| s.resource_requirements.as_ref()) - .and_then(|requirements| driver_gpu_count(Some(requirements))), + driver_gpu_requirements(spec.and_then(|s| s.resource_requirements.as_ref())), &pod_env, inject_workspace, params, @@ -1286,20 +1285,19 @@ fn sandbox_template_to_k8s( inject_workspace: bool, params: &SandboxPodParams<'_>, ) -> serde_json::Value { - sandbox_template_to_k8s_with_gpu_count( + let gpu_requirements = gpu.then_some(GpuResourceRequirements { count: None }); + sandbox_template_to_k8s_with_gpu_requirements( template, - gpu, - None, + gpu_requirements.as_ref(), spec_environment, inject_workspace, params, ) } -fn sandbox_template_to_k8s_with_gpu_count( +fn sandbox_template_to_k8s_with_gpu_requirements( template: &SandboxTemplate, - gpu: bool, - gpu_count: Option, + gpu_requirements: Option<&GpuResourceRequirements>, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, @@ -1382,7 +1380,7 @@ fn sandbox_template_to_k8s_with_gpu_count( if use_user_namespaces { spec.insert("hostUsers".to_string(), serde_json::json!(false)); - if gpu { + if gpu_requirements.is_some() { warn!( "GPU sandbox with user namespaces enabled — \ NVIDIA device plugin compatibility is unverified" @@ -1491,7 +1489,7 @@ fn sandbox_template_to_k8s_with_gpu_count( serde_json::Value::Array(volume_mounts), ); - if let Some(resources) = container_resources(template, gpu, gpu_count) { + if let Some(resources) = container_resources(template, gpu_requirements) { container.insert("resources".to_string(), resources); } apply_agent_driver_resources(&mut container, &driver_config.containers.agent.resources); @@ -1671,8 +1669,7 @@ fn app_armor_profile_to_k8s(profile: &AppArmorProfile) -> serde_json::Value { fn container_resources( template: &SandboxTemplate, - gpu: bool, - gpu_count: Option, + gpu_requirements: Option<&GpuResourceRequirements>, ) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API @@ -1706,8 +1703,8 @@ fn container_resources( apply("requests", "memory", memory_request); } - if gpu { - let quantity = gpu_count.map_or_else( + if let Some(gpu) = gpu_requirements { + let quantity = gpu.count.map_or_else( || GPU_RESOURCE_QUANTITY.to_string(), |count| count.to_string(), ); @@ -2085,7 +2082,11 @@ mod tests { ..Sandbox::default() }; - let err = validate_gpu_resource_requirements(&sandbox).unwrap_err(); + let gpu_requirements = sandbox + .spec + .as_ref() + .and_then(|spec| driver_gpu_requirements(spec.resource_requirements.as_ref())); + let err = validate_gpu_request(gpu_requirements).unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); assert!(err.message().contains("gpu count must be greater than 0")); } @@ -2423,10 +2424,10 @@ mod tests { fn gpu_count_sandbox_adds_requested_gpu_limit() { let pod_template = { let params = SandboxPodParams::default(); - sandbox_template_to_k8s_with_gpu_count( + let gpu_requirements = GpuResourceRequirements { count: Some(2) }; + sandbox_template_to_k8s_with_gpu_requirements( &SandboxTemplate::default(), - true, - Some(2), + Some(&gpu_requirements), &std::collections::HashMap::new(), true, ¶ms, diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index f14c01934..ad42fb42c 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -5,7 +5,7 @@ use crate::config::PodmanComputeConfig; use openshell_core::ComputeDriverError; -use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_requested}; +use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_requirements}; use openshell_core::proto::compute::v1::{DriverSandbox, DriverSandboxTemplate}; use openshell_core::proto_struct::deserialize_optional_non_empty_string_list; use openshell_core::{driver_mounts, proto_struct}; @@ -484,15 +484,15 @@ fn build_devices(sandbox: &DriverSandbox) -> Result>, Co let cdi_devices = PodmanSandboxDriverConfig::from_sandbox(sandbox)? .cdi_devices .unwrap_or_default(); - let gpu_requested = driver_gpu_requested(spec.resource_requirements.as_ref()); - if !gpu_requested && !cdi_devices.is_empty() { + let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); + if gpu_requirements.is_none() && !cdi_devices.is_empty() { return Err(ComputeDriverError::InvalidArgument( "driver_config.cdi_devices requires gpu=true".to_string(), )); } Ok( - cdi_gpu_device_ids(gpu_requested, &cdi_devices).map(|device_ids| { + cdi_gpu_device_ids(gpu_requirements, &cdi_devices).map(|device_ids| { device_ids .into_iter() .map(|path| LinuxDevice { path }) diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index 2fbe8292d..90f24c0de 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -11,8 +11,10 @@ use crate::watcher::{ }; use openshell_core::ComputeDriverError; use openshell_core::driver_utils::supervisor_image_should_refresh; -use openshell_core::gpu::{driver_gpu_count, driver_gpu_requested}; -use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; +use openshell_core::gpu::driver_gpu_requirements; +use openshell_core::proto::compute::v1::{ + DriverSandbox, GetCapabilitiesResponse, GpuResourceRequirements, +}; use std::path::PathBuf; use std::time::Duration; use tracing::{info, warn}; @@ -282,39 +284,33 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu_requested = sandbox - .spec - .as_ref() - .and_then(|spec| spec.resource_requirements.as_ref()) - .is_some_and(|requirements| driver_gpu_requested(Some(requirements))); - let gpu_count = sandbox + let gpu_requirements = sandbox .spec .as_ref() .and_then(|spec| spec.resource_requirements.as_ref()) - .and_then(|requirements| driver_gpu_count(Some(requirements))); + .and_then(|requirements| driver_gpu_requirements(Some(requirements))); let driver_config = PodmanSandboxDriverConfig::from_sandbox(sandbox)?; - if !gpu_requested && driver_config.cdi_devices.is_some() { + if gpu_requirements.is_none() && driver_config.cdi_devices.is_some() { return Err(ComputeDriverError::InvalidArgument( "driver_config.cdi_devices requires gpu=true".to_string(), )); } - Self::validate_gpu_request(gpu_requested, gpu_count)?; + Self::validate_gpu_request(gpu_requirements)?; self.validate_user_volume_mounts_available(sandbox).await?; Ok(()) } fn validate_gpu_request( - gpu_requested: bool, - gpu_count: Option, + gpu_requirements: Option<&GpuResourceRequirements>, ) -> Result<(), ComputeDriverError> { - if gpu_count.is_some() { + if gpu_requirements.and_then(|gpu| gpu.count).is_some() { return Err(ComputeDriverError::InvalidArgument( "podman GPU count requests are not supported; use --gpu without a count or driver_config.cdi_devices" .to_string(), )); } - if gpu_requested && !Self::has_gpu_capacity() { + if gpu_requirements.is_some() && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), )); @@ -715,7 +711,8 @@ mod tests { #[test] fn validate_gpu_request_rejects_gpu_count() { - let err = PodmanComputeDriver::validate_gpu_request(true, Some(2)) + let gpu = GpuResourceRequirements { count: Some(2) }; + let err = PodmanComputeDriver::validate_gpu_request(Some(&gpu)) .expect_err("gpu count should be rejected"); assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index de34cb507..4cd84a9be 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -29,7 +29,7 @@ use oci_client::manifest::{ }; use oci_client::secrets::RegistryAuth; use oci_client::{Reference, RegistryOperation}; -use openshell_core::gpu::{driver_gpu_count, driver_gpu_requested}; +use openshell_core::gpu::driver_gpu_requirements; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -632,7 +632,8 @@ impl VmDriver { .spec .as_ref() .and_then(|spec| spec.resource_requirements.as_ref()) - .is_some_and(|requirements| driver_gpu_requested(Some(requirements))); + .and_then(|requirements| driver_gpu_requirements(Some(requirements))) + .is_some(); self.publish_platform_event( sandbox.id.clone(), platform_event( @@ -3084,7 +3085,7 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu if let Some(template) = spec.template.as_ref() { validate_vm_sandbox_template(template)?; } - validate_vm_gpu_request(sandbox, gpu_enabled)?; + validate_gpu_request(sandbox, gpu_enabled)?; Ok(()) } @@ -3105,14 +3106,14 @@ fn validate_vm_sandbox_template(template: &SandboxTemplate) -> Result<(), Status } #[allow(clippy::result_large_err)] -fn validate_vm_gpu_request(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Status> { +fn validate_gpu_request(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Status> { let spec = sandbox .spec .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - let gpu_requested = driver_gpu_requested(spec.resource_requirements.as_ref()); - let gpu_count = driver_gpu_count(spec.resource_requirements.as_ref()); + let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); + let gpu_count = gpu_requirements.and_then(|gpu| gpu.count); if gpu_count == Some(0) { return Err(Status::invalid_argument("gpu count must be greater than 0")); @@ -3125,7 +3126,7 @@ fn validate_vm_gpu_request(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), S } let _ = vm_gpu_device_id(sandbox)?; - if gpu_requested && !gpu_enabled { + if gpu_requirements.is_some() && !gpu_enabled { return Err(Status::failed_precondition( "GPU support is not enabled on this driver; start with --gpu", )); @@ -3142,8 +3143,8 @@ fn vm_gpu_device_id(sandbox: &Sandbox) -> Result, Status> { .map_err(Status::invalid_argument)? .gpu_device_ids .unwrap_or_default(); - let gpu_requested = driver_gpu_requested(spec.resource_requirements.as_ref()); - if !gpu_requested && !gpu_device_ids.is_empty() { + let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); + if gpu_requirements.is_none() && !gpu_device_ids.is_empty() { return Err(Status::invalid_argument( "driver_config.gpu_device_ids requires gpu=true", )); @@ -3154,7 +3155,9 @@ fn vm_gpu_device_id(sandbox: &Sandbox) -> Result, Status> { )); } - Ok(gpu_requested.then(|| gpu_device_ids.into_iter().next().unwrap_or_default())) + Ok(gpu_requirements + .is_some() + .then(|| gpu_device_ids.into_iter().next().unwrap_or_default())) } #[allow(clippy::result_large_err)] From da6fbd8752372c86a834d59edb238d61d38dc9b7 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 11 Jun 2026 18:42:29 +0200 Subject: [PATCH 5/6] fix(gpu): validate exact device requests Require exact driver GPU device lists to be tied to a GPU request, allow a single exact device to use the default countless request, and require explicit matching counts for multi-device lists. Signed-off-by: Evan Lezar --- crates/openshell-core/src/gpu.rs | 129 ++++++++++++++++ crates/openshell-driver-docker/README.md | 2 +- crates/openshell-driver-docker/src/lib.rs | 35 +++-- crates/openshell-driver-docker/src/tests.rs | 124 +++++++++++++++- crates/openshell-driver-podman/README.md | 2 +- .../openshell-driver-podman/src/container.rs | 72 ++++++++- crates/openshell-driver-podman/src/driver.rs | 126 ++++++++++++++-- crates/openshell-driver-vm/src/driver.rs | 139 ++++++++++++++++-- docs/reference/sandbox-compute-drivers.mdx | 8 +- docs/sandboxes/manage-sandboxes.mdx | 7 +- 10 files changed, 583 insertions(+), 61 deletions(-) diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index a9bbc87ae..fb4927aac 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -63,6 +63,47 @@ pub fn cdi_gpu_device_ids( } } +/// Validate a compute-driver GPU request against driver-owned specific devices. +/// +/// Drivers call this when a sandbox request combines portable GPU requirements +/// with exact device identifiers in `driver_config`. +/// +/// # Errors +/// Returns an error when the sandbox GPU request is absent or when `gpu.count` +/// does not equal the number of specific devices. A single exact device is +/// compatible with the default sandbox GPU request where `gpu.count` is absent. +pub fn validate_specific_gpu_device_request( + gpu: Option<&DriverGpuResourceRequirements>, + specific_devices: &[String], + driver_config_field: &str, +) -> Result<(), String> { + let device_count = specific_devices.len(); + if device_count == 0 { + return Ok(()); + } + + let Some(gpu) = gpu else { + return Err(format!("{driver_config_field} requires a gpu request")); + }; + + let Some(count) = gpu.count else { + if device_count == 1 { + return Ok(()); + } + return Err(format!( + "{driver_config_field} requires an explicit gpu count matching its length ({device_count})" + )); + }; + + if usize::try_from(count).ok() != Some(device_count) { + return Err(format!( + "gpu count ({count}) must match {driver_config_field} length ({device_count})" + )); + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -100,4 +141,92 @@ mod tests { ]) ); } + + #[test] + fn validate_specific_gpu_device_request_ignores_empty_devices() { + validate_specific_gpu_device_request(None, &[], "driver_config.cdi_devices") + .expect("empty exact device lists should not be validated"); + } + + #[test] + fn validate_specific_gpu_device_request_accepts_matching_count() { + let gpu = DriverGpuResourceRequirements { count: Some(2) }; + let specific_devices = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + + validate_specific_gpu_device_request( + Some(&gpu), + &specific_devices, + "driver_config.cdi_devices", + ) + .expect("matching count should be accepted"); + } + + #[test] + fn validate_specific_gpu_device_request_accepts_missing_count_for_one_device() { + let gpu = DriverGpuResourceRequirements { count: None }; + let specific_devices = vec!["nvidia.com/gpu=0".to_string()]; + + validate_specific_gpu_device_request( + Some(&gpu), + &specific_devices, + "driver_config.cdi_devices", + ) + .expect("single exact device should be compatible with a default GPU request"); + } + + #[test] + fn validate_specific_gpu_device_request_rejects_missing_gpu_request() { + let specific_devices = vec!["nvidia.com/gpu=0".to_string()]; + + let err = validate_specific_gpu_device_request( + None, + &specific_devices, + "driver_config.cdi_devices", + ) + .expect_err("missing GPU request should be rejected"); + + assert_eq!(err, "driver_config.cdi_devices requires a gpu request"); + } + + #[test] + fn validate_specific_gpu_device_request_rejects_missing_count_for_multiple_devices() { + let gpu = DriverGpuResourceRequirements { count: None }; + let specific_devices = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + + let err = validate_specific_gpu_device_request( + Some(&gpu), + &specific_devices, + "driver_config.cdi_devices", + ) + .expect_err("missing count should be rejected for multiple devices"); + + assert_eq!( + err, + "driver_config.cdi_devices requires an explicit gpu count matching its length (2)" + ); + } + + #[test] + fn validate_specific_gpu_device_request_rejects_mismatch() { + let gpu = DriverGpuResourceRequirements { count: Some(2) }; + let specific_devices = vec!["nvidia.com/gpu=0".to_string()]; + + let err = validate_specific_gpu_device_request( + Some(&gpu), + &specific_devices, + "driver_config.cdi_devices", + ) + .expect_err("mismatched count should be rejected"); + + assert_eq!( + err, + "gpu count (2) must match driver_config.cdi_devices length (1)" + ); + } } diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index b8e244a13..d16f53456 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses `driver_config.cdi_devices` when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. Counted GPU requests are rejected. | +| CDI GPU request | Uses `driver_config.cdi_devices` when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. Count-only GPU requests are rejected; exact CDI device lists with more than one entry require an explicit GPU count matching the device list length. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 5138c34d3..f77205e6b 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -27,7 +27,9 @@ use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME, LABEL_SANDBOX_NAMESPACE, SUPERVISOR_IMAGE_BINARY_PATH, supervisor_image_should_refresh, }; -use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_requirements}; +use openshell_core::gpu::{ + cdi_gpu_device_ids, driver_gpu_requirements, validate_specific_gpu_device_request, +}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -514,23 +516,25 @@ impl DockerComputeDriver { supports_gpu: bool, driver_config: &DockerSandboxDriverConfig, ) -> Result<(), Status> { - if gpu_requirements.is_none() && driver_config.cdi_devices.is_some() { - return Err(Status::invalid_argument( - "driver_config.cdi_devices requires gpu=true", + if gpu_requirements.is_some() && !supports_gpu { + return Err(Status::failed_precondition( + "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", )); } - if gpu_requirements.and_then(|gpu| gpu.count).is_some() { + if let Some(cdi_devices) = driver_config.cdi_devices.as_deref() { + validate_specific_gpu_device_request( + gpu_requirements, + cdi_devices, + "driver_config.cdi_devices", + ) + .map_err(Status::invalid_argument)?; + } else if gpu_requirements.and_then(|gpu| gpu.count).is_some() { return Err(Status::invalid_argument( "docker GPU count requests are not supported; use --gpu without a count or driver_config.cdi_devices", )); } - if gpu_requirements.is_some() && !supports_gpu { - return Err(Status::failed_precondition( - "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", - )); - } Ok(()) } @@ -2129,11 +2133,12 @@ fn build_device_requests(sandbox: &DriverSandbox) -> Result Result>, Co .cdi_devices .unwrap_or_default(); let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); - if gpu_requirements.is_none() && !cdi_devices.is_empty() { - return Err(ComputeDriverError::InvalidArgument( - "driver_config.cdi_devices requires gpu=true".to_string(), - )); - } + validate_specific_gpu_device_request( + gpu_requirements, + &cdi_devices, + "driver_config.cdi_devices", + ) + .map_err(ComputeDriverError::InvalidArgument)?; Ok( cdi_gpu_device_ids(gpu_requirements, &cdi_devices).map(|device_ids| { @@ -1289,7 +1292,60 @@ mod tests { } #[test] - fn container_spec_rejects_cdi_devices_without_gpu() { + fn container_spec_accepts_gpu_count_matching_cdi_devices() { + use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), + template: Some(DriverSandboxTemplate { + driver_config: Some(cdi_devices_config(&[ + "nvidia.com/gpu=0", + "nvidia.com/gpu=1", + ])), + ..Default::default() + }), + ..Default::default() + }); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert_eq!(spec["devices"].as_array().map(Vec::len), Some(2)); + assert_eq!( + spec["devices"][0]["path"].as_str(), + Some("nvidia.com/gpu=0") + ); + assert_eq!( + spec["devices"][1]["path"].as_str(), + Some("nvidia.com/gpu=1") + ); + } + + #[test] + fn container_spec_rejects_gpu_count_mismatched_cdi_devices() { + use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), + template: Some(DriverSandboxTemplate { + driver_config: Some(cdi_devices_config(&["nvidia.com/gpu=0"])), + ..Default::default() + }), + ..Default::default() + }); + let config = test_config(); + + let err = try_build_container_spec_with_token(&sandbox, &config, None).unwrap_err(); + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!( + err.to_string() + .contains("gpu count (2) must match driver_config.cdi_devices length (1)") + ); + } + + #[test] + fn container_spec_rejects_cdi_devices_without_gpu_request() { use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; let mut sandbox = test_sandbox("test-id", "test-name"); @@ -1304,7 +1360,7 @@ mod tests { let err = try_build_container_spec_with_token(&sandbox, &config, None).unwrap_err(); assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); - assert!(err.to_string().contains("requires gpu=true")); + assert!(err.to_string().contains("requires a gpu request")); } #[test] diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index 90f24c0de..17864c941 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -11,7 +11,7 @@ use crate::watcher::{ }; use openshell_core::ComputeDriverError; use openshell_core::driver_utils::supervisor_image_should_refresh; -use openshell_core::gpu::driver_gpu_requirements; +use openshell_core::gpu::{driver_gpu_requirements, validate_specific_gpu_device_request}; use openshell_core::proto::compute::v1::{ DriverSandbox, GetCapabilitiesResponse, GpuResourceRequirements, }; @@ -290,31 +290,48 @@ impl PodmanComputeDriver { .and_then(|spec| spec.resource_requirements.as_ref()) .and_then(|requirements| driver_gpu_requirements(Some(requirements))); let driver_config = PodmanSandboxDriverConfig::from_sandbox(sandbox)?; - if gpu_requirements.is_none() && driver_config.cdi_devices.is_some() { - return Err(ComputeDriverError::InvalidArgument( - "driver_config.cdi_devices requires gpu=true".to_string(), - )); - } - Self::validate_gpu_request(gpu_requirements)?; + let cdi_devices = driver_config.cdi_devices.as_deref(); + Self::validate_gpu_request(gpu_requirements, cdi_devices)?; self.validate_user_volume_mounts_available(sandbox).await?; Ok(()) } fn validate_gpu_request( gpu_requirements: Option<&GpuResourceRequirements>, + cdi_devices: Option<&[String]>, + ) -> Result<(), ComputeDriverError> { + Self::validate_gpu_request_with_capacity( + gpu_requirements, + cdi_devices, + Self::has_gpu_capacity(), + ) + } + + fn validate_gpu_request_with_capacity( + gpu_requirements: Option<&GpuResourceRequirements>, + cdi_devices: Option<&[String]>, + has_gpu_capacity: bool, ) -> Result<(), ComputeDriverError> { - if gpu_requirements.and_then(|gpu| gpu.count).is_some() { + if gpu_requirements.is_some() && !has_gpu_capacity { + return Err(ComputeDriverError::Precondition( + "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), + )); + } + + if let Some(cdi_devices) = cdi_devices { + validate_specific_gpu_device_request( + gpu_requirements, + cdi_devices, + "driver_config.cdi_devices", + ) + .map_err(ComputeDriverError::InvalidArgument)?; + } else if gpu_requirements.and_then(|gpu| gpu.count).is_some() { return Err(ComputeDriverError::InvalidArgument( "podman GPU count requests are not supported; use --gpu without a count or driver_config.cdi_devices" .to_string(), )); } - if gpu_requirements.is_some() && !Self::has_gpu_capacity() { - return Err(ComputeDriverError::Precondition( - "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), - )); - } Ok(()) } @@ -712,7 +729,7 @@ mod tests { #[test] fn validate_gpu_request_rejects_gpu_count() { let gpu = GpuResourceRequirements { count: Some(2) }; - let err = PodmanComputeDriver::validate_gpu_request(Some(&gpu)) + let err = PodmanComputeDriver::validate_gpu_request_with_capacity(Some(&gpu), None, true) .expect_err("gpu count should be rejected"); assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); @@ -722,6 +739,87 @@ mod tests { ); } + #[test] + fn validate_gpu_request_accepts_single_cdi_device_without_gpu_count() { + let gpu = GpuResourceRequirements { count: None }; + let cdi_devices = vec!["nvidia.com/gpu=0".to_string()]; + + PodmanComputeDriver::validate_gpu_request_with_capacity( + Some(&gpu), + Some(&cdi_devices), + true, + ) + .expect("single exact CDI device should pass count validation"); + } + + #[test] + fn validate_gpu_request_rejects_missing_gpu_capacity_before_request_shape() { + let gpu = GpuResourceRequirements { count: Some(2) }; + let cdi_devices = vec!["nvidia.com/gpu=0".to_string()]; + let err = PodmanComputeDriver::validate_gpu_request_with_capacity( + Some(&gpu), + Some(&cdi_devices), + false, + ) + .expect_err("missing GPU capacity should be rejected before request shape"); + + assert!(matches!(err, ComputeDriverError::Precondition(_))); + assert!(err.to_string().contains("no NVIDIA GPU devices")); + } + + #[test] + fn validate_gpu_request_rejects_multiple_cdi_devices_without_gpu_count() { + let gpu = GpuResourceRequirements { count: None }; + let cdi_devices = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + let err = PodmanComputeDriver::validate_gpu_request_with_capacity( + Some(&gpu), + Some(&cdi_devices), + true, + ) + .expect_err("missing CDI device count should be rejected for multiple devices"); + + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!( + err.to_string() + .contains("requires an explicit gpu count matching its length (2)") + ); + } + + #[test] + fn validate_gpu_request_rejects_cdi_devices_without_gpu_request() { + let cdi_devices = vec!["nvidia.com/gpu=0".to_string()]; + let err = PodmanComputeDriver::validate_gpu_request_with_capacity( + None, + Some(&cdi_devices), + false, + ) + .expect_err("missing GPU request should be rejected"); + + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!(err.to_string().contains("requires a gpu request")); + } + + #[test] + fn validate_gpu_request_rejects_mismatched_cdi_device_count() { + let gpu = GpuResourceRequirements { count: Some(2) }; + let cdi_devices = vec!["nvidia.com/gpu=0".to_string()]; + let err = PodmanComputeDriver::validate_gpu_request_with_capacity( + Some(&gpu), + Some(&cdi_devices), + true, + ) + .expect_err("mismatched CDI device count should be rejected"); + + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!( + err.to_string() + .contains("gpu count (2) must match driver_config.cdi_devices length (1)") + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 4cd84a9be..16a9742a5 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -29,7 +29,7 @@ use oci_client::manifest::{ }; use oci_client::secrets::RegistryAuth; use oci_client::{Reference, RegistryOperation}; -use openshell_core::gpu::driver_gpu_requirements; +use openshell_core::gpu::{driver_gpu_requirements, validate_specific_gpu_device_request}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -3115,22 +3115,24 @@ fn validate_gpu_request(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Stat let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); let gpu_count = gpu_requirements.and_then(|gpu| gpu.count); + if gpu_requirements.is_some() && !gpu_enabled { + return Err(Status::failed_precondition( + "GPU support is not enabled on this driver; start with --gpu", + )); + } + if gpu_count == Some(0) { return Err(Status::invalid_argument("gpu count must be greater than 0")); } + let _ = vm_gpu_device_id(sandbox)?; + if gpu_count.is_some_and(|count| count > 1) { return Err(Status::invalid_argument( "VM GPU sandboxes support only one GPU", )); } - let _ = vm_gpu_device_id(sandbox)?; - if gpu_requirements.is_some() && !gpu_enabled { - return Err(Status::failed_precondition( - "GPU support is not enabled on this driver; start with --gpu", - )); - } Ok(()) } @@ -3144,11 +3146,12 @@ fn vm_gpu_device_id(sandbox: &Sandbox) -> Result, Status> { .gpu_device_ids .unwrap_or_default(); let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); - if gpu_requirements.is_none() && !gpu_device_ids.is_empty() { - return Err(Status::invalid_argument( - "driver_config.gpu_device_ids requires gpu=true", - )); - } + validate_specific_gpu_device_request( + gpu_requirements, + &gpu_device_ids, + "driver_config.gpu_device_ids", + ) + .map_err(Status::invalid_argument)?; if gpu_device_ids.len() > 1 { return Err(Status::invalid_argument( "vm driver currently supports at most one gpu_device_ids entry", @@ -5207,6 +5210,26 @@ mod tests { assert!(err.message().contains("GPU support is not enabled")); } + #[test] + fn validate_vm_sandbox_rejects_missing_gpu_support_before_request_shape() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, false) + .expect_err("missing GPU support should be rejected before request shape"); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("GPU support is not enabled")); + } + #[test] fn validate_vm_sandbox_accepts_gpu_when_enabled() { let sandbox = Sandbox { @@ -5233,6 +5256,66 @@ mod tests { validate_vm_sandbox(&sandbox, true).expect("one GPU should be accepted when enabled"); } + #[test] + fn validate_vm_sandbox_accepts_single_gpu_device_without_gpu_count() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(None)), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true) + .expect("single exact GPU device should be compatible with a default GPU request"); + } + + #[test] + fn validate_vm_sandbox_rejects_multiple_gpu_device_ids_without_gpu_count() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(None)), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0", "0000:31:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("multiple GPU device IDs without count should be rejected"); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!( + err.message() + .contains("requires an explicit gpu count matching its length (2)") + ); + } + + #[test] + fn validate_vm_sandbox_accepts_gpu_count_matching_device_id() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(1))), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true) + .expect("matching explicit GPU device count should be accepted"); + } + #[test] fn validate_vm_sandbox_rejects_gpu_count_above_one() { let sandbox = Sandbox { @@ -5250,10 +5333,11 @@ mod tests { } #[test] - fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { + fn validate_vm_sandbox_rejects_gpu_count_mismatched_device_id() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), template: Some(SandboxTemplate { driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), ..Default::default() @@ -5263,9 +5347,32 @@ mod tests { ..Default::default() }; let err = validate_vm_sandbox(&sandbox, true) - .expect_err("gpu_device_ids without gpu should be rejected"); + .expect_err("mismatched explicit GPU device count should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("gpu_device_ids requires gpu=true")); + assert!( + err.message() + .contains("gpu count (2) must match driver_config.gpu_device_ids length (1)") + ); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_device_without_gpu_request() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("gpu_device_ids without a GPU request should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("requires a gpu request")); } #[test] @@ -5273,7 +5380,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - resource_requirements: Some(gpu_resources(None)), + resource_requirements: Some(gpu_resources(Some(2))), template: Some(SandboxTemplate { driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0", "0000:31:00.0"])), ..Default::default() diff --git a/docs/reference/sandbox-compute-drivers.mdx b/docs/reference/sandbox-compute-drivers.mdx index cbadac071..c75ec3438 100644 --- a/docs/reference/sandbox-compute-drivers.mdx +++ b/docs/reference/sandbox-compute-drivers.mdx @@ -54,14 +54,18 @@ openshell sandbox create \ Driver config is for fields without a stable public flag. Prefer `--cpu`, `--memory`, and `--gpu` for supported resource intent. Pass a count to `--gpu` -when the active driver supports counted allocation. +when the active driver supports counted allocation. Docker and Podman reject +count-only GPU selection. If `driver_config` lists more than one exact CDI +device, pass `--gpu COUNT`; the count must match the number of listed devices. +A single exact CDI device is compatible with the default `--gpu` request. Exact GPU device selection remains driver-owned and requires `--gpu`. Docker and Podman accept `cdi_devices`; replace the top-level `docker` key with `podman` when using the Podman driver, for example `{"docker":{"cdi_devices":["nvidia.com/gpu=0"]}}`. The VM driver accepts `gpu_device_ids`, for example `{"vm":{"gpu_device_ids":["0000:2d:00.0"]}}`; -the current VM implementation accepts at most one entry. +the current VM implementation accepts at most one entry and allows either +`--gpu` or `--gpu 1` when `gpu_device_ids` is set. For Kubernetes, `pod.runtime_class_name` maps to PodSpec `runtimeClassName`. It overrides the gateway's configured default runtime class for that sandbox, diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 58851b632..c531e9d15 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -78,8 +78,11 @@ openshell sandbox create --gpu 2 -- claude Support for counted GPU requests is driver-dependent. Kubernetes honors a counted `--gpu` request by setting the `nvidia.com/gpu` limit. Docker and Podman -reject count-based selection. VM gateways accept only one GPU, either through -`--gpu` or `--gpu 1`. +reject count-only selection. If `driver_config` lists more than one exact CDI +device, pass `--gpu COUNT`; the count must match the number of listed devices. +A single exact CDI device is compatible with the default `--gpu` request. VM +gateways accept only one GPU, either through `--gpu` or `--gpu 1`; a single +`gpu_device_ids` entry works with either form. For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the From c7cf9d635895b4802aedee4e9563ce4db1020e47 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Sat, 13 Jun 2026 11:57:35 +0200 Subject: [PATCH 6/6] docs(rfc): document gpu field replacement Signed-off-by: Evan Lezar --- .../README.md | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/rfc/0004-sandbox-resource-requirements/README.md b/rfc/0004-sandbox-resource-requirements/README.md index 01b4319dd..e18c97bb7 100644 --- a/rfc/0004-sandbox-resource-requirements/README.md +++ b/rfc/0004-sandbox-resource-requirements/README.md @@ -69,8 +69,10 @@ tracked separately in issue #1492. - Defining the general driver-specific configuration passthrough API. Issue #1492 tracks that related API surface. - Publishing allocated resource identities in sandbox status. -- Preserving long-term compatibility for `gpu`, `gpu_device`, or a - GPU-specific `gpu_count` request field. +- Preserving alpha-era compatibility for `gpu`, `gpu_device`, or a + GPU-specific `gpu_count` request field. The legacy GPU-specific request + fields are intentionally not carried forward into the API shape this RFC + aims to stabilize. ## Proposal @@ -89,13 +91,22 @@ message SandboxSpec { // Portable resource requirements used by the gateway for driver selection // and by drivers for provisioning. - SandboxResourceRequirements resource_requirements = 11; + SandboxResourceRequirements resource_requirements = 9; - reserved 9, 10; - reserved "gpu", "gpu_device"; + reserved 10; + reserved "gpu_device"; } ``` +The public sandbox API is still alpha. This migration intentionally replaces +the old `bool gpu = 9` field with the typed `resource_requirements = 9` message +instead of reserving the legacy field number. Old live requests and persisted +sandbox records that encode GPU intent through the legacy boolean are not +migrated; callers should use a matching OpenShell CLI/API version and recreate +GPU sandboxes after upgrade when they need the new typed shape. Avoiding +alpha-era reserved fields keeps the proto surface closer to the API intended +for stabilization. + `SandboxTemplate.resources` keeps its existing role as platform-native workload configuration. It may contain Kubernetes-style CPU, memory, and extended resource requests and limits, but it is not the portable resource contract. @@ -551,10 +562,10 @@ message DriverSandboxSpec { string log_level = 1; map environment = 5; DriverSandboxTemplate template = 6; - DriverSandboxResourceRequirements resource_requirements = 11; + DriverSandboxResourceRequirements resource_requirements = 9; - reserved 9, 10; - reserved "gpu", "gpu_device"; + reserved 10; + reserved "gpu_device"; } ``` @@ -562,6 +573,12 @@ Driver-owned resource requirement messages should have the same semantics as the public messages, but live in `compute_driver.proto` to keep the public and internal contracts separated. +The compute-driver API is version-coupled to the gateway in current deployments: +local drivers are launched by the gateway at startup, and the driver proto is +not treated as a public compatibility surface. It follows the same alpha-era +field replacement as the public API rather than preserving transitional GPU +fields. + ### Driver capabilities Replace GPU-specific capability fields with coarse resource capability