@@ -13,8 +13,9 @@ import (
13
13
14
14
"cloud.google.com/go/spanner"
15
15
sdb "cloud.google.com/go/spanner/admin/database/apiv1"
16
- "cloud.google.com/go/spanner/spansql"
17
16
17
+ "github.com/cloudspannerecosystem/memefish"
18
+ "github.com/cloudspannerecosystem/memefish/token"
18
19
"github.com/golang-migrate/migrate/v4"
19
20
"github.com/golang-migrate/migrate/v4/database"
20
21
@@ -60,11 +61,9 @@ type Config struct {
60
61
61
62
// Spanner implements database.Driver for Google Cloud Spanner
62
63
type Spanner struct {
63
- db * DB
64
-
64
+ db * DB
65
65
config * Config
66
-
67
- lock * uatomic.Uint32
66
+ lock * uatomic.Uint32
68
67
}
69
68
70
69
type DB struct {
@@ -179,26 +178,65 @@ func (s *Spanner) Run(migration io.Reader) error {
179
178
return err
180
179
}
181
180
182
- stmts := []string {string (migr )}
183
- if s .config .CleanStatements {
184
- stmts , err = cleanStatements (migr )
185
- if err != nil {
186
- return err
181
+ ctx := context .Background ()
182
+
183
+ if ! s .config .CleanStatements {
184
+ return s .runDdl (ctx , []string {string (migr )})
185
+ }
186
+
187
+ stmtGroups , err := statementGroups (migr )
188
+ if err != nil {
189
+ return err
190
+ }
191
+
192
+ for _ , group := range stmtGroups {
193
+ switch group .typ {
194
+ case statementTypeDDL :
195
+ if err := s .runDdl (ctx , group .stmts ); err != nil {
196
+ return err
197
+ }
198
+ case statementTypeDML :
199
+ if err := s .runDml (ctx , group .stmts ); err != nil {
200
+ return err
201
+ }
202
+ default :
203
+ return fmt .Errorf ("unknown statement type: %s" , group .typ )
187
204
}
188
205
}
189
206
190
- ctx := context .Background ()
207
+ return nil
208
+ }
209
+
210
+ func (s * Spanner ) runDdl (ctx context.Context , stmts []string ) error {
191
211
op , err := s .db .admin .UpdateDatabaseDdl (ctx , & adminpb.UpdateDatabaseDdlRequest {
192
212
Database : s .config .DatabaseName ,
193
213
Statements : stmts ,
194
214
})
195
215
196
216
if err != nil {
197
- return & database.Error {OrigErr : err , Err : "migration failed" , Query : migr }
217
+ return & database.Error {OrigErr : err , Err : "migration failed" , Query : [] byte ( strings . Join ( stmts , "; \n " )) }
198
218
}
199
219
200
220
if err := op .Wait (ctx ); err != nil {
201
- return & database.Error {OrigErr : err , Err : "migration failed" , Query : migr }
221
+ return & database.Error {OrigErr : err , Err : "migration failed" , Query : []byte (strings .Join (stmts , ";\n " ))}
222
+ }
223
+
224
+ return nil
225
+ }
226
+
227
+ func (s * Spanner ) runDml (ctx context.Context , stmts []string ) error {
228
+ _ , err := s .db .data .ReadWriteTransaction (ctx ,
229
+ func (ctx context.Context , txn * spanner.ReadWriteTransaction ) error {
230
+ for _ , s := range stmts {
231
+ _ , err := txn .Update (ctx , spanner.Statement {SQL : s })
232
+ if err != nil {
233
+ return err
234
+ }
235
+ }
236
+ return nil
237
+ })
238
+ if err != nil {
239
+ return & database.Error {OrigErr : err , Err : "migration failed" , Query : []byte (strings .Join (stmts , ";\n " ))}
202
240
}
203
241
204
242
return nil
@@ -345,17 +383,80 @@ func (s *Spanner) ensureVersionTable() (err error) {
345
383
return nil
346
384
}
347
385
348
- func cleanStatements (migration []byte ) ([]string , error ) {
349
- // The Spanner GCP backend does not yet support comments for the UpdateDatabaseDdl RPC
350
- // (see https://issuetracker.google.com/issues/159730604) we use
351
- // spansql to parse the DDL and output valid stamements without comments
352
- ddl , err := spansql .ParseDDL ("" , string (migration ))
353
- if err != nil {
354
- return nil , err
386
+ type statementType string
387
+
388
+ const (
389
+ statementTypeUnknown statementType = ""
390
+ statementTypeDDL statementType = "DDL"
391
+ statementTypeDML statementType = "DML"
392
+ )
393
+
394
+ type statementGroup struct {
395
+ typ statementType
396
+ stmts []string
397
+ }
398
+
399
+ func statementGroups (migr []byte ) (groups []* statementGroup , err error ) {
400
+ lex := & memefish.Lexer {
401
+ File : & token.File {Buffer : string (migr )},
355
402
}
356
- stmts := make ([]string , 0 , len (ddl .List ))
357
- for _ , stmt := range ddl .List {
358
- stmts = append (stmts , stmt .SQL ())
403
+
404
+ group := & statementGroup {}
405
+ var stmtTyp statementType
406
+ var stmt strings.Builder
407
+ for {
408
+ if err := lex .NextToken (); err != nil {
409
+ return nil , err
410
+ }
411
+
412
+ if stmtTyp == statementTypeUnknown {
413
+ switch {
414
+ case lex .Token .IsKeywordLike ("INSERT" ) || lex .Token .IsKeywordLike ("DELETE" ) || lex .Token .IsKeywordLike ("UPDATE" ):
415
+ stmtTyp = statementTypeDML
416
+ default :
417
+ stmtTyp = statementTypeDDL
418
+ }
419
+ if group .typ != stmtTyp {
420
+ if len (group .stmts ) > 0 {
421
+ groups = append (groups , group )
422
+ }
423
+ group = & statementGroup {typ : stmtTyp }
424
+ }
425
+ }
426
+
427
+ if lex .Token .Kind == token .TokenEOF || lex .Token .Kind == ";" {
428
+ if stmt .Len () > 0 {
429
+ group .stmts = append (group .stmts , stmt .String ())
430
+ }
431
+ stmtTyp = statementTypeUnknown
432
+ stmt .Reset ()
433
+
434
+ if lex .Token .Kind == token .TokenEOF {
435
+ if len (group .stmts ) > 0 {
436
+ groups = append (groups , group )
437
+ }
438
+
439
+ break
440
+ }
441
+
442
+ continue
443
+ }
444
+
445
+ if len (lex .Token .Comments ) > 0 {
446
+ // preserve newline where comments are removed
447
+ if _ , err := stmt .WriteString ("\n " ); err != nil {
448
+ return nil , err
449
+ }
450
+ }
451
+ if stmt .Len () > 0 {
452
+ if _ , err := stmt .WriteString (lex .Token .Space ); err != nil {
453
+ return nil , err
454
+ }
455
+ }
456
+ if _ , err := stmt .WriteString (lex .Token .Raw ); err != nil {
457
+ return nil , err
458
+ }
359
459
}
360
- return stmts , nil
460
+
461
+ return groups , nil
361
462
}
0 commit comments