Skip to content

Commit 3353643

Browse files
authored
Fix schema sampling apoc meta data (#450)
* Expose sample parameter in get_structured_schema and get_schema * Update changelog for schema sampling parameter * update unit tests and e2e tests * Apply ruff formatting to test_schema
1 parent c4a21ae commit 3353643

File tree

5 files changed

+39
-15
lines changed

5 files changed

+39
-15
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
### Added
66

77
- Added an optional `node_label_neo4j` parameter in the external retrievers to speed up the search query in Neo4j.
8+
9+
- Exposed optional `sample` parameter on `get_schema` and `get_structured_schema` to control APOC sampling for schema discovery.
810
- Added an optional `id_property_getter` callable parameter in the Qdrant retriever to allow for custom ID retrieval.
911

1012
## 1.10.1

src/neo4j_graphrag/schema.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
DISTINCT_VALUE_LIMIT = 10
3030

3131
NODE_PROPERTIES_QUERY = (
32-
"CALL apoc.meta.data() "
32+
"CALL apoc.meta.data({sample: $SAMPLE}) "
3333
"YIELD label, other, elementType, type, property "
3434
"WHERE NOT type = 'RELATIONSHIP' AND elementType = 'node' "
3535
"AND NOT label IN $EXCLUDED_LABELS "
@@ -38,7 +38,7 @@
3838
)
3939

4040
REL_PROPERTIES_QUERY = (
41-
"CALL apoc.meta.data() "
41+
"CALL apoc.meta.data({sample: $SAMPLE}) "
4242
"YIELD label, other, elementType, type, property "
4343
"WHERE NOT type = 'RELATIONSHIP' AND elementType = 'relationship' "
4444
"AND NOT label in $EXCLUDED_LABELS "
@@ -47,7 +47,7 @@
4747
)
4848

4949
REL_QUERY = (
50-
"CALL apoc.meta.data() "
50+
"CALL apoc.meta.data({sample: $SAMPLE}) "
5151
"YIELD label, other, elementType, type, property "
5252
"WHERE type = 'RELATIONSHIP' AND elementType = 'node' "
5353
"UNWIND other AS other_node "
@@ -186,6 +186,7 @@ def get_schema(
186186
database: Optional[str] = None,
187187
timeout: Optional[float] = None,
188188
sanitize: bool = False,
189+
sample: int = 1000,
189190
) -> str:
190191
"""
191192
Returns the schema of the graph as a string with following format:
@@ -210,6 +211,8 @@ def get_schema(
210211
sanitize (bool): A flag to indicate whether to remove lists with
211212
more than 128 elements from results. Useful for removing
212213
embedding-like properties from database responses. Default is False.
214+
sample (int): Number of nodes to sample for the apoc.meta.data procedure. Setting sample to -1 will remove sampling.
215+
Defaults to 1000.
213216
214217
215218
Returns:
@@ -221,6 +224,7 @@ def get_schema(
221224
database=database,
222225
timeout=timeout,
223226
sanitize=sanitize,
227+
sample=sample,
224228
)
225229
return format_schema(structured_schema, is_enhanced)
226230

@@ -231,6 +235,7 @@ def get_structured_schema(
231235
database: Optional[str] = None,
232236
timeout: Optional[float] = None,
233237
sanitize: bool = False,
238+
sample: int = 1000,
234239
) -> dict[str, Any]:
235240
"""
236241
Returns the structured schema of the graph.
@@ -280,6 +285,8 @@ def get_structured_schema(
280285
sanitize (bool): A flag to indicate whether to remove lists with
281286
more than 128 elements from results. Useful for removing
282287
embedding-like properties from database responses. Default is False.
288+
sample (int): Number of nodes to sample for the apoc.meta.data procedure. Setting sample to -1 will remove sampling.
289+
Defaults to 1000.
283290
284291
Returns:
285292
dict[str, Any]: the graph schema information in a structured format.
@@ -291,7 +298,8 @@ def get_structured_schema(
291298
query=NODE_PROPERTIES_QUERY,
292299
params={
293300
"EXCLUDED_LABELS": EXCLUDED_LABELS
294-
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
301+
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL],
302+
"SAMPLE": sample,
295303
},
296304
database=database,
297305
timeout=timeout,
@@ -304,7 +312,7 @@ def get_structured_schema(
304312
for data in query_database(
305313
driver=driver,
306314
query=REL_PROPERTIES_QUERY,
307-
params={"EXCLUDED_LABELS": EXCLUDED_RELS},
315+
params={"EXCLUDED_LABELS": EXCLUDED_RELS, "SAMPLE": sample},
308316
database=database,
309317
timeout=timeout,
310318
sanitize=sanitize,
@@ -318,7 +326,8 @@ def get_structured_schema(
318326
query=REL_QUERY,
319327
params={
320328
"EXCLUDED_LABELS": EXCLUDED_LABELS
321-
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
329+
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL],
330+
"SAMPLE": sample,
322331
},
323332
database=database,
324333
timeout=timeout,

tests/e2e/test_schema_e2e.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
@pytest.mark.usefixtures("setup_neo4j_for_schema_query")
3030
def test_cypher_returns_correct_node_properties(driver: Driver) -> None:
3131
node_properties = query_database(
32-
driver, NODE_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]}
32+
driver,
33+
NODE_PROPERTIES_QUERY,
34+
params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL], "SAMPLE": 1000},
3335
)
3436

3537
expected_node_properties = [
@@ -47,7 +49,9 @@ def test_cypher_returns_correct_node_properties(driver: Driver) -> None:
4749
@pytest.mark.usefixtures("setup_neo4j_for_schema_query")
4850
def test_cypher_returns_correct_relationship_properties(driver: Driver) -> None:
4951
relationships_properties = query_database(
50-
driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]}
52+
driver,
53+
REL_PROPERTIES_QUERY,
54+
params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL], "SAMPLE": 1000},
5155
)
5256

5357
expected_relationships_properties = [
@@ -65,7 +69,9 @@ def test_cypher_returns_correct_relationship_properties(driver: Driver) -> None:
6569
@pytest.mark.usefixtures("setup_neo4j_for_schema_query")
6670
def test_cypher_returns_correct_relationships(driver: Driver) -> None:
6771
relationships = query_database(
68-
driver, REL_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]}
72+
driver,
73+
REL_QUERY,
74+
params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL], "SAMPLE": 1000},
6975
)
7076

7177
expected_relationships = [

tests/e2e/test_schema_filters_e2e.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_filtering_labels_node_properties(driver: Driver) -> None:
3333
for data in query_database(
3434
driver,
3535
NODE_PROPERTIES_QUERY,
36-
params={"EXCLUDED_LABELS": EXCLUDED_LABELS},
36+
params={"EXCLUDED_LABELS": EXCLUDED_LABELS, "SAMPLE": 1000},
3737
)
3838
]
3939

@@ -45,7 +45,9 @@ def test_filtering_labels_relationship_properties(driver: Driver) -> None:
4545
relationship_properties = [
4646
data["output"]
4747
for data in query_database(
48-
driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS}
48+
driver,
49+
REL_PROPERTIES_QUERY,
50+
params={"EXCLUDED_LABELS": EXCLUDED_RELS, "SAMPLE": 1000},
4951
)
5052
]
5153

@@ -59,7 +61,10 @@ def test_filtering_labels_relationships(driver: Driver) -> None:
5961
for data in query_database(
6062
driver,
6163
REL_QUERY,
62-
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
64+
params={
65+
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL],
66+
"SAMPLE": 1000,
67+
},
6368
)
6469
]
6570

tests/unit/test_schema.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def test_get_structured_schema_happy_path(driver: MagicMock) -> None:
9797
assert query_obj.timeout is None
9898
assert kwargs["database_"] is None
9999
assert kwargs["parameters_"] == {
100-
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
100+
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL],
101+
"SAMPLE": 1000,
101102
}
102103

103104
args, kwargs = calls[1]
@@ -106,7 +107,7 @@ def test_get_structured_schema_happy_path(driver: MagicMock) -> None:
106107
assert query_obj.text == REL_PROPERTIES_QUERY
107108
assert query_obj.timeout is None
108109
assert kwargs["database_"] is None
109-
assert kwargs["parameters_"] == {"EXCLUDED_LABELS": EXCLUDED_RELS}
110+
assert kwargs["parameters_"] == {"EXCLUDED_LABELS": EXCLUDED_RELS, "SAMPLE": 1000}
110111

111112
args, kwargs = calls[2]
112113
query_obj = args[0]
@@ -115,7 +116,8 @@ def test_get_structured_schema_happy_path(driver: MagicMock) -> None:
115116
assert query_obj.timeout is None
116117
assert kwargs["database_"] is None
117118
assert kwargs["parameters_"] == {
118-
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
119+
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL],
120+
"SAMPLE": 1000,
119121
}
120122

121123
args, kwargs = calls[3]

0 commit comments

Comments
 (0)