Skip to content

Commit f3186ac

Browse files
committed
feat: add ConstraintType to GraphSchema for constraint extraction
1 parent 3353643 commit f3186ac

File tree

2 files changed

+96
-3
lines changed

2 files changed

+96
-3
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,22 @@ def property_type_from_name(self, name: str) -> Optional[PropertyType]:
161161
return None
162162

163163

164+
class ConstraintType(BaseModel):
165+
"""
166+
Represents a constraint on a node in the graph.
167+
"""
168+
169+
type: Literal[
170+
"UNIQUENESS"
171+
] # TODO: add other constraint types ["propertyExistence", "propertyType", "key"]
172+
node_type: str
173+
property_name: str
174+
175+
model_config = ConfigDict(
176+
frozen=True,
177+
)
178+
179+
164180
class GraphSchema(DataModel):
165181
"""This model represents the expected
166182
node and relationship types in the graph.
@@ -177,6 +193,7 @@ class GraphSchema(DataModel):
177193
node_types: Tuple[NodeType, ...]
178194
relationship_types: Tuple[RelationshipType, ...] = tuple()
179195
patterns: Tuple[Tuple[str, str, str], ...] = tuple()
196+
constraints: Tuple[ConstraintType, ...] = tuple()
180197

181198
additional_node_types: bool = Field(
182199
default_factory=default_additional_item("node_types")
@@ -239,6 +256,21 @@ def validate_additional_parameters(self) -> Self:
239256
)
240257
return self
241258

259+
@model_validator(mode="after")
260+
def validate_constraints_against_node_types(self) -> Self:
261+
if not self.constraints:
262+
return self
263+
for constraint in self.constraints:
264+
if not constraint.property_name:
265+
raise SchemaValidationError(
266+
f"Constraint has no property name: {constraint}. Property name is required."
267+
)
268+
if constraint.node_type not in self._node_type_index:
269+
raise SchemaValidationError(
270+
f"Constraint references undefined node type: {constraint.node_type}"
271+
)
272+
return self
273+
242274
def node_type_from_label(self, label: str) -> Optional[NodeType]:
243275
return self._node_type_index.get(label)
244276

@@ -382,6 +414,7 @@ def create_schema_model(
382414
node_types: Sequence[NodeType],
383415
relationship_types: Optional[Sequence[RelationshipType]] = None,
384416
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
417+
constraints: Optional[Sequence[ConstraintType]] = None,
385418
**kwargs: Any,
386419
) -> GraphSchema:
387420
"""
@@ -403,6 +436,7 @@ def create_schema_model(
403436
node_types=node_types,
404437
relationship_types=relationship_types or (),
405438
patterns=patterns or (),
439+
constraints=constraints or (),
406440
**kwargs,
407441
)
408442
)
@@ -415,6 +449,7 @@ async def run(
415449
node_types: Sequence[NodeType],
416450
relationship_types: Optional[Sequence[RelationshipType]] = None,
417451
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
452+
constraints: Optional[Sequence[ConstraintType]] = None,
418453
**kwargs: Any,
419454
) -> GraphSchema:
420455
"""
@@ -432,6 +467,7 @@ async def run(
432467
node_types,
433468
relationship_types,
434469
patterns,
470+
constraints,
435471
**kwargs,
436472
)
437473

@@ -555,6 +591,41 @@ def _filter_relationships_without_labels(
555591
relationship_types, "relationship type"
556592
)
557593

594+
def _filter_invalid_constraints(
595+
self, constraints: List[Dict[str, Any]], node_types: List[Dict[str, Any]]
596+
) -> List[Dict[str, Any]]:
597+
"""Filter out constraints that reference undefined node types or have no property name."""
598+
if not constraints:
599+
return []
600+
601+
if not node_types:
602+
logging.info(
603+
"Filtering out all constraints because no node types are defined. "
604+
"Constraints reference node types that must be defined."
605+
)
606+
return []
607+
608+
valid_node_labels = {node_type.get("label") for node_type in node_types}
609+
610+
filtered_constraints = []
611+
for constraint in constraints:
612+
# check if the property_name is provided
613+
if not constraint.get("property_name"):
614+
logging.info(
615+
f"Filtering out constraint: {constraint}. "
616+
f"Property name is not provided."
617+
)
618+
continue
619+
# check if the node_type is valid
620+
if constraint.get("node_type") not in valid_node_labels:
621+
logging.info(
622+
f"Filtering out constraint: {constraint}. "
623+
f"Node type '{constraint.get('node_type')}' is not valid. Valid node types: {valid_node_labels}"
624+
)
625+
continue
626+
filtered_constraints.append(constraint)
627+
return filtered_constraints
628+
558629
def _clean_json_content(self, content: str) -> str:
559630
content = content.strip()
560631

@@ -624,6 +695,9 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
624695
extracted_patterns: Optional[List[Tuple[str, str, str]]] = extracted_schema.get(
625696
"patterns"
626697
)
698+
extracted_constraints: Optional[List[Dict[str, Any]]] = extracted_schema.get(
699+
"constraints"
700+
)
627701

628702
# Filter out nodes and relationships without labels
629703
extracted_node_types = self._filter_nodes_without_labels(extracted_node_types)
@@ -638,11 +712,18 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
638712
extracted_patterns, extracted_node_types, extracted_relationship_types
639713
)
640714

715+
# Filter out invalid constraints
716+
if extracted_constraints:
717+
extracted_constraints = self._filter_invalid_constraints(
718+
extracted_constraints, extracted_node_types
719+
)
720+
641721
return GraphSchema.model_validate(
642722
{
643723
"node_types": extracted_node_types,
644724
"relationship_types": extracted_relationship_types,
645725
"patterns": extracted_patterns,
726+
"constraints": extracted_constraints or [],
646727
}
647728
)
648729

src/neo4j_graphrag/generation/prompts.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,11 @@ class SchemaExtractionTemplate(PromptTemplate):
217217
4. Include property definitions only when the type can be confidently inferred, otherwise omit them.
218218
5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types.
219219
6. Do not create node types that aren't clearly mentioned in the text.
220-
7. Keep your schema minimal and focused on clearly identifiable patterns in the text.
220+
7. For each node type, identify a unique identifier property and add it as a UNIQUENESS constraint to the list of constraints.
221+
8. Constraints must reference a node_type label that exists in the list of node types.
222+
9. Each constraint must have a property_name having a name that indicates it is a unique identifier for the node type (e.g., person_id for Person, company_id for Company)
223+
10. Keep your schema minimal and focused on clearly identifiable patterns in the text.
224+
221225
222226
Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST,
223227
LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME.
@@ -233,18 +237,26 @@ class SchemaExtractionTemplate(PromptTemplate):
233237
"type": "STRING"
234238
}}
235239
]
236-
}},
240+
}}
237241
...
238242
],
239243
"relationship_types": [
240244
{{
241245
"label": "WORKS_FOR"
242-
}},
246+
}}
243247
...
244248
],
245249
"patterns": [
246250
["Person", "WORKS_FOR", "Company"],
247251
...
252+
],
253+
"constraints": [
254+
{{
255+
"type": "UNIQUENESS",
256+
"node_type": "Person",
257+
"property_name": "person_id"
258+
}}
259+
...
248260
]
249261
}}
250262

0 commit comments

Comments
 (0)