Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 63 additions & 22 deletions storages/backends/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,54 @@ def _filter_download_params(params):
}


class S3SeekableFile(io.RawIOBase):
"""
Wrapper around an S3 file, providing support for seeking.
The position of the pointer is held in memory.
For every read only the requested length is loaded from the S3 service.
"""

def __init__(self, obj, ExtraArgs):
self.s3_object = obj
self.size = obj.content_length
self.pos = 0
self.ExtraArgs = ExtraArgs

def seek(self, offset, whence=io.SEEK_SET):
if whence == io.SEEK_SET:
self.pos = offset
elif whence == io.SEEK_CUR:
self.pos += offset
elif whence == io.SEEK_END:
self.pos = self.size + offset
else:
raise ValueError("Invalid value for whence.")

return self.pos

def readable(self):
return True

def seekable(self):
return True

def tell(self):
return self.pos

def read(self, size=-1):
if size < 0 or self.pos + size > self.size:
size = self.size - self.pos

if size <= 0:
return b""

range_header = f"bytes={self.pos}-{self.pos + size - 1}"
resp = self.s3_object.get(Range=range_header, **self.ExtraArgs)
data = resp['Body'].read()
self.pos += len(data)

return data

@deconstructible
class S3File(CompressedFileMixin, File):
"""
Expand Down Expand Up @@ -167,35 +215,24 @@ def closed(self):

def _get_file(self):
if self._file is None:
self._file = tempfile.SpooledTemporaryFile(
max_size=self._storage.max_memory_size,
suffix=".S3File",
dir=setting("FILE_UPLOAD_TEMP_DIR"),
)
if "r" in self._mode:
self._is_dirty = False
params = _filter_download_params(
self._storage.get_object_parameters(self.name)
)
self.obj.download_fileobj(
self._file, ExtraArgs=params, Config=self._storage.transfer_config
)
self._file.seek(0)
self._file = S3SeekableFile(self.obj, ExtraArgs=params)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that Config=self._storage.transfer_config was removed... I'm not sure if this is needed as the method changed from download_fileobj to get as well.


if self._storage.gzip and self.obj.content_encoding == "gzip":
self._file = self._decompress_file(mode=self._mode, file=self._file)
elif "b" not in self._mode:
if hasattr(self._file, "readable"):
# For versions > Python 3.10 compatibility
# See SpooledTemporaryFile changes in 3.11 (https://docs.python.org/3/library/tempfile.html) # noqa: E501
# Now fully implements the io.BufferedIOBase and io.TextIOBase abstract base classes allowing the file # noqa: E501
# to be readable in the mode that it was specified (without accessing the underlying _file object). # noqa: E501
# In this case, we need to wrap the file in a TextIOWrapper to ensure that the file is read as a text file. # noqa: E501
self._file = io.TextIOWrapper(self._file, encoding="utf-8")
else:
# For versions <= Python 3.10 compatibility
self._file = io.TextIOWrapper(
self._file._file, encoding="utf-8"
)
if "b" not in self._mode:
self._file = io.TextIOWrapper(self._file, encoding="utf-8")
else:
self._file = tempfile.SpooledTemporaryFile(
max_size=self._storage.max_memory_size,
suffix=".S3File",
dir=setting("FILE_UPLOAD_TEMP_DIR"),
)

self._closed = False
return self._file

Expand All @@ -204,6 +241,10 @@ def _set_file(self, value):

file = property(_get_file, _set_file)

seek = property(lambda self: self.file.seek)
seekable = property(lambda self: self.file.seekable)
tell = property(lambda self: self.file.tell)

def read(self, *args, **kwargs):
if "r" not in self._mode:
raise AttributeError("File was not opened in read mode.")
Expand Down
106 changes: 103 additions & 3 deletions tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ def test_storage_open_read_string(self):
"""
Test opening a file in "r" mode (ie reading as string, not bytes)
"""
obj = self.storage.bucket.Object.return_value
obj.content_length = 0

name = "test_open_read_string.txt"
with self.storage.open(name, "r") as file:
content_str = file.read()
Expand Down Expand Up @@ -1031,6 +1034,35 @@ def setUp(self) -> None:
self.storage = s3.S3Storage()
self.storage._connections.connection = mock.MagicMock()

def test_open(self):
params = {"CacheControl": "never"}
self.storage.get_object_parameters = lambda name: params
f = s3.S3File("test", "r", self.storage)
f.file
f.obj.get.assert_not_called()

def test_read(self):
params = {"CacheControl": "never"}
self.storage.get_object_parameters = lambda name: params
f = s3.S3File("test", "r", self.storage)

data = b"01234"
f.obj.content_length = len(data)
f.obj.get.return_value = {"Body": io.BytesIO(data)}

self.assertEqual(data.decode("utf-8"), f.file.read())
f.obj.get.assert_called_once_with(Range="bytes=0-4")

def test_seek(self):
f = s3.S3File("test", "r", self.storage)

data = b"01234"
f.obj.content_length = len(data)
f.obj.get.return_value = {"Body": io.BytesIO(data)}

f.file.seek(5)
f.obj.get.assert_not_called()

def test_loading_ssec(self):
params = {"SSECustomerKey": "xyz", "CacheControl": "never"}
self.storage.get_object_parameters = lambda name: params
Expand All @@ -1040,9 +1072,7 @@ def test_loading_ssec(self):
f.obj.load.assert_called_once_with(**filtered)

f.file
f.obj.download_fileobj.assert_called_once_with(
mock.ANY, ExtraArgs=filtered, Config=self.storage.transfer_config
)
f.obj.get.assert_not_called()

def test_closed(self):
with s3.S3File("test", "wb", self.storage) as f:
Expand Down Expand Up @@ -1092,6 +1122,76 @@ def setUp(cls):
cls.bucket = cls.storage.connection.Bucket(settings.AWS_STORAGE_BUCKET_NAME)
cls.bucket.create()

def test_readable(self):
self.storage.save("file.txt", File(io.StringIO("01234")))

file = s3.S3File("file.txt", "r", self.storage)
self.assertTrue(file.readable())

def test_seekable(self):
self.storage.save("file.txt", File(io.StringIO("01234")))

file = s3.S3File("file.txt", "r", self.storage)
self.assertTrue(file.seekable())

def test_tell(self):
self.storage.save("file.txt", File(io.StringIO("01234")))

file = s3.S3File("file.txt", "r", self.storage)
self.assertEqual(0, file.tell())

file.seek(3)
self.assertEqual(3, file.tell())

file.seek(0, os.SEEK_END)
self.assertEqual(5, file.tell())

def test_seek_string_file(self):
self.storage.save("string_file.txt", File(io.StringIO("01234")))
file = s3.S3File("string_file.txt", "r", self.storage)

self.assertEqual(2, file.seek(2))
self.assertEqual("234", file.read())

def test_seek_start_string_file(self):
self.storage.save("string_file.txt", File(io.StringIO("01234")))

file = s3.S3File("string_file.txt", "r", self.storage)
self.assertEqual("01234", file.read())

self.assertEqual(0, file.seek(0))
self.assertEqual("01234", file.read())

def test_seek_end_string_file(self):
self.storage.save("string_file.txt", File(io.StringIO("01234")))

file = s3.S3File("string_file.txt", "r", self.storage)
self.assertEqual(5, file.seek(0, io.SEEK_END))
self.assertEqual("", file.read())

def test_seek_bytes_file(self):
self.storage.save("string_file.txt", File(io.BytesIO(b"01234")))

file = s3.S3File("string_file.txt", "rb", self.storage)
self.assertEqual(2, file.seek(2))
self.assertEqual(b"234", file.read())

def test_seek_start_bytes_file(self):
self.storage.save("bytes_file.txt", File(io.BytesIO(b"01234")))

file = s3.S3File("bytes_file.txt", "rb", self.storage)
self.assertEqual(b"01234", file.read())

self.assertEqual(0, file.seek(0))
self.assertEqual(b"01234", file.read())

def test_seek_end_bytes_file(self):
self.storage.save("bytes_file.txt", File(io.BytesIO(b"01234")))

file = s3.S3File("bytes_file.txt", "rb", self.storage)
self.assertEqual(5, file.seek(0, io.SEEK_END))
self.assertEqual(b"", file.read())

def test_save_bytes_file(self):
self.storage.save("bytes_file.txt", File(io.BytesIO(b"foo1")))

Expand Down