Skip to content

Commit 2d4193b

Browse files
committed
Add interact_with_state closure cheatcode
Towards #3331 commit-id:0b8b8c6b
1 parent 2f9d9c9 commit 2d4193b

File tree

13 files changed

+488
-0
lines changed

13 files changed

+488
-0
lines changed

.github/workflows/ci.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,22 @@ jobs:
159159

160160
- run: cargo test --package forge --features scarb_since_2_10 sierra_gas
161161

162+
# TODO(#3212): Closures in Cairo are fully supported since version 2.11
163+
test-interact-with-state:
164+
name: Test interact with state
165+
runs-on: ubuntu-latest
166+
steps:
167+
- uses: actions/checkout@v4
168+
- uses: dtolnay/rust-toolchain@stable
169+
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6
170+
- uses: software-mansion/setup-scarb@v1
171+
with:
172+
scarb-version: "2.11.0"
173+
- uses: software-mansion/setup-universal-sierra-compiler@v1
174+
175+
- run: cargo test --release --package forge --features interact-with-state --test main integration::interact_with_state
176+
- run: cargo test --release --package forge --features interact-with-state --test main e2e::running::test_interact_with_state
177+
162178
test-forge-runner:
163179
name: Test Forge Runner
164180
runs-on: ubuntu-latest

crates/forge/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ no_scarb_installed = []
1313
debugging = []
1414
assert_non_exact_gas = ["test_utils/assert_non_exact_gas"]
1515
supports-panic-backtrace = []
16+
interact-with-state = []
1617

1718
[dependencies]
1819
anyhow.workspace = true
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[package]
2+
name = "contract_state"
3+
version = "0.1.0"
4+
edition = "2024_07"
5+
6+
[dependencies]
7+
starknet = "2.11.0"
8+
9+
[dev-dependencies]
10+
snforge_std = { path = "../../../../../snforge_std" }
11+
assert_macros = "2.11.0"
12+
13+
[[target.starknet-contract]]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#[starknet::interface]
2+
pub trait IHelloStarknetExtended<TContractState> {
3+
fn increase_balance(ref self: TContractState, amount: u256);
4+
fn get_balance(self: @TContractState) -> u256;
5+
fn get_caller_info(self: @TContractState, address: starknet::ContractAddress) -> u256;
6+
fn get_balance_at(self: @TContractState, index: u64) -> u256;
7+
}
8+
9+
10+
#[starknet::contract]
11+
pub mod HelloStarknetExtended {
12+
use starknet::storage::{
13+
Map, MutableVecTrait, StoragePathEntry, StoragePointerReadAccess, StoragePointerWriteAccess,
14+
Vec, VecTrait,
15+
};
16+
use starknet::{ContractAddress, get_caller_address};
17+
18+
#[derive(starknet::Store)]
19+
struct Owner {
20+
pub address: ContractAddress,
21+
pub name: felt252,
22+
}
23+
24+
#[storage]
25+
struct Storage {
26+
pub owner: Owner,
27+
pub balance: u256,
28+
pub balance_records: Vec<u256>,
29+
pub callers: Map<ContractAddress, u256>,
30+
}
31+
32+
#[constructor]
33+
fn constructor(ref self: ContractState, owner_name: felt252) {
34+
self
35+
._set_owner(
36+
starknet::get_execution_info().tx_info.account_contract_address, owner_name,
37+
);
38+
self.balance_records.push(0);
39+
}
40+
41+
#[abi(embed_v0)]
42+
impl HelloStarknetExtendedImpl of super::IHelloStarknetExtended<ContractState> {
43+
fn increase_balance(ref self: ContractState, amount: u256) {
44+
let caller = get_caller_address();
45+
let value_before = self.callers.entry(caller).read();
46+
47+
assert(amount != 0, 'Amount cannot be 0');
48+
49+
self.balance.write(self.balance.read() + amount);
50+
self.callers.entry(caller).write(value_before + amount);
51+
self.balance_records.push(self.balance.read());
52+
}
53+
54+
fn get_balance(self: @ContractState) -> u256 {
55+
self.balance.read()
56+
}
57+
58+
fn get_caller_info(self: @ContractState, address: ContractAddress) -> u256 {
59+
self.callers.entry(address).read()
60+
}
61+
62+
fn get_balance_at(self: @ContractState, index: u64) -> u256 {
63+
assert(index < self.balance_records.len(), 'Index out of range');
64+
self.balance_records.at(index).read()
65+
}
66+
}
67+
68+
#[generate_trait]
69+
pub impl InternalFunctions of InternalFunctionsTrait {
70+
fn _set_owner(ref self: ContractState, address: ContractAddress, name: felt252) {
71+
self.owner.address.write(address);
72+
self.owner.name.write(name);
73+
}
74+
}
75+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod balance;
2+
pub mod storage_node;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#[starknet::contract]
2+
pub mod StorageNodeContract {
3+
use starknet::ContractAddress;
4+
use starknet::storage::Map;
5+
6+
#[starknet::storage_node]
7+
struct RandomData {
8+
title: felt252,
9+
description: felt252,
10+
counter: u32,
11+
data: Map<(ContractAddress, u16), ByteArray>,
12+
}
13+
14+
#[storage]
15+
struct Storage {
16+
pub random_data: RandomData,
17+
}
18+
}
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
use contract_state::balance::HelloStarknetExtended::InternalFunctionsTrait;
2+
use contract_state::balance::{
3+
HelloStarknetExtended, IHelloStarknetExtendedDispatcher, IHelloStarknetExtendedDispatcherTrait,
4+
};
5+
use snforge_std::{ContractClassTrait, DeclareResultTrait, declare, interact_with_state};
6+
use starknet::ContractAddress;
7+
use starknet::storage::{
8+
MutableVecTrait, StorageMapWriteAccess, StoragePointerReadAccess, StoragePointerWriteAccess,
9+
};
10+
11+
fn deploy_contract(name: ByteArray) -> ContractAddress {
12+
let contract = declare(name).unwrap().contract_class();
13+
let mut calldata = array![];
14+
calldata.append('Name');
15+
16+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
17+
contract_address
18+
}
19+
20+
#[test]
21+
fn test_interact_with_state() {
22+
let contract_address = deploy_contract("HelloStarknetExtended");
23+
let dispatcher = IHelloStarknetExtendedDispatcher { contract_address };
24+
25+
assert(dispatcher.get_balance() == 0, 'Wrong balance');
26+
27+
interact_with_state(
28+
contract_address,
29+
|| {
30+
let mut state = HelloStarknetExtended::contract_state_for_testing();
31+
state.balance.write(987);
32+
},
33+
);
34+
35+
assert(dispatcher.get_balance() == 987, 'Wrong balance');
36+
dispatcher.increase_balance(13);
37+
assert(dispatcher.get_balance() == 1000, 'Wrong balance');
38+
}
39+
40+
#[test]
41+
fn test_interact_with_state_return() {
42+
let contract_address = deploy_contract("HelloStarknetExtended");
43+
let dispatcher = IHelloStarknetExtendedDispatcher { contract_address };
44+
45+
assert(dispatcher.get_balance() == 0, 'Wrong balance');
46+
47+
let res = interact_with_state(
48+
contract_address,
49+
|| -> u256 {
50+
let mut state = HelloStarknetExtended::contract_state_for_testing();
51+
state.balance.write(111);
52+
state.balance.read()
53+
},
54+
);
55+
56+
assert(res == 111, 'Wrong balance');
57+
}
58+
59+
#[test]
60+
fn test_interact_with_initialized_state() {
61+
let contract_address = deploy_contract("HelloStarknetExtended");
62+
let dispatcher = IHelloStarknetExtendedDispatcher { contract_address };
63+
64+
dispatcher.increase_balance(199);
65+
66+
interact_with_state(
67+
contract_address,
68+
|| {
69+
let mut state = HelloStarknetExtended::contract_state_for_testing();
70+
assert(state.balance.read() == 199, 'Wrong balance');
71+
state.balance.write(1);
72+
},
73+
);
74+
75+
assert(dispatcher.get_balance() == 1, 'Wrong balance');
76+
}
77+
78+
#[test]
79+
fn test_interact_with_state_vec() {
80+
let contract_address = deploy_contract("HelloStarknetExtended");
81+
let dispatcher = IHelloStarknetExtendedDispatcher { contract_address };
82+
83+
dispatcher.increase_balance(1);
84+
dispatcher.increase_balance(1);
85+
dispatcher.increase_balance(1);
86+
87+
interact_with_state(
88+
contract_address,
89+
|| {
90+
let mut state = HelloStarknetExtended::contract_state_for_testing();
91+
assert(state.balance_records.len() == 4, 'Wrong length');
92+
state.balance_records.push(10);
93+
},
94+
);
95+
96+
assert(dispatcher.get_balance_at(0) == 0, 'Wrong balance');
97+
assert(dispatcher.get_balance_at(2) == 2, 'Wrong balance');
98+
assert(dispatcher.get_balance_at(4) == 10, 'Wrong balance');
99+
}
100+
101+
#[test]
102+
fn test_interact_with_state_map() {
103+
let contract_address = deploy_contract("HelloStarknetExtended");
104+
let dispatcher = IHelloStarknetExtendedDispatcher { contract_address };
105+
106+
dispatcher.increase_balance(1);
107+
108+
interact_with_state(
109+
contract_address,
110+
|| {
111+
let mut state = HelloStarknetExtended::contract_state_for_testing();
112+
state.callers.write(0x123.try_into().unwrap(), 1000);
113+
state.callers.write(0x321.try_into().unwrap(), 2000);
114+
},
115+
);
116+
117+
assert(
118+
dispatcher.get_caller_info(0x123.try_into().unwrap()) == 1000,
119+
'Wrong data for address 0x123',
120+
);
121+
assert(
122+
dispatcher.get_caller_info(0x321.try_into().unwrap()) == 2000,
123+
'Wrong data for address 0x321',
124+
);
125+
assert(
126+
dispatcher.get_caller_info(0x12345.try_into().unwrap()) == 0,
127+
'Wrong data for address 0x12345',
128+
);
129+
}
130+
131+
#[test]
132+
fn test_interact_with_state_internal_function() {
133+
let contract_address = deploy_contract("HelloStarknetExtended");
134+
135+
let get_owner =
136+
|| -> (
137+
ContractAddress, felt252,
138+
) {
139+
interact_with_state(
140+
contract_address,
141+
|| -> (
142+
ContractAddress, felt252,
143+
) {
144+
let mut state = HelloStarknetExtended::contract_state_for_testing();
145+
(state.owner.address.read(), state.owner.name.read())
146+
},
147+
)
148+
};
149+
let (owner_address, owner_name) = get_owner();
150+
assert(owner_address == 0.try_into().unwrap(), 'Incorrect owner address');
151+
assert(owner_name == 'Name', 'Incorrect owner name');
152+
153+
interact_with_state(
154+
contract_address,
155+
|| {
156+
let mut state = HelloStarknetExtended::contract_state_for_testing();
157+
state._set_owner(0x777.try_into().unwrap(), 'New name');
158+
},
159+
);
160+
let (owner_address, owner_name) = get_owner();
161+
162+
assert(owner_address == 0x777.try_into().unwrap(), 'Incorrect owner address');
163+
assert(owner_name == 'New name', 'Incorrect owner name');
164+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#[starknet::interface]
2+
pub trait IHelloStarknetExtended<TContractState> {
3+
fn increase_balance(ref self: TContractState, amount: u256);
4+
fn get_balance(self: @TContractState) -> u256;
5+
fn get_caller_info(self: @TContractState, address: starknet::ContractAddress) -> u256;
6+
fn get_balance_at(self: @TContractState, index: u64) -> u256;
7+
}
8+
9+
10+
#[starknet::contract]
11+
pub mod HelloStarknetExtended {
12+
use starknet::storage::{
13+
Map, MutableVecTrait, StoragePathEntry, StoragePointerReadAccess, StoragePointerWriteAccess,
14+
Vec, VecTrait,
15+
};
16+
use starknet::{ContractAddress, get_caller_address};
17+
18+
#[derive(starknet::Store)]
19+
struct Owner {
20+
pub address: ContractAddress,
21+
pub name: felt252,
22+
}
23+
24+
#[storage]
25+
struct Storage {
26+
pub owner: Owner,
27+
pub balance: u256,
28+
pub balance_records: Vec<u256>,
29+
pub callers: Map<ContractAddress, u256>,
30+
}
31+
32+
#[constructor]
33+
fn constructor(ref self: ContractState, owner_name: felt252) {
34+
self
35+
._set_owner(
36+
starknet::get_execution_info().tx_info.account_contract_address, owner_name,
37+
);
38+
self.balance_records.push(0);
39+
}
40+
41+
#[abi(embed_v0)]
42+
impl HelloStarknetExtendedImpl of super::IHelloStarknetExtended<ContractState> {
43+
fn increase_balance(ref self: ContractState, amount: u256) {
44+
let caller = get_caller_address();
45+
let value_before = self.callers.entry(caller).read();
46+
47+
assert(amount != 0, 'Amount cannot be 0');
48+
49+
self.balance.write(self.balance.read() + amount);
50+
self.callers.entry(caller).write(value_before + amount);
51+
self.balance_records.push(self.balance.read());
52+
}
53+
54+
fn get_balance(self: @ContractState) -> u256 {
55+
self.balance.read()
56+
}
57+
58+
fn get_caller_info(self: @ContractState, address: ContractAddress) -> u256 {
59+
self.callers.entry(address).read()
60+
}
61+
62+
fn get_balance_at(self: @ContractState, index: u64) -> u256 {
63+
assert(index < self.balance_records.len(), 'Index out of range');
64+
self.balance_records.at(index).read()
65+
}
66+
}
67+
68+
#[generate_trait]
69+
pub impl InternalFunctions of InternalFunctionsTrait {
70+
fn _set_owner(ref self: ContractState, address: ContractAddress, name: felt252) {
71+
self.owner.address.write(address);
72+
self.owner.name.write(name);
73+
}
74+
}
75+
}

0 commit comments

Comments
 (0)