Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/storage/azure_storage_blob/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pin-project.workspace = true
serde.workspace = true
serde_json.workspace = true
time.workspace = true
tokio = { workspace = true, features = ["rt", "sync"] }
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tokio dependencies need to be under a feature flag so first party customers can use them.

You can also use the customer configured runtime if you need to (azure_core::get_async_runtime()).


[lints]
workspace = true
Expand Down
60 changes: 60 additions & 0 deletions sdk/storage/azure_storage_blob/src/clients/blob_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,64 @@ impl BlobClient {
self.block_blob_client().upload(content, options).await
}

/// Downloads a blob directly into a caller-provided buffer using the Azure Core pipeline.
///
/// Uses a large initial partition size (default 256MB) so small/medium blobs download in a
/// single HTTP request. Remaining data is downloaded in parallel chunks written directly to
/// the buffer at their correct offsets, with no ordering overhead.
///
/// Returns the number of bytes written to the buffer.
pub async fn managed_download_to(
&self,
buffer: &mut [u8],
options: Option<BlobClientManagedDownloadOptions<'_>>,
) -> Result<usize> {
let options = options.unwrap_or_default();
let parallel = options.parallel.unwrap_or(DEFAULT_DOWNLOAD_TO_PARALLEL);
let initial_partition_size = options
.initial_partition_size
.unwrap_or(DEFAULT_INITIAL_PARTITION_SIZE);
let partition_size = options.partition_size.unwrap_or(DEFAULT_PARTITION_SIZE);

let get_range_options = BlobClientDownloadOptions {
encryption_algorithm: options.encryption_algorithm,
encryption_key: options.encryption_key,
encryption_key_sha256: options.encryption_key_sha256,
if_match: options.if_match,
if_modified_since: options.if_modified_since,
if_none_match: options.if_none_match,
if_tags: options.if_tags,
if_unmodified_since: options.if_unmodified_since,
lease_id: options.lease_id,
range: None,
range_get_content_crc64: options.range_get_content_crc64,
range_get_content_md5: options.range_get_content_md5,
snapshot: options.snapshot,
structured_body_type: options.structured_body_type,
timeout: options.timeout,
version_id: options.version_id,
..Default::default()
};

let client = GeneratedBlobClient {
endpoint: self.endpoint.clone(),
pipeline: self.pipeline.clone(),
version: self.version.clone(),
tracer: self.tracer.clone(),
};
let client = BlobClientDownloadBehavior::new(client, get_range_options);

partitioned_transfer::download_to(
buffer,
options.range,
parallel,
initial_partition_size,
partition_size,
Arc::new(client),
)
.await
}

/// Checks if the blob exists.
///
/// Returns `true` if the blob exists, `false` if the blob does not exist, and propagates all other errors.
Expand All @@ -289,6 +347,8 @@ impl BlobClient {
// unwrap evaluated at compile time
const DEFAULT_PARALLEL: NonZero<usize> = NonZero::new(4).unwrap();
const DEFAULT_PARTITION_SIZE: NonZero<usize> = NonZero::new(4 * 1024 * 1024).unwrap();
const DEFAULT_DOWNLOAD_TO_PARALLEL: NonZero<usize> = NonZero::new(5).unwrap();
const DEFAULT_INITIAL_PARTITION_SIZE: NonZero<usize> = NonZero::new(256 * 1024 * 1024).unwrap();

struct BlobClientDownloadBehavior<'a> {
client: GeneratedBlobClient,
Expand Down
5 changes: 5 additions & 0 deletions sdk/storage/azure_storage_blob/src/models/method_options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ pub struct BlobClientManagedDownloadOptions<'a> {
/// Allows customization of the method call.
pub method_options: ClientMethodOptions<'a>,

/// Optional. Size of the initial download request. A larger value means small/medium blobs
/// can be downloaded in a single request. Only used by `managed_download_to`.
/// A default value will be chosen if none is provided.
pub initial_partition_size: Option<NonZero<usize>>,

/// Optional. Number of concurrent network transfers to maintain for this operation.
/// A default value will be chosen if none is provided.
pub parallel: Option<NonZero<usize>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use bytes::{Bytes, BytesMut};
use futures::{
channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
future::Either,
SinkExt, TryStream,
SinkExt, StreamExt, TryStream,
};

use crate::models::{drains::SequentialBoundedDrain, http_ranges::ContentRange};
Expand Down Expand Up @@ -324,6 +324,275 @@ fn analyze_initial_response(
trait DownloadRangeFuture: Future + Send {}
impl<T: Future + Send> DownloadRangeFuture for T {}

/// A wrapper around a raw pointer to a mutable byte slice that can be sent across threads.
///
/// # Safety
/// The caller must ensure that:
/// - The underlying buffer outlives all `SendSlice` instances
/// - No two `SendSlice` instances cover overlapping memory regions when used concurrently
pub(crate) struct SendSlice {
ptr: *mut u8,
len: usize,
}

// SAFETY: SendSlice is only created from non-overlapping sub-regions of a buffer
// that outlives the spawned tasks. The caller guarantees exclusive access to each region.
unsafe impl Send for SendSlice {}

impl SendSlice {
/// Creates a SendSlice from a raw pointer and length.
///
/// # Safety
/// The caller must ensure the pointer is valid for `len` bytes and that no other
/// `SendSlice` or reference accesses overlapping memory concurrently.
pub(crate) unsafe fn from_raw(ptr: *mut u8, len: usize) -> Self {
Self { ptr, len }
}

/// Returns a mutable slice from the raw pointer.
///
/// # Safety
/// The caller must ensure no other references to this memory region exist concurrently.
pub(crate) unsafe fn as_mut_slice(&mut self) -> &mut [u8] {
std::slice::from_raw_parts_mut(self.ptr, self.len)
}
}

/// Downloads a blob directly into a caller-provided buffer with true parallel chunk writes.
///
/// Uses a large `initial_partition_size` for the first request (so small/medium blobs complete
/// in a single HTTP request), then fills the remainder with `partition_size` chunks using
/// `tokio::spawn` for real OS-thread-level parallelism. Each spawned task streams its response
/// body directly into a non-overlapping sub-slice of the buffer via unsafe `SendSlice`, avoiding
/// intermediate `Bytes` allocations.
///
/// Returns the number of bytes written to the buffer.
pub(crate) async fn download_to<Behavior>(
buffer: &mut [u8],
range: Option<Range<usize>>,
parallel: NonZero<usize>,
initial_partition_size: NonZero<usize>,
partition_size: NonZero<usize>,
client: Arc<Behavior>,
) -> AzureResult<usize>
where
Behavior: PartitionedDownloadBehavior + Send + Sync + 'static,
{
let parallel = parallel.get();
let initial_partition_size = initial_partition_size.get();
let partition_size = partition_size.get();

let max_download_range = range.unwrap_or(0..usize::MAX);
if max_download_range.is_empty() {
return Ok(0);
}

// Initial request uses the large initial_partition_size to probe blob size
// and download small/medium blobs in a single request.
let initial_response = match client
.transfer_range(Some(
max_download_range.start
..min(
max_download_range.end,
max_download_range.start.saturating_add(initial_partition_size),
),
))
.await
{
Ok(response) => response,
Err(err) => match (err.http_status(), max_download_range.start) {
(Some(StatusCode::RequestedRangeNotSatisfiable), 0) => {
client.transfer_range(None).await?
}
_ => Err(err)?,
},
};

// Parse Content-Range to determine total blob size and compute remaining ranges.
let (initial_chunk_len, ranges): (usize, Vec<Range<usize>>) = match initial_response
.headers()
.get_optional_as::<ContentRange, _>(&"content-range".into())?
{
Some(content_range) => match (content_range.range, content_range.total_len) {
(Some(received_range), Some(resource_len)) => {
let initial_chunk_len = received_range.1 - received_range.0;
let remainder_start = received_range.1;
let remainder_end = min(max_download_range.end, resource_len);
let ranges = (remainder_start..remainder_end)
.step_by(partition_size)
.map(|i| i..min(i.saturating_add(partition_size), remainder_end))
.collect();
(initial_chunk_len, ranges)
}
_ => (0, Vec::new()),
},
None => (0, Vec::new()),
};

// If the entire blob was returned in the initial request, stream it directly into the buffer.
if ranges.is_empty() {
let mut body = initial_response.into_body();
let mut write_offset = 0usize;
while let Some(chunk) = body.next().await {
let chunk = chunk?;
let end = write_offset + chunk.len();
if end > buffer.len() {
return Err(azure_core::Error::with_message(
azure_core::error::ErrorKind::Io,
format!(
"Buffer too small: size {} but need at least {} bytes",
buffer.len(),
end
),
));
}
buffer[write_offset..end].copy_from_slice(&chunk);
write_offset += chunk.len();
}
return Ok(write_offset);
}

// Multiple ranges needed: overlap initial body streaming with parallel chunk downloads.
let buffer_ptr = buffer.as_mut_ptr();
let buffer_len = buffer.len();
let mut join_set = tokio::task::JoinSet::new();
let range_start_offset = max_download_range.start;

// Spawn initial body streaming as a concurrent task so it overlaps with chunk downloads.
if initial_chunk_len > buffer_len {
return Err(azure_core::Error::with_message(
azure_core::error::ErrorKind::Io,
format!(
"Buffer too small: size {} but need at least {} bytes",
buffer_len, initial_chunk_len
),
));
}
// SAFETY: initial body writes to buffer[0..initial_chunk_len], which does not overlap
// with any subsequent chunk range. The buffer outlives all spawned tasks because we
// join/shutdown before returning.
let mut initial_slice =
unsafe { SendSlice::from_raw(buffer_ptr, initial_chunk_len) };
let mut initial_body = initial_response.into_body();
join_set.spawn(async move {
// SAFETY: No other task accesses this slice region concurrently.
let slice = unsafe { initial_slice.as_mut_slice() };
let mut written = 0usize;
while let Some(chunk) = initial_body.next().await {
let chunk = chunk?;
slice[written..written + chunk.len()].copy_from_slice(&chunk);
written += chunk.len();
}
Ok::<usize, azure_core::Error>(written)
});

// Lazily spawn chunk download tasks, maintaining at most `parallel` in-flight downloads
// (plus the initial body task above). This avoids creating thousands of suspended futures
// for very large blobs.
let mut range_iter = ranges.into_iter();

// Spawn initial batch of parallel downloads.
for _ in 0..parallel {
match range_iter.next() {
Some(r) => spawn_download_chunk(
&mut join_set,
client.clone(),
r,
buffer_ptr,
buffer_len,
range_start_offset,
)?,
None => break,
}
}

let mut total_written = 0usize;
while let Some(result) = join_set.join_next().await {
match result {
Ok(Ok(written)) => {
total_written += written;
// Spawn next chunk download as a slot frees up.
if let Some(r) = range_iter.next() {
spawn_download_chunk(
&mut join_set,
client.clone(),
r,
buffer_ptr,
buffer_len,
range_start_offset,
)?;
}
}
Ok(Err(e)) => {
join_set.shutdown().await;
return Err(e);
}
Err(join_err) => {
join_set.shutdown().await;
return Err(azure_core::Error::new(
azure_core::error::ErrorKind::Other,
join_err,
));
}
}
}

Ok(total_written)
}

/// Spawns a single chunk download task into the JoinSet.
///
/// # Safety
///
/// The caller must ensure that the buffer region for this range does not overlap
/// with any other concurrently-active task's region, and that the buffer outlives
/// all spawned tasks.
fn spawn_download_chunk<Behavior>(
join_set: &mut tokio::task::JoinSet<AzureResult<usize>>,
client: Arc<Behavior>,
range: Range<usize>,
buffer_ptr: *mut u8,
buffer_len: usize,
range_start_offset: usize,
) -> AzureResult<()>
where
Behavior: PartitionedDownloadBehavior + Send + Sync + 'static,
{
let buf_offset = range.start - range_start_offset;
let chunk_max_len = range.end - range.start;

if buf_offset + chunk_max_len > buffer_len {
return Err(azure_core::Error::with_message(
azure_core::error::ErrorKind::Io,
format!(
"Buffer too small: size {} but download requires {} bytes",
buffer_len,
buf_offset + chunk_max_len
),
));
}

// SAFETY: Each range is non-overlapping and within buffer bounds (checked above).
let mut send_slice =
unsafe { SendSlice::from_raw(buffer_ptr.add(buf_offset), chunk_max_len) };

join_set.spawn(async move {
let response = client.transfer_range(Some(range)).await?;
let mut body = response.into_body();
let mut written = 0usize;
// SAFETY: No other task accesses this slice region concurrently.
let slice = unsafe { send_slice.as_mut_slice() };
while let Some(chunk) = body.next().await {
let chunk = chunk?;
slice[written..written + chunk.len()].copy_from_slice(&chunk);
written += chunk.len();
}
Ok::<usize, azure_core::Error>(written)
});

Ok(())
}

#[cfg(test)]
mod tests {
use std::cmp::min;
Expand Down
Loading