1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

//! Common logic for interacting with remote object stores
use std::{
    fmt::Display,
    ops::{Range, RangeBounds},
};

use super::Result;
use bytes::Bytes;
use futures::{stream::StreamExt, Stream, TryStreamExt};
use snafu::Snafu;

#[cfg(any(feature = "azure", feature = "http"))]
pub static RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";

// deserialize dates according to rfc1123
#[cfg(any(feature = "azure", feature = "http"))]
pub fn deserialize_rfc1123<'de, D>(
    deserializer: D,
) -> Result<chrono::DateTime<chrono::Utc>, D::Error>
where
    D: serde::Deserializer<'de>,
{
    let s: String = serde::Deserialize::deserialize(deserializer)?;
    let naive =
        chrono::NaiveDateTime::parse_from_str(&s, RFC1123_FMT).map_err(serde::de::Error::custom)?;
    Ok(chrono::TimeZone::from_utc_datetime(&chrono::Utc, &naive))
}

#[cfg(any(feature = "aws", feature = "azure"))]
pub(crate) fn hmac_sha256(secret: impl AsRef<[u8]>, bytes: impl AsRef<[u8]>) -> ring::hmac::Tag {
    let key = ring::hmac::Key::new(ring::hmac::HMAC_SHA256, secret.as_ref());
    ring::hmac::sign(&key, bytes.as_ref())
}

/// Collect a stream into [`Bytes`] avoiding copying in the event of a single chunk
pub async fn collect_bytes<S, E>(mut stream: S, size_hint: Option<usize>) -> Result<Bytes, E>
where
    E: Send,
    S: Stream<Item = Result<Bytes, E>> + Send + Unpin,
{
    let first = stream.next().await.transpose()?.unwrap_or_default();

    // Avoid copying if single response
    match stream.next().await.transpose()? {
        None => Ok(first),
        Some(second) => {
            let size_hint = size_hint.unwrap_or_else(|| first.len() + second.len());

            let mut buf = Vec::with_capacity(size_hint);
            buf.extend_from_slice(&first);
            buf.extend_from_slice(&second);
            while let Some(maybe_bytes) = stream.next().await {
                buf.extend_from_slice(&maybe_bytes?);
            }

            Ok(buf.into())
        }
    }
}

#[cfg(not(target_arch = "wasm32"))]
/// Takes a function and spawns it to a tokio blocking pool if available
pub async fn maybe_spawn_blocking<F, T>(f: F) -> Result<T>
where
    F: FnOnce() -> Result<T> + Send + 'static,
    T: Send + 'static,
{
    match tokio::runtime::Handle::try_current() {
        Ok(runtime) => runtime.spawn_blocking(f).await?,
        Err(_) => f(),
    }
}

/// Range requests with a gap less than or equal to this,
/// will be coalesced into a single request by [`coalesce_ranges`]
pub const OBJECT_STORE_COALESCE_DEFAULT: usize = 1024 * 1024;

/// Up to this number of range requests will be performed in parallel by [`coalesce_ranges`]
pub const OBJECT_STORE_COALESCE_PARALLEL: usize = 10;

/// Takes a function `fetch` that can fetch a range of bytes and uses this to
/// fetch the provided byte `ranges`
///
/// To improve performance it will:
///
/// * Combine ranges less than `coalesce` bytes apart into a single call to `fetch`
/// * Make multiple `fetch` requests in parallel (up to maximum of 10)
///
pub async fn coalesce_ranges<F, E, Fut>(
    ranges: &[Range<usize>],
    fetch: F,
    coalesce: usize,
) -> Result<Vec<Bytes>, E>
where
    F: Send + FnMut(Range<usize>) -> Fut,
    E: Send,
    Fut: std::future::Future<Output = Result<Bytes, E>> + Send,
{
    let fetch_ranges = merge_ranges(ranges, coalesce);

    let fetched: Vec<_> = futures::stream::iter(fetch_ranges.iter().cloned())
        .map(fetch)
        .buffered(OBJECT_STORE_COALESCE_PARALLEL)
        .try_collect()
        .await?;

    Ok(ranges
        .iter()
        .map(|range| {
            let idx = fetch_ranges.partition_point(|v| v.start <= range.start) - 1;
            let fetch_range = &fetch_ranges[idx];
            let fetch_bytes = &fetched[idx];

            let start = range.start - fetch_range.start;
            let end = range.end - fetch_range.start;
            fetch_bytes.slice(start..end.min(fetch_bytes.len()))
        })
        .collect())
}

/// Returns a sorted list of ranges that cover `ranges`
fn merge_ranges(ranges: &[Range<usize>], coalesce: usize) -> Vec<Range<usize>> {
    if ranges.is_empty() {
        return vec![];
    }

    let mut ranges = ranges.to_vec();
    ranges.sort_unstable_by_key(|range| range.start);

    let mut ret = Vec::with_capacity(ranges.len());
    let mut start_idx = 0;
    let mut end_idx = 1;

    while start_idx != ranges.len() {
        let mut range_end = ranges[start_idx].end;

        while end_idx != ranges.len()
            && ranges[end_idx]
                .start
                .checked_sub(range_end)
                .map(|delta| delta <= coalesce)
                .unwrap_or(true)
        {
            range_end = range_end.max(ranges[end_idx].end);
            end_idx += 1;
        }

        let start = ranges[start_idx].start;
        let end = range_end;
        ret.push(start..end);

        start_idx = end_idx;
        end_idx += 1;
    }

    ret
}

/// Request only a portion of an object's bytes
///
/// These can be created from [usize] ranges, like
///
/// ```rust
/// # use object_store::GetRange;
/// let range1: GetRange = (50..150).into();
/// let range2: GetRange = (50..=150).into();
/// let range3: GetRange = (50..).into();
/// let range4: GetRange = (..150).into();
/// ```
///
/// Implementations may wish to inspect [`GetResult`] for the exact byte
/// range returned.
///
/// [`GetResult`]: crate::GetResult
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum GetRange {
    /// Request a specific range of bytes
    ///
    /// If the given range is zero-length or starts after the end of the object,
    /// an error will be returned. Additionally, if the range ends after the end
    /// of the object, the entire remainder of the object will be returned.
    /// Otherwise, the exact requested range will be returned.
    Bounded(Range<usize>),
    /// Request all bytes starting from a given byte offset
    Offset(usize),
    /// Request up to the last n bytes
    Suffix(usize),
}

#[derive(Debug, Snafu)]
pub(crate) enum InvalidGetRange {
    #[snafu(display(
        "Wanted range starting at {requested}, but object was only {length} bytes long"
    ))]
    StartTooLarge { requested: usize, length: usize },

    #[snafu(display("Range started at {start} and ended at {end}"))]
    Inconsistent { start: usize, end: usize },
}

impl GetRange {
    pub(crate) fn is_valid(&self) -> Result<(), InvalidGetRange> {
        match self {
            Self::Bounded(r) if r.end <= r.start => {
                return Err(InvalidGetRange::Inconsistent {
                    start: r.start,
                    end: r.end,
                });
            }
            _ => (),
        };
        Ok(())
    }

    /// Convert to a [`Range`] if valid.
    pub(crate) fn as_range(&self, len: usize) -> Result<Range<usize>, InvalidGetRange> {
        self.is_valid()?;
        match self {
            Self::Bounded(r) => {
                if r.start >= len {
                    Err(InvalidGetRange::StartTooLarge {
                        requested: r.start,
                        length: len,
                    })
                } else if r.end > len {
                    Ok(r.start..len)
                } else {
                    Ok(r.clone())
                }
            }
            Self::Offset(o) => {
                if *o >= len {
                    Err(InvalidGetRange::StartTooLarge {
                        requested: *o,
                        length: len,
                    })
                } else {
                    Ok(*o..len)
                }
            }
            Self::Suffix(n) => Ok(len.saturating_sub(*n)..len),
        }
    }
}

impl Display for GetRange {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Bounded(r) => write!(f, "bytes={}-{}", r.start, r.end - 1),
            Self::Offset(o) => write!(f, "bytes={o}-"),
            Self::Suffix(n) => write!(f, "bytes=-{n}"),
        }
    }
}

impl<T: RangeBounds<usize>> From<T> for GetRange {
    fn from(value: T) -> Self {
        use std::ops::Bound::*;
        let first = match value.start_bound() {
            Included(i) => *i,
            Excluded(i) => i + 1,
            Unbounded => 0,
        };
        match value.end_bound() {
            Included(i) => Self::Bounded(first..(i + 1)),
            Excluded(i) => Self::Bounded(first..*i),
            Unbounded => Self::Offset(first),
        }
    }
}
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
//
// Do not URI-encode any of the unreserved characters that RFC 3986 defines:
// A-Z, a-z, 0-9, hyphen ( - ), underscore ( _ ), period ( . ), and tilde ( ~ ).
#[cfg(any(feature = "aws", feature = "gcp"))]
pub(crate) const STRICT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::NON_ALPHANUMERIC
    .remove(b'-')
    .remove(b'.')
    .remove(b'_')
    .remove(b'~');

/// Computes the SHA256 digest of `body` returned as a hex encoded string
#[cfg(any(feature = "aws", feature = "gcp"))]
pub(crate) fn hex_digest(bytes: &[u8]) -> String {
    let digest = ring::digest::digest(&ring::digest::SHA256, bytes);
    hex_encode(digest.as_ref())
}

/// Returns `bytes` as a lower-case hex encoded string
#[cfg(any(feature = "aws", feature = "gcp"))]
pub(crate) fn hex_encode(bytes: &[u8]) -> String {
    use std::fmt::Write;
    let mut out = String::with_capacity(bytes.len() * 2);
    for byte in bytes {
        // String writing is infallible
        let _ = write!(out, "{byte:02x}");
    }
    out
}

#[cfg(test)]
mod tests {
    use crate::Error;

    use super::*;
    use rand::{thread_rng, Rng};
    use std::ops::Range;

    /// Calls coalesce_ranges and validates the returned data is correct
    ///
    /// Returns the fetched ranges
    async fn do_fetch(ranges: Vec<Range<usize>>, coalesce: usize) -> Vec<Range<usize>> {
        let max = ranges.iter().map(|x| x.end).max().unwrap_or(0);
        let src: Vec<_> = (0..max).map(|x| x as u8).collect();

        let mut fetches = vec![];
        let coalesced = coalesce_ranges::<_, Error, _>(
            &ranges,
            |range| {
                fetches.push(range.clone());
                futures::future::ready(Ok(Bytes::from(src[range].to_vec())))
            },
            coalesce,
        )
        .await
        .unwrap();

        assert_eq!(ranges.len(), coalesced.len());
        for (range, bytes) in ranges.iter().zip(coalesced) {
            assert_eq!(bytes.as_ref(), &src[range.clone()]);
        }
        fetches
    }

    #[tokio::test]
    async fn test_coalesce_ranges() {
        let fetches = do_fetch(vec![], 0).await;
        assert!(fetches.is_empty());

        let fetches = do_fetch(vec![0..3; 1], 0).await;
        assert_eq!(fetches, vec![0..3]);

        let fetches = do_fetch(vec![0..2, 3..5], 0).await;
        assert_eq!(fetches, vec![0..2, 3..5]);

        let fetches = do_fetch(vec![0..1, 1..2], 0).await;
        assert_eq!(fetches, vec![0..2]);

        let fetches = do_fetch(vec![0..1, 2..72], 1).await;
        assert_eq!(fetches, vec![0..72]);

        let fetches = do_fetch(vec![0..1, 56..72, 73..75], 1).await;
        assert_eq!(fetches, vec![0..1, 56..75]);

        let fetches = do_fetch(vec![0..1, 5..6, 7..9, 2..3, 4..6], 1).await;
        assert_eq!(fetches, vec![0..9]);

        let fetches = do_fetch(vec![0..1, 5..6, 7..9, 2..3, 4..6], 1).await;
        assert_eq!(fetches, vec![0..9]);

        let fetches = do_fetch(vec![0..1, 6..7, 8..9, 10..14, 9..10], 4).await;
        assert_eq!(fetches, vec![0..1, 6..14]);
    }

    #[tokio::test]
    async fn test_coalesce_fuzz() {
        let mut rand = thread_rng();
        for _ in 0..100 {
            let object_len = rand.gen_range(10..250);
            let range_count = rand.gen_range(0..10);
            let ranges: Vec<_> = (0..range_count)
                .map(|_| {
                    let start = rand.gen_range(0..object_len);
                    let max_len = 20.min(object_len - start);
                    let len = rand.gen_range(0..max_len);
                    start..start + len
                })
                .collect();

            let coalesce = rand.gen_range(1..5);
            let fetches = do_fetch(ranges.clone(), coalesce).await;

            for fetch in fetches.windows(2) {
                assert!(
                    fetch[0].start <= fetch[1].start,
                    "fetches should be sorted, {:?} vs {:?}",
                    fetch[0],
                    fetch[1]
                );

                let delta = fetch[1].end - fetch[0].end;
                assert!(
                    delta > coalesce,
                    "fetches should not overlap by {}, {:?} vs {:?} for {:?}",
                    coalesce,
                    fetch[0],
                    fetch[1],
                    ranges
                );
            }
        }
    }

    #[test]
    fn getrange_str() {
        assert_eq!(GetRange::Offset(0).to_string(), "bytes=0-");
        assert_eq!(GetRange::Bounded(10..19).to_string(), "bytes=10-18");
        assert_eq!(GetRange::Suffix(10).to_string(), "bytes=-10");
    }

    #[test]
    fn getrange_from() {
        assert_eq!(Into::<GetRange>::into(10..15), GetRange::Bounded(10..15),);
        assert_eq!(Into::<GetRange>::into(10..=15), GetRange::Bounded(10..16),);
        assert_eq!(Into::<GetRange>::into(10..), GetRange::Offset(10),);
        assert_eq!(Into::<GetRange>::into(..=15), GetRange::Bounded(0..16));
    }

    #[test]
    fn test_as_range() {
        let range = GetRange::Bounded(2..5);
        assert_eq!(range.as_range(5).unwrap(), 2..5);

        let range = range.as_range(4).unwrap();
        assert_eq!(range, 2..4);

        let range = GetRange::Bounded(3..3);
        let err = range.as_range(2).unwrap_err().to_string();
        assert_eq!(err, "Range started at 3 and ended at 3");

        let range = GetRange::Bounded(2..2);
        let err = range.as_range(3).unwrap_err().to_string();
        assert_eq!(err, "Range started at 2 and ended at 2");

        let range = GetRange::Suffix(3);
        assert_eq!(range.as_range(3).unwrap(), 0..3);
        assert_eq!(range.as_range(2).unwrap(), 0..2);

        let range = GetRange::Suffix(0);
        assert_eq!(range.as_range(0).unwrap(), 0..0);

        let range = GetRange::Offset(2);
        let err = range.as_range(2).unwrap_err().to_string();
        assert_eq!(
            err,
            "Wanted range starting at 2, but object was only 2 bytes long"
        );

        let err = range.as_range(1).unwrap_err().to_string();
        assert_eq!(
            err,
            "Wanted range starting at 2, but object was only 1 bytes long"
        );

        let range = GetRange::Offset(1);
        assert_eq!(range.as_range(2).unwrap(), 1..2);
    }
}