1
1
import logging
2
2
import functools
3
- from typing import Dict , Text , Any , List , Union , Optional , Tuple
3
+ from typing import Dict , Text , Any , List , Optional
4
+ from rasa .core .slots import Slot
5
+
6
+ from rasa_addons .core .actions .required_slots_graph_parser import (
7
+ RequiredSlotsGraphParser ,
8
+ )
4
9
from rasa_addons .core .actions .slot_rule_validator import validate_with_rule
5
10
from rasa_addons .core .actions .submit_form_to_botfront import submit_form_to_botfront
6
11
@@ -41,7 +46,12 @@ def __str__(self) -> Text:
41
46
return f"FormAction('{ self .name ()} ')"
42
47
43
48
def required_slots (self , tracker ):
44
- return [s .get ("name" ) for s in self .form_spec .get ("slots" , [])]
49
+ graph = self .form_spec .get ("graph_elements" )
50
+ if not graph :
51
+ return [s .get ("name" ) for s in self .form_spec .get ("slots" , [])]
52
+ parser = RequiredSlotsGraphParser (graph )
53
+ required_slots = parser .get_required_slots (tracker )
54
+ return required_slots
45
55
46
56
def get_field_for_slot (
47
57
self , slot : Text , field : Text , default : Optional [Any ] = None ,
@@ -51,32 +61,6 @@ def get_field_for_slot(
51
61
return s .get (field , default )
52
62
return default
53
63
54
- async def validate_prefilled (
55
- self ,
56
- output_channel : "OutputChannel" ,
57
- nlg : "NaturalLanguageGenerator" ,
58
- tracker : "DialogueStateTracker" ,
59
- domain : "Domain" ,
60
- ):
61
- # collect values of required slots filled before activation
62
- prefilled_slots = {}
63
- events = []
64
-
65
- for slot_name in self .required_slots (tracker ):
66
- if not self ._should_request_slot (tracker , slot_name ):
67
- prefilled_slots [slot_name ] = tracker .get_slot (slot_name )
68
-
69
- if prefilled_slots :
70
- logger .debug (f"Validating pre-filled required slots: { prefilled_slots } " )
71
- events .extend (
72
- await self .validate_slots (
73
- prefilled_slots , output_channel , nlg , tracker , domain
74
- )
75
- )
76
- else :
77
- logger .debug ("No pre-filled required slots to validate." )
78
- return events
79
-
80
64
async def run (
81
65
self ,
82
66
output_channel : "OutputChannel" ,
@@ -86,7 +70,9 @@ async def run(
86
70
) -> List [Event ]:
87
71
# attempt retrieving spec
88
72
if not len (self .form_spec ):
89
- for form in tracker .slots .get ("bf_forms" , {}).initial_value :
73
+ for form in tracker .slots .get (
74
+ "bf_forms" , Slot ("bf_forms" , initial_value = [])
75
+ ).initial_value :
90
76
if form .get ("name" ) == self .name ():
91
77
self .form_spec = clean_none_values (form )
92
78
if not len (self .form_spec ):
@@ -150,7 +136,8 @@ async def submit(
150
136
template = await nlg .generate (
151
137
f"utter_submit_{ self .name ()} " , tracker , output_channel .name (),
152
138
)
153
- events += [create_bot_utterance (template )]
139
+ if template :
140
+ events += [create_bot_utterance (template )]
154
141
if collect_in_botfront :
155
142
submit_form_to_botfront (tracker )
156
143
return events
@@ -222,11 +209,13 @@ def entity_is_desired(
222
209
223
210
@staticmethod
224
211
def get_entity_value (
225
- name : Text ,
212
+ name : Optional [ Text ] ,
226
213
tracker : "DialogueStateTracker" ,
227
214
role : Optional [Text ] = None ,
228
215
group : Optional [Text ] = None ,
229
216
) -> Any :
217
+ if not name :
218
+ return None
230
219
# list is used to cover the case of list slot type
231
220
value = list (
232
221
tracker .get_latest_entity_values (name , entity_group = group , entity_role = role )
@@ -246,6 +235,8 @@ def extract_other_slots(
246
235
domain : "Domain" ,
247
236
) -> Dict [Text , Any ]:
248
237
slot_to_fill = tracker .get_slot (REQUESTED_SLOT )
238
+ if not slot_to_fill :
239
+ return {}
249
240
250
241
slot_values = {}
251
242
for slot in self .required_slots (tracker ):
@@ -300,6 +291,8 @@ def extract_requested_slot(
300
291
else return None
301
292
"""
302
293
slot_to_fill = tracker .get_slot (REQUESTED_SLOT )
294
+ if not slot_to_fill :
295
+ return {}
303
296
logger .debug (f"Trying to extract requested slot '{ slot_to_fill } ' ..." )
304
297
305
298
# get mapping for requested slot
@@ -352,14 +345,16 @@ async def utter_post_validation(
352
345
and self .get_field_for_slot (slot , "utter_on_new_valid_slot" , False ) is False
353
346
):
354
347
return []
355
- valid = "valid" if valid else "invalid"
348
+ utter_what = "valid" if valid else "invalid"
356
349
357
350
# so utter_(in)valid_slot supports {slot} template replacements
358
351
temp_tracker = tracker .copy ()
359
352
temp_tracker .slots [slot ].value = value
360
353
template = await nlg .generate (
361
- f"utter_{ valid } _{ slot } " , temp_tracker , output_channel .name (),
354
+ f"utter_{ utter_what } _{ slot } " , temp_tracker , output_channel .name (),
362
355
)
356
+ if not template :
357
+ return []
363
358
return [create_bot_utterance (template )]
364
359
365
360
async def validate_slots (
0 commit comments