diff --git a/muskrat/s3consumer.py b/muskrat/s3consumer.py index 97551b7..d3b59ea 100644 --- a/muskrat/s3consumer.py +++ b/muskrat/s3consumer.py @@ -10,7 +10,7 @@ import os import time -import boto +import boto3 from muskrat.util import config_loader class S3Cursor(object): @@ -40,6 +40,11 @@ def __init__(self, name, type, **kwargs ): else: raise NotImplementedError('File cursor types currently the only types supported') + @classmethod + def at_path(cls, path): + """Creates a cursor object at the given path.""" + name = os.path.basename(path) + return cls(name, 'file', location=os.path.dirname(path)) def _update_file_cursor( self, key ): #instead of opening and re-opening we could just seek and truncate @@ -67,14 +72,29 @@ def update( self, key ): def get( self ): return self._get_func() + def filter_collection(self, collection): + """lowlevel helper to filter an s3 object collection with marker.""" + marker = self.get() + if marker: + collection = collection.filter(Marker=marker) + return collection.filter(Delimiter='/') + + def persist_progress(self, collection): + """Iterates through a collection, maintaining a persistent cursor.""" + for obj in collection: + yield obj + self.update(obj.key) + + def each(self, collection): + collection = self.filter_collection(collection) + return self.persist_progress(collection) + class S3Consumer(object): def __init__(self, routing_key, func, name=None, config='config.py'): self.config = config_loader( config ) - self._s3conn = None - self._bucket = None self.routing_key = routing_key.upper() self.callback = func @@ -92,15 +112,15 @@ def __init__(self, routing_key, func, name=None, config='config.py'): @property def s3conn(self): - if self._s3conn is None: - self._s3conn = boto.connect_s3( self.config.s3_key, self.config.s3_secret ) - return self._s3conn + return boto3.resource( + 's3', + aws_access_key_id=self.config.s3_key, + aws_secret_access_key=self.config.s3_secret, + ) @property def bucket(self): - if self._bucket is None: - self._bucket = self.s3conn.get_bucket( self.config.s3_bucket ) - return self._bucket + return self.s3conn.Bucket(self.config.s3_bucket) def _gen_name(self, func): """ Generates a cursor name so that the cursor can be re-attached to """ @@ -111,11 +131,8 @@ def _gen_routing_key( self, routing_key ): def _get_msg_iterator(self): #If marker is not matched to a key then the returned list is none. - msg_iterator = self.bucket.list( - prefix=self._gen_routing_key( self.routing_key ) + '/', - delimiter= '/', - marker=self._cursor.get() - ) + prefix = self._gen_routing_key(self.routing_key) + '/' + msg_iterator = self.bucket.objects.filter(Prefix=prefix) return msg_iterator @@ -125,12 +142,8 @@ def consume(self): #Update: actually... this doesn't seem to be a problem... msg_iterator = self._get_msg_iterator() - for msg in msg_iterator: - #Sub 'directories' are prefix objects, so ignore them - if isinstance( msg, boto.s3.key.Key ): - self.callback( msg.get_contents_as_string() ) - self._cursor.update( msg.name ) - + for obj in self._cursor.each(msg_iterator): + self.callback(obj.get()['Body'].read()) def consumption_loop( self, interval=2 ): """ @@ -157,14 +170,11 @@ class S3AggregateConsumer( S3Consumer ): def consume( self ): msg_iterator = self._get_msg_iterator() - cursor = None - messages = [] - for msg in msg_iterator: - if isinstance( msg, boto.s3.key.Key ): - messages.append( msg.get_contents_as_string() ) - cursor = msg.name + objs = list(self._cursor.filter_collection(msg_iterator)) + messages = [x.get()['Body'].read() for x in objs] if messages: + cursor = objs[-1].key self.callback( messages ) self._cursor.update( cursor ) diff --git a/muskrat/tests/test_s3consumer.py b/muskrat/tests/test_s3consumer.py index bd805a0..b787ca8 100644 --- a/muskrat/tests/test_s3consumer.py +++ b/muskrat/tests/test_s3consumer.py @@ -8,17 +8,30 @@ """ import unittest import os +import tempfile +import uuid +from datetime import datetime import boto +import boto3 os.environ['MUSKRAT'] = 'TEST' from ..producer import S3Producer -from ..s3consumer import S3Consumer, Consumer +from ..s3consumer import S3Consumer, Consumer, S3Cursor from ..util import config_loader config_path = 'config.py' TEST_KEY_PREFIX = 'Muskrat.Consumer' +class TempCursorFile(): + def __enter__(self): + fd, self.path = tempfile.mkstemp() + os.close(fd) + return self.path + + def __exit__(self, type, value, traceback): + os.remove(self.path) + class TestS3ConsumerBase( unittest.TestCase ): def setUp(self): @@ -113,7 +126,87 @@ def decorated_consumer( msg ): self.assertIsInstance( decorated_consumer.consumer, S3Consumer, 'Decorator did not correctly attach S3Consumer' ) decorated_consumer.consumer.consume() - + + +class TestS3CollectionEach(unittest.TestCase): + def setUp(self): + config = config_loader(config_path) + self.time_format = config.s3_timestamp_format + self.prefix = 'MUSKRAT/TEST/S3COLLECTIONEACH/' + + s3 = boto3.resource( + 's3', + aws_access_key_id=config.s3_key, + aws_secret_access_key=config.s3_secret, + ) + self.bucket = s3.Bucket(config.s3_bucket) + + def tearDown(self): + for obj in self.bucket.objects.filter(Prefix=self.prefix): + obj.delete() + + def _add_message(self, message): + key = self.prefix + datetime.today().strftime(self.time_format) + self.bucket.put_object(Key=key, Body=message) + + def test_s3collection_marker_each(self): + """ An s3collection iterator which persists the marker in a file """ + collection = self.bucket.objects.filter(Prefix=self.prefix) + + with TempCursorFile() as path: + cursor = S3Cursor.at_path(path) + + # add a message to the queue + message = str(uuid.uuid4()) + self._add_message(message) + + # iterate over queue, validate message & marker + counter = 0 + for obj in cursor.each(collection): + counter += 1 + last_key = obj.key + self.assertEqual(message, obj.get()['Body'].read()) + self.assertEqual(1, counter) + with open(path, 'r') as f: + self.assertEqual(last_key, f.read()) + + # add more messages to the queue + messages = [str(uuid.uuid4()), str(uuid.uuid4())] + self._add_message(messages[0]) + self._add_message(messages[1]) + + # iterate over queue & validate messages + counter = 0 + for obj in cursor.each(collection): + self.assertEqual(messages[counter], obj.get()['Body'].read()) + counter += 1 + self.assertEqual(2, counter) + + def test_prefix_match_extra_levels(self): + collection = self.bucket.objects.filter(Prefix=self.prefix) + + with TempCursorFile() as path: + cursor = S3Cursor.at_path(path) + + # add a message to the queue + message1 = str(uuid.uuid4()) + self._add_message(message1) + + # add a message with extra levels + message2 = str(uuid.uuid4()) + ts = datetime.today().strftime(self.time_format) + key = self.prefix + 'FOO/BAR/' + ts + self.bucket.put_object(Key=key, Body=message2) + + # iterate over queue, validate message & marker + counter = 0 + for obj in cursor.each(collection): + counter += 1 + last_key = obj.key + self.assertEqual(message1, obj.get()['Body'].read()) + self.assertEqual(1, counter) + with open(path, 'r') as f: + self.assertEqual(last_key, f.read()) if '__main__' == __name__: unittest.main() diff --git a/requirements.txt b/requirements.txt index 2301b14..cc7a3e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ pika==0.9.8 boto==2.7.0 +boto3==1.4.0 simplejson==3.0.7