Skip to content
This repository was archived by the owner on Apr 28, 2021. It is now read-only.

Commit a2b6992

Browse files
author
pheel
authored
feat: programmatic required slots
2 parents 271a723 + b9a9cda commit a2b6992

File tree

4 files changed

+347
-113
lines changed

4 files changed

+347
-113
lines changed

rasa_addons/core/actions/action_botfront_form.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import logging
22
import functools
3-
from typing import Dict, Text, Any, List, Union, Optional, Tuple
3+
from typing import Dict, Text, Any, List, Optional
4+
from rasa.core.slots import Slot
5+
6+
from rasa_addons.core.actions.required_slots_graph_parser import (
7+
RequiredSlotsGraphParser,
8+
)
49
from rasa_addons.core.actions.slot_rule_validator import validate_with_rule
510
from rasa_addons.core.actions.submit_form_to_botfront import submit_form_to_botfront
611

@@ -41,7 +46,12 @@ def __str__(self) -> Text:
4146
return f"FormAction('{self.name()}')"
4247

4348
def required_slots(self, tracker):
44-
return [s.get("name") for s in self.form_spec.get("slots", [])]
49+
graph = self.form_spec.get("graph_elements")
50+
if not graph:
51+
return [s.get("name") for s in self.form_spec.get("slots", [])]
52+
parser = RequiredSlotsGraphParser(graph)
53+
required_slots = parser.get_required_slots(tracker)
54+
return required_slots
4555

4656
def get_field_for_slot(
4757
self, slot: Text, field: Text, default: Optional[Any] = None,
@@ -51,32 +61,6 @@ def get_field_for_slot(
5161
return s.get(field, default)
5262
return default
5363

54-
async def validate_prefilled(
55-
self,
56-
output_channel: "OutputChannel",
57-
nlg: "NaturalLanguageGenerator",
58-
tracker: "DialogueStateTracker",
59-
domain: "Domain",
60-
):
61-
# collect values of required slots filled before activation
62-
prefilled_slots = {}
63-
events = []
64-
65-
for slot_name in self.required_slots(tracker):
66-
if not self._should_request_slot(tracker, slot_name):
67-
prefilled_slots[slot_name] = tracker.get_slot(slot_name)
68-
69-
if prefilled_slots:
70-
logger.debug(f"Validating pre-filled required slots: {prefilled_slots}")
71-
events.extend(
72-
await self.validate_slots(
73-
prefilled_slots, output_channel, nlg, tracker, domain
74-
)
75-
)
76-
else:
77-
logger.debug("No pre-filled required slots to validate.")
78-
return events
79-
8064
async def run(
8165
self,
8266
output_channel: "OutputChannel",
@@ -86,7 +70,9 @@ async def run(
8670
) -> List[Event]:
8771
# attempt retrieving spec
8872
if not len(self.form_spec):
89-
for form in tracker.slots.get("bf_forms", {}).initial_value:
73+
for form in tracker.slots.get(
74+
"bf_forms", Slot("bf_forms", initial_value=[])
75+
).initial_value:
9076
if form.get("name") == self.name():
9177
self.form_spec = clean_none_values(form)
9278
if not len(self.form_spec):
@@ -150,7 +136,8 @@ async def submit(
150136
template = await nlg.generate(
151137
f"utter_submit_{self.name()}", tracker, output_channel.name(),
152138
)
153-
events += [create_bot_utterance(template)]
139+
if template:
140+
events += [create_bot_utterance(template)]
154141
if collect_in_botfront:
155142
submit_form_to_botfront(tracker)
156143
return events
@@ -222,11 +209,13 @@ def entity_is_desired(
222209

223210
@staticmethod
224211
def get_entity_value(
225-
name: Text,
212+
name: Optional[Text],
226213
tracker: "DialogueStateTracker",
227214
role: Optional[Text] = None,
228215
group: Optional[Text] = None,
229216
) -> Any:
217+
if not name:
218+
return None
230219
# list is used to cover the case of list slot type
231220
value = list(
232221
tracker.get_latest_entity_values(name, entity_group=group, entity_role=role)
@@ -246,6 +235,8 @@ def extract_other_slots(
246235
domain: "Domain",
247236
) -> Dict[Text, Any]:
248237
slot_to_fill = tracker.get_slot(REQUESTED_SLOT)
238+
if not slot_to_fill:
239+
return {}
249240

250241
slot_values = {}
251242
for slot in self.required_slots(tracker):
@@ -300,6 +291,8 @@ def extract_requested_slot(
300291
else return None
301292
"""
302293
slot_to_fill = tracker.get_slot(REQUESTED_SLOT)
294+
if not slot_to_fill:
295+
return {}
303296
logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...")
304297

305298
# get mapping for requested slot
@@ -352,14 +345,16 @@ async def utter_post_validation(
352345
and self.get_field_for_slot(slot, "utter_on_new_valid_slot", False) is False
353346
):
354347
return []
355-
valid = "valid" if valid else "invalid"
348+
utter_what = "valid" if valid else "invalid"
356349

357350
# so utter_(in)valid_slot supports {slot} template replacements
358351
temp_tracker = tracker.copy()
359352
temp_tracker.slots[slot].value = value
360353
template = await nlg.generate(
361-
f"utter_{valid}_{slot}", temp_tracker, output_channel.name(),
354+
f"utter_{utter_what}_{slot}", temp_tracker, output_channel.name(),
362355
)
356+
if not template:
357+
return []
363358
return [create_bot_utterance(template)]
364359

365360
async def validate_slots(
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import Dict, Text, Any
2+
from rasa_addons.core.actions.slot_rule_validator import validate_with_rule
3+
4+
5+
class RequiredSlotsGraphParser:
6+
def __init__(self, required_slots_graph: Dict[Text, Any]) -> None:
7+
self.start = None
8+
self.nodes = {}
9+
for node in required_slots_graph.get("nodes", []):
10+
if node.get("type") == "start":
11+
self.start = node.get("id")
12+
continue
13+
self.nodes[node.get("id")] = node.get("slotName")
14+
self.edges = {}
15+
for edge in required_slots_graph.get("edges", []):
16+
source = edge.get("source")
17+
self.edges[source] = [*self.edges.get(source, []), edge]
18+
19+
def get_required_slots(self, tracker, start=None):
20+
required_slots = []
21+
current_source = start or self.start
22+
current_edges = self.edges.get(current_source, [])
23+
for edge in sorted(current_edges, key=lambda e: e.get("condition") is None):
24+
target, condition = edge.get("target"), edge.get("condition")
25+
if self.check_condition(tracker, condition):
26+
required_slots.append(self.nodes.get(target))
27+
required_slots += self.get_required_slots(tracker, start=target)
28+
break # use first matching condition, that's it
29+
else:
30+
continue
31+
return required_slots
32+
33+
def check_condition(self, tracker, condition):
34+
if condition is None:
35+
return True
36+
props = condition.get("properties", {})
37+
children = condition.get("children1", {}).values()
38+
if condition.get("type") == "rule":
39+
return self.check_atomic_condition(tracker, **props)
40+
conjunction_operator = any if props.get("conjunction") == "OR" else all
41+
polarity = (lambda p: not p) if props.get("not") else (lambda p: p)
42+
return polarity(
43+
conjunction_operator(
44+
self.check_condition(tracker, child) for child in children
45+
)
46+
)
47+
48+
def check_atomic_condition(self, tracker, field, operator, value, **rest):
49+
slot = tracker.slots.get(field)
50+
return validate_with_rule(
51+
slot.value if slot else None,
52+
{
53+
"operator": operator,
54+
"comparatum": [*value, None][0] # value is always a singleton list
55+
},
56+
)

rasa_addons/core/actions/slot_rule_validator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,19 @@ def validate_with_rule(value, validation_rule) -> bool:
7575
if operator in NUM_COMPARATUM_OPERATORS:
7676
try:
7777
comparatum = float(comparatum)
78-
except ValueError:
78+
except (ValueError, TypeError):
7979
raise ValueError(
8080
f"Validation operator '{operator}' requires a numerical comparatum."
8181
)
82-
except ValueError as e:
82+
except (ValueError, TypeError) as e:
8383
logger.error(str(e))
8484
return False
8585
if operator in TEXT_VALUE_OPERATORS and not isinstance(value, str):
8686
return False
8787
if operator in NUM_VALUE_OPERATORS:
8888
try:
8989
value = float(value)
90-
except ValueError:
90+
except (ValueError, TypeError):
9191
return False
9292
if operator == "is_in":
9393
return value in comparatum

0 commit comments

Comments
 (0)