Skip to content

Commit 4025291

Browse files
committed
Add EVAL_RO and EVALSHA_RO commands
1 parent 8205bc2 commit 4025291

File tree

3 files changed

+142
-7
lines changed

3 files changed

+142
-7
lines changed

cmd_scripting.go

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ import (
1818

1919
func commandsScripting(m *Miniredis) {
2020
m.srv.Register("EVAL", m.cmdEval)
21+
m.srv.RegisterWithOptions("EVAL_RO", m.cmdEvalro, server.ReadOnlyOption())
2122
m.srv.Register("EVALSHA", m.cmdEvalsha)
23+
m.srv.RegisterWithOptions("EVALSHA_RO", m.cmdEvalshaRo, server.ReadOnlyOption())
2224
m.srv.Register("SCRIPT", m.cmdScript)
2325
}
2426

@@ -28,7 +30,7 @@ var (
2830

2931
// Execute lua. Needs to run m.Lock()ed, from within withTx().
3032
// Returns true if the lua was OK (and hence should be cached).
31-
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool {
33+
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, readOnly bool, args []string) bool {
3234
l := lua.NewState(lua.Options{SkipOpenLibs: true})
3335
defer l.Close()
3436

@@ -85,7 +87,7 @@ func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []stri
8587
}
8688
l.SetGlobal("ARGV", argvTable)
8789

88-
redisFuncs, redisConstants := mkLua(m.srv, c, sha)
90+
redisFuncs, redisConstants := mkLua(m.srv, c, sha, readOnly)
8991
// Register command handlers
9092
l.Push(l.NewFunction(func(l *lua.LState) int {
9193
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
@@ -150,7 +152,8 @@ func compile(script string) (*lua.FunctionProto, error) {
150152
return proto, nil
151153
}
152154

153-
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
155+
// Shared implementation for EVAL and EVALRO
156+
func (m *Miniredis) cmdEvalShared(c *server.Peer, cmd string, readOnly bool, args []string) {
154157
if !m.isValidCMD(c, cmd, args, atLeast(2)) {
155158
return
156159
}
@@ -165,14 +168,20 @@ func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
165168

166169
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
167170
sha := sha1Hex(script)
168-
ok := m.runLuaScript(c, sha, script, args)
171+
ok := m.runLuaScript(c, sha, script, readOnly, args)
169172
if ok {
170173
m.scripts[sha] = script
171174
}
172175
})
173176
}
174177

175-
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
178+
// Wrapper function for EVAL command
179+
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
180+
m.cmdEvalShared(c, cmd, false, args)
181+
}
182+
183+
// Shared implementation for EVALSHA and EVALSHA_RO
184+
func (m *Miniredis) cmdEvalshaShared(c *server.Peer, cmd string, readOnly bool, args []string) {
176185
if !m.isValidCMD(c, cmd, args, atLeast(2)) {
177186
return
178187
}
@@ -192,10 +201,25 @@ func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
192201
return
193202
}
194203

195-
m.runLuaScript(c, sha, script, args)
204+
m.runLuaScript(c, sha, script, readOnly, args)
196205
})
197206
}
198207

208+
// Wrapper function for EVALSHA command
209+
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
210+
m.cmdEvalshaShared(c, cmd, false, args)
211+
}
212+
213+
// Wrapper function for EVALRO command
214+
func (m *Miniredis) cmdEvalro(c *server.Peer, cmd string, args []string) {
215+
m.cmdEvalShared(c, cmd, true, args)
216+
}
217+
218+
// Wrapper function for EVALSHA_RO command
219+
func (m *Miniredis) cmdEvalshaRo(c *server.Peer, cmd string, args []string) {
220+
m.cmdEvalshaShared(c, cmd, true, args)
221+
}
222+
199223
func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
200224
if !m.isValidCMD(c, cmd, args, atLeast(1)) {
201225
return

cmd_scripting_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,102 @@ func TestLuaTX(t *testing.T) {
598598
})
599599
}
600600

601+
func TestEvalRo(t *testing.T) {
602+
_, c := runWithClient(t)
603+
604+
t.Run("read-only command", func(t *testing.T) {
605+
mustOK(t, c,
606+
"SET", "readonly", "foo",
607+
)
608+
609+
// Test EVALRO with read-only command (should work)
610+
mustDo(t, c,
611+
"EVALRO", "return redis.call('GET', KEYS[1])", "1", "readonly",
612+
proto.String("foo"),
613+
)
614+
})
615+
616+
t.Run("write command", func(t *testing.T) {
617+
// Test EVALRO with write command (should fail)
618+
mustContain(t, c,
619+
"EVALRO", "return redis.call('SET', KEYS[1], ARGV[1])", "1", "key1", "value1",
620+
"Write commands are not allowed in read-only scripts",
621+
)
622+
})
623+
}
624+
625+
func TestEvalshaRo(t *testing.T) {
626+
_, c := runWithClient(t)
627+
628+
// First load a read-only script
629+
script := "return redis.call('GET', KEYS[1])"
630+
t.Run("read-only script", func(t *testing.T) {
631+
mustDo(t, c,
632+
"SCRIPT", "LOAD", script,
633+
proto.String("d3c21d0c2b9ca22f82737626a27bcaf5d288f99f"),
634+
)
635+
636+
mustOK(t, c,
637+
"SET", "readonly", "foo",
638+
)
639+
640+
// Test EVALSHA_RO with read-only command (should work)
641+
mustDo(t, c,
642+
"EVALSHA_RO", "d3c21d0c2b9ca22f82737626a27bcaf5d288f99f", "1", "readonly",
643+
proto.String("foo"),
644+
)
645+
646+
})
647+
648+
t.Run("write script", func(t *testing.T) {
649+
// Load a write script
650+
writeScript := "return redis.call('SET', KEYS[1], ARGV[1])"
651+
mustDo(t, c,
652+
"SCRIPT", "LOAD", writeScript,
653+
proto.String("d8f2fad9f8e86a53d2a6ebd960b33c4972cacc37"),
654+
)
655+
656+
// Test EVALSHA_RO with write command (should fail)
657+
mustContain(t, c,
658+
"EVALSHA_RO", "d8f2fad9f8e86a53d2a6ebd960b33c4972cacc37", "1", "key1", "value1",
659+
"Write commands are not allowed in read-only scripts",
660+
)
661+
})
662+
}
663+
664+
func TestEvalRoWriteCommandWithPcall(t *testing.T) {
665+
_, c := runWithClient(t)
666+
667+
t.Run("return error", func(t *testing.T) {
668+
// Test EVAL with pcall and write command (should fail)
669+
mustContain(t, c,
670+
"EVALRO", "return redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1])", "1", "key1", "value1",
671+
"Unknown Redis command called from script",
672+
)
673+
})
674+
675+
t.Run("extra work after error", func(t *testing.T) {
676+
script := `
677+
local err = redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1]);
678+
local res = "pcall:" .. err['err'];
679+
return res;
680+
`
681+
// Test EVAL with pcall and write command (should fail)
682+
mustContain(t, c,
683+
"EVALRO", script, "1", "key1", "value1",
684+
"pcall:ERR Unknown Redis command called from script",
685+
)
686+
})
687+
688+
t.Run("write command in read-only script", func(t *testing.T) {
689+
// Test EVALRO with pcall and write command (should fail)
690+
mustContain(t, c,
691+
"EVALRO", "return redis.pcall('SET', KEYS[1], ARGV[1])", "1", "key1", "value1",
692+
"Write commands are not allowed in read-only scripts",
693+
)
694+
})
695+
}
696+
601697
func TestEvalWithPcall(t *testing.T) {
602698
_, c := runWithClient(t)
603699

lua.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ var luaRedisConstants = map[string]lua.LValue{
1818
"LOG_WARNING": lua.LNumber(3),
1919
}
2020

21-
func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFunction, map[string]lua.LValue) {
21+
func mkLua(srv *server.Server, c *server.Peer, sha string, readOnly bool) (map[string]lua.LGFunction, map[string]lua.LValue) {
2222
mkCall := func(failFast bool) func(l *lua.LState) int {
2323
// one server.Ctx for a single Lua run
2424
pCtx := &connCtx{}
@@ -52,6 +52,21 @@ func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFun
5252
return 0
5353
}
5454

55+
// Add readonly check
56+
if readOnly && len(args) > 0 {
57+
if srv.IsRegisteredCommand(args[0]) && !srv.IsReadOnlyCommand(args[0]) {
58+
if failFast {
59+
l.Error(lua.LString("Write commands are not allowed in read-only scripts"), 1)
60+
return 0
61+
}
62+
// pcall() mode - return error table
63+
res := &lua.LTable{}
64+
res.RawSetString("err", lua.LString("Write commands are not allowed in read-only scripts"))
65+
l.Push(res)
66+
return 1
67+
}
68+
}
69+
5570
buf := &bytes.Buffer{}
5671
wr := bufio.NewWriter(buf)
5772
peer := server.NewPeer(wr)

0 commit comments

Comments
 (0)