6
6
import tarfile
7
7
8
8
from airflow .hooks .S3_hook import S3Hook
9
- from airflow .models import BaseOperator
9
+ from airflow .models import BaseOperator , SkipMixin
10
10
from airflow .utils .decorators import apply_defaults
11
11
12
12
from mailchimp_plugin .hooks .mailchimp_hook import MailchimpHook
32
32
}
33
33
34
34
35
- class MailchimpToS3Operator (BaseOperator ):
35
+ class MailchimpToS3Operator (BaseOperator , SkipMixin ):
36
36
"""
37
37
Make a query against Mailchimp and write the resulting data to s3
38
38
"""
@@ -147,7 +147,7 @@ def execute(self, context):
147
147
)
148
148
149
149
self .hook .get_conn ()
150
-
150
+
151
151
logging .info (
152
152
"Making request for"
153
153
" {0} object" .format (self .mailchimp_resource )
@@ -166,28 +166,43 @@ def execute(self, context):
166
166
results = self .read_file (url , results_field = 'sites' )
167
167
168
168
if self .mailchimp_resource == 'connected_sites_details' :
169
- endpoints = ["/connected-sites/{}" .format (result ['id' ]) for result in results ]
169
+ endpoints = [
170
+ "/connected-sites/{}" .format (result ['id' ]) for result in results ]
170
171
url = self .hook .run_batch (endpoints )
171
172
results = self .read_file (url )
172
173
else :
173
174
results = self .hook .run_query (self .mailchimp_resource )
174
175
175
176
# write the results to a temporary file and save that file to s3
176
- with NamedTemporaryFile ( "w" ) as tmp :
177
- for result in filterd_results :
178
- tmp . write ( json . dumps ( result ) + ' \n ' )
179
-
180
- tmp . flush ( )
181
-
182
- dest_s3 = S3Hook ( s3_conn_id = self . s3_conn_id )
183
- dest_s3 . load_file (
184
- filename = tmp . name ,
185
- key = self . s3_key ,
186
- bucket_name = self . s3_bucket ,
187
- replace = True
177
+ if len ( results ) == 0 or results is None :
178
+ logging . info ( "No records pulled from Mailchimp." )
179
+ downstream_tasks = context [ 'task' ]. get_flat_relatives (
180
+ upstream = False )
181
+ logging . info ( 'Skipping downstream tasks...' )
182
+ logging . debug ( "Downstream task_ids %s" , downstream_tasks )
183
+
184
+ if downstream_tasks :
185
+ self . skip ( context [ 'dag_run' ] ,
186
+ context [ 'ti' ]. execution_date ,
187
+ downstream_tasks )
188
+ return True
188
189
189
- )
190
- dest_s3 .connection .close ()
191
- tmp .close ()
192
-
193
- logging .info ("Query finished!" )
190
+ else :
191
+ # Write the results to a temporary file and save that file to s3.
192
+ with NamedTemporaryFile ("w" ) as tmp :
193
+ for result in results :
194
+ filtered_result = self .filter_fields (result )
195
+ tmp .write (json .dumps (filtered_result ) + '\n ' )
196
+
197
+ tmp .flush ()
198
+
199
+ dest_s3 = S3Hook (s3_conn_id = self .s3_conn_id )
200
+ dest_s3 .load_file (
201
+ filename = tmp .name ,
202
+ key = self .s3_key ,
203
+ bucket_name = self .s3_bucket ,
204
+ replace = True
205
+
206
+ )
207
+ dest_s3 .connection .close ()
208
+ tmp .close ()
0 commit comments