Skip to content
This repository was archived by the owner on Sep 2, 2024. It is now read-only.

Commit c803198

Browse files
committed
fixed issue with SQLite update function overwritting doc
1 parent 991014e commit c803198

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

database/sqlite/base.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,15 @@ func (sl *SQLite) GetDocumentsByIDs(auth model.Auth, dbName, col string, ids []s
278278
}
279279

280280
func (sl *SQLite) UpdateDocument(auth model.Auth, dbName, col, id string, doc map[string]interface{}) (map[string]interface{}, error) {
281+
orig, err := sl.GetDocumentByID(auth, dbName, col, id)
282+
if err != nil {
283+
return nil, err
284+
}
285+
286+
for key, val := range doc {
287+
orig[key] = val
288+
}
289+
281290
where := secureWrite(auth, col)
282291

283292
qry := fmt.Sprintf(`
@@ -286,7 +295,7 @@ func (sl *SQLite) UpdateDocument(auth model.Auth, dbName, col, id string, doc ma
286295
%s AND id = $3
287296
`, dbName, model.CleanCollectionName(col), where)
288297

289-
b, err := json.Marshal(doc)
298+
b, err := json.Marshal(orig)
290299
if err != nil {
291300
return nil, err
292301
}
@@ -297,6 +306,7 @@ func (sl *SQLite) UpdateDocument(auth model.Auth, dbName, col, id string, doc ma
297306

298307
updated, err := sl.GetDocumentByID(auth, dbName, col, id)
299308
if err != nil {
309+
fmt.Println("DEBUG: in getbyid", err)
300310
return nil, err
301311
}
302312

db_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,54 @@ func TestDBCreate(t *testing.T) {
148148
}
149149
}
150150

151+
func TestDBUpdateShouldNotOverwriteDoc(t *testing.T) {
152+
task :=
153+
Task{
154+
Title: "item created",
155+
Created: time.Now(),
156+
}
157+
158+
resp := dbReq(t, db.add, "POST", "/db/tasks", task)
159+
defer resp.Body.Close()
160+
161+
if resp.StatusCode > 299 {
162+
t.Fatal(GetResponseBody(t, resp))
163+
}
164+
165+
var saved Task
166+
if err := parseBody(resp.Body, &saved); err != nil {
167+
t.Fatal(err)
168+
} else if task.Title != saved.Title {
169+
t.Errorf("expected title to be %s go %s", task.Title, saved.Title)
170+
}
171+
172+
update := new(struct {
173+
Done bool `json:"done"`
174+
})
175+
update.Done = true
176+
177+
uresp := dbReq(t, db.update, "PUT", "/db/tasks/"+saved.ID, update)
178+
defer uresp.Body.Close()
179+
180+
if uresp.StatusCode > 299 {
181+
t.Fatal(GetResponseBody(t, uresp))
182+
}
183+
184+
resp = dbReq(t, db.get, "GET", "/db/tasks/"+saved.ID, nil)
185+
defer resp.Body.Close()
186+
187+
if resp.StatusCode > 299 {
188+
t.Fatal(GetResponseBody(t, resp))
189+
}
190+
191+
var afterUpdate Task
192+
if err := parseBody(resp.Body, &afterUpdate); err != nil {
193+
t.Fatal(err)
194+
} else if task.Title != afterUpdate.Title {
195+
t.Errorf("expected title to be '%s' got '%s'", task.Title, afterUpdate.Title)
196+
}
197+
}
198+
151199
func TestDBListCollections(t *testing.T) {
152200
req := httptest.NewRequest("GET", "/sudolistall", nil)
153201
w := httptest.NewRecorder()

0 commit comments

Comments
 (0)