Skip to content

Commit 2214b16

Browse files
authored
feat(null-propagation): Null value for required args -> Null return (#834)
* feat: propagate `None` in `ExtractByLlm` * feat(null-propagation): Null value for required args -> Null return
1 parent d3bc087 commit 2214b16

File tree

8 files changed

+249
-111
lines changed

8 files changed

+249
-111
lines changed

src/base/schema.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,11 @@ impl EnrichedValueType {
262262
attrs: Default::default(),
263263
}
264264
}
265+
266+
pub fn with_nullable(mut self, nullable: bool) -> Self {
267+
self.nullable = nullable;
268+
self
269+
}
265270
}
266271

267272
impl<DataType> EnrichedValueType<DataType> {

src/ops/factory_bases.rs

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ pub struct OpArgsResolver<'a> {
8989
num_positional_args: usize,
9090
next_positional_idx: usize,
9191
remaining_kwargs: HashMap<&'a str, usize>,
92+
required_args_idx: &'a mut Vec<usize>,
9293
}
9394

9495
impl<'a> OpArgsResolver<'a> {
95-
pub fn new(args: &'a [OpArgSchema]) -> Result<Self> {
96+
pub fn new(args: &'a [OpArgSchema], required_args_idx: &'a mut Vec<usize>) -> Result<Self> {
9697
let mut num_positional_args = 0;
9798
let mut kwargs = HashMap::new();
9899
for (idx, arg) in args.iter().enumerate() {
@@ -110,6 +111,7 @@ impl<'a> OpArgsResolver<'a> {
110111
num_positional_args,
111112
next_positional_idx: 0,
112113
remaining_kwargs: kwargs,
114+
required_args_idx,
113115
})
114116
}
115117

@@ -135,9 +137,11 @@ impl<'a> OpArgsResolver<'a> {
135137
}
136138

137139
pub fn next_arg(&mut self, name: &str) -> Result<ResolvedOpArg> {
138-
Ok(self
140+
let arg = self
139141
.next_optional_arg(name)?
140-
.ok_or_else(|| api_error!("Required argument `{name}` is missing",))?)
142+
.ok_or_else(|| api_error!("Required argument `{name}` is missing",))?;
143+
self.required_args_idx.push(arg.idx);
144+
Ok(arg)
141145
}
142146

143147
pub fn done(self) -> Result<()> {
@@ -233,7 +237,7 @@ pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'stat
233237
spec: Self::Spec,
234238
resolved_input_schema: Self::ResolvedArgs,
235239
context: Arc<FlowInstanceContext>,
236-
) -> Result<Box<dyn SimpleFunctionExecutor>>;
240+
) -> Result<impl SimpleFunctionExecutor>;
237241

238242
fn register(self, registry: &mut ExecutorFactoryRegistry) -> Result<()>
239243
where
@@ -246,6 +250,31 @@ pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'stat
246250
}
247251
}
248252

253+
struct FunctionExecutorWrapper<E: SimpleFunctionExecutor> {
254+
executor: E,
255+
required_args_idx: Vec<usize>,
256+
}
257+
258+
#[async_trait]
259+
impl<E: SimpleFunctionExecutor> SimpleFunctionExecutor for FunctionExecutorWrapper<E> {
260+
async fn evaluate(&self, args: Vec<value::Value>) -> Result<value::Value> {
261+
for idx in &self.required_args_idx {
262+
if args[*idx].is_null() {
263+
return Ok(value::Value::Null);
264+
}
265+
}
266+
self.executor.evaluate(args).await
267+
}
268+
269+
fn enable_cache(&self) -> bool {
270+
self.executor.enable_cache()
271+
}
272+
273+
fn behavior_version(&self) -> Option<u32> {
274+
self.executor.behavior_version()
275+
}
276+
}
277+
249278
#[async_trait]
250279
impl<T: SimpleFunctionFactoryBase> SimpleFunctionFactory for T {
251280
async fn build(
@@ -258,13 +287,31 @@ impl<T: SimpleFunctionFactoryBase> SimpleFunctionFactory for T {
258287
BoxFuture<'static, Result<Box<dyn SimpleFunctionExecutor>>>,
259288
)> {
260289
let spec: T::Spec = serde_json::from_value(spec)?;
261-
let mut args_resolver = OpArgsResolver::new(&input_schema)?;
262-
let (resolved_input_schema, output_schema) = self
290+
let mut required_args_idx = vec![];
291+
let mut args_resolver = OpArgsResolver::new(&input_schema, &mut required_args_idx)?;
292+
let (resolved_input_schema, mut output_schema) = self
263293
.resolve_schema(&spec, &mut args_resolver, &context)
264294
.await?;
295+
296+
// If any required argument is nullable, the output schema should be nullable.
297+
if args_resolver
298+
.required_args_idx
299+
.iter()
300+
.any(|idx| input_schema[*idx].value_type.nullable)
301+
{
302+
output_schema.nullable = true;
303+
}
304+
265305
args_resolver.done()?;
266-
let executor = self.build_executor(spec, resolved_input_schema, context);
267-
Ok((output_schema, executor))
306+
let executor = async move {
307+
Ok(Box::new(FunctionExecutorWrapper {
308+
executor: self
309+
.build_executor(spec, resolved_input_schema, context)
310+
.await?,
311+
required_args_idx,
312+
}) as Box<dyn SimpleFunctionExecutor>)
313+
};
314+
Ok((output_schema, Box::pin(executor)))
268315
}
269316
}
270317

src/ops/functions/embed_text.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
ops::sdk::*,
66
};
77

8-
#[derive(Deserialize)]
8+
#[derive(Serialize, Deserialize)]
99
struct Spec {
1010
api_type: LlmApiType,
1111
model: String,
@@ -92,8 +92,8 @@ impl SimpleFunctionFactoryBase for Factory {
9292
spec: Spec,
9393
args: Args,
9494
_context: Arc<FlowInstanceContext>,
95-
) -> Result<Box<dyn SimpleFunctionExecutor>> {
96-
Ok(Box::new(Executor { spec, args }))
95+
) -> Result<impl SimpleFunctionExecutor> {
96+
Ok(Executor { spec, args })
9797
}
9898
}
9999

@@ -123,9 +123,10 @@ mod tests {
123123

124124
let input_args_values = vec![text_content.to_string().into()];
125125

126-
let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)];
126+
let input_arg_schemas = &[build_arg_schema("text", BasicValueType::Str)];
127127

128-
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
128+
let result =
129+
test_flow_function(&factory, &spec, input_arg_schemas, input_args_values).await;
129130

130131
if result.is_err() {
131132
eprintln!(

src/ops/functions/extract_by_llm.rs

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,24 @@ impl SimpleFunctionExecutor for Executor {
8181
}
8282

8383
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
84-
let image_bytes: Option<Cow<'_, [u8]>> = self
85-
.args
86-
.image
87-
.as_ref()
88-
.map(|arg| arg.value(&input)?.as_bytes())
89-
.transpose()?
90-
.map(|bytes| Cow::Borrowed(bytes.as_ref()));
91-
let text = self
92-
.args
93-
.text
94-
.as_ref()
95-
.map(|arg| arg.value(&input)?.as_str())
96-
.transpose()?;
84+
let image_bytes: Option<Cow<'_, [u8]>> = if let Some(arg) = self.args.image.as_ref()
85+
&& let Some(value) = arg.value(&input)?.optional()
86+
{
87+
Some(Cow::Borrowed(value.as_bytes()?))
88+
} else {
89+
None
90+
};
91+
92+
let text = if let Some(arg) = self.args.text.as_ref()
93+
&& let Some(value) = arg.value(&input)?.optional()
94+
{
95+
Some(value.as_str()?)
96+
} else {
97+
None
98+
};
9799

98100
if text.is_none() && image_bytes.is_none() {
99-
api_bail!("At least one of `text` or `image` must be provided");
101+
return Ok(Value::Null);
100102
}
101103

102104
let user_prompt = text.map_or("", |v| v);
@@ -147,16 +149,22 @@ impl SimpleFunctionFactoryBase for Factory {
147149
api_bail!("At least one of 'text' or 'image' must be provided");
148150
}
149151

150-
Ok((args, spec.output_type.clone()))
152+
let mut output_type = spec.output_type.clone();
153+
if args.text.as_ref().map_or(true, |arg| arg.typ.nullable)
154+
&& args.image.as_ref().map_or(true, |arg| arg.typ.nullable)
155+
{
156+
output_type.nullable = true;
157+
}
158+
Ok((args, output_type))
151159
}
152160

153161
async fn build_executor(
154162
self: Arc<Self>,
155163
spec: Spec,
156164
resolved_input_schema: Args,
157165
_context: Arc<FlowInstanceContext>,
158-
) -> Result<Box<dyn SimpleFunctionExecutor>> {
159-
Ok(Box::new(Executor::new(spec, resolved_input_schema).await?))
166+
) -> Result<impl SimpleFunctionExecutor> {
167+
Executor::new(spec, resolved_input_schema).await
160168
}
161169
}
162170

@@ -205,9 +213,10 @@ mod tests {
205213

206214
let input_args_values = vec![text_content.to_string().into()];
207215

208-
let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)];
216+
let input_arg_schemas = &[build_arg_schema("text", BasicValueType::Str)];
209217

210-
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
218+
let result =
219+
test_flow_function(&factory, &spec, input_arg_schemas, input_args_values).await;
211220

212221
if result.is_err() {
213222
eprintln!(
@@ -253,4 +262,34 @@ mod tests {
253262
_ => panic!("Expected Value::Struct, got {value:?}"),
254263
}
255264
}
265+
266+
#[tokio::test]
267+
#[ignore = "This test requires an OpenAI API key or a configured local LLM and may make network calls."]
268+
async fn test_null_inputs() {
269+
let factory = Arc::new(Factory);
270+
let spec = Spec {
271+
llm_spec: LlmSpec {
272+
api_type: crate::llm::LlmApiType::OpenAi,
273+
model: "gpt-4o".to_string(),
274+
address: None,
275+
api_config: None,
276+
},
277+
output_type: make_output_type(BasicValueType::Str),
278+
instruction: None,
279+
};
280+
let input_arg_schemas = &[
281+
(
282+
Some("text"),
283+
make_output_type(BasicValueType::Str).with_nullable(true),
284+
),
285+
(
286+
Some("image"),
287+
make_output_type(BasicValueType::Bytes).with_nullable(true),
288+
),
289+
];
290+
let input_args_values = vec![Value::Null, Value::Null];
291+
let result =
292+
test_flow_function(&factory, &spec, input_arg_schemas, input_args_values).await;
293+
assert_eq!(result.unwrap(), Value::Null);
294+
}
256295
}

src/ops/functions/parse_json.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ impl SimpleFunctionFactoryBase for Factory {
9898
_spec: EmptySpec,
9999
args: Args,
100100
_context: Arc<FlowInstanceContext>,
101-
) -> Result<Box<dyn SimpleFunctionExecutor>> {
102-
Ok(Box::new(Executor { args }))
101+
) -> Result<impl SimpleFunctionExecutor> {
102+
Ok(Executor { args })
103103
}
104104
}
105105

@@ -119,12 +119,13 @@ mod tests {
119119

120120
let input_args_values = vec![json_string_content.to_string().into(), lang_value.clone()];
121121

122-
let input_arg_schemas = vec![
122+
let input_arg_schemas = &[
123123
build_arg_schema("text", BasicValueType::Str),
124124
build_arg_schema("language", BasicValueType::Str),
125125
];
126126

127-
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
127+
let result =
128+
test_flow_function(&factory, &spec, input_arg_schemas, input_args_values).await;
128129

129130
assert!(
130131
result.is_ok(),

0 commit comments

Comments
 (0)