|
2 | 2 |
|
3 | 3 | import uuid
|
4 | 4 | from io import StringIO
|
5 |
| -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast |
| 5 | +from typing import ( |
| 6 | + Any, |
| 7 | + Callable, |
| 8 | + Dict, |
| 9 | + Iterable, |
| 10 | + List, |
| 11 | + Optional, |
| 12 | + Sequence, |
| 13 | + Tuple, |
| 14 | + Union, |
| 15 | + cast, |
| 16 | +) |
6 | 17 |
|
7 | 18 | import sqlalchemy as sa
|
8 | 19 | from pendulum import now
|
@@ -163,24 +174,43 @@ def bulk_insert_records( # type: ignore[override]
|
163 | 174 | for record in records
|
164 | 175 | ]
|
165 | 176 |
|
166 |
| - # Prepare processor functions for each column type. These are used to convert |
167 |
| - # from Python values to database values. |
| 177 | + # Prepare to process the rows into csv. Use each column's bind_processor to do |
| 178 | + # most of the work, then do the final construction of the csv rows ourselves |
| 179 | + # to control exactly how values are converted and which ones are quoted. |
168 | 180 | column_processors = [
|
169 | 181 | column.type.bind_processor(connection.dialect) or str for column in columns
|
170 | 182 | ]
|
171 | 183 |
|
172 |
| - # Create a buffer of CSV formatted values to send in bulk. |
| 184 | + def process_column_value(data: Any, proc: Callable) -> str: |
| 185 | + # If the data is null, return nothing (unquoted). |
| 186 | + if data is None: |
| 187 | + return "" |
| 188 | + |
| 189 | + # Pass the Python value through the bind_processor. |
| 190 | + value = proc(data) |
| 191 | + |
| 192 | + # If the value is a string, escape double-quotes as "" and return |
| 193 | + # a quoted value. |
| 194 | + if isinstance(value, str): |
| 195 | + # escape double quotes as "". |
| 196 | + return '"' + value.replace('"', '""') + '"' |
| 197 | + |
| 198 | + # If the value is a list (for ARRAY), escape double-quotes as \" and return |
| 199 | + # a quoted value in literal array format. |
| 200 | + if isinstance(value, list): |
| 201 | + # for each member of value, escape double quotes as \". |
| 202 | + return ( |
| 203 | + '"{' |
| 204 | + + ",".join('""' + v.replace('"', r'\""') + '""' for v in value) |
| 205 | + + '}"' |
| 206 | + ) |
| 207 | + |
| 208 | + # Otherwise, return the string representation of the value. |
| 209 | + return str(value) |
| 210 | + |
173 | 211 | buffer = StringIO()
|
174 | 212 | for row in data_to_insert:
|
175 |
| - processed_row = ",".join( |
176 |
| - map( |
177 |
| - lambda data, proc: ( |
178 |
| - "" if data is None else str(proc(data)).replace('"', '""') |
179 |
| - ), |
180 |
| - row, |
181 |
| - column_processors, |
182 |
| - ) |
183 |
| - ) |
| 213 | + processed_row = ",".join(map(process_column_value, row, column_processors)) |
184 | 214 |
|
185 | 215 | buffer.write(processed_row)
|
186 | 216 | buffer.write("\n")
|
|
0 commit comments