1
1
use colored:: Colorize ;
2
- use dkn_executor:: TaskBody ;
2
+ use dkn_executor:: { CompletionError , ModelProvider , PromptError , TaskBody } ;
3
3
use dkn_p2p:: libp2p:: request_response:: ResponseChannel ;
4
- use dkn_utils:: payloads:: { TaskRequestPayload , TaskResponsePayload , TaskStats , TASK_RESULT_TOPIC } ;
4
+ use dkn_utils:: payloads:: {
5
+ TaskError , TaskRequestPayload , TaskResponsePayload , TaskStats , TASK_RESULT_TOPIC ,
6
+ } ;
5
7
use dkn_utils:: DriaMessage ;
6
8
use eyre:: { Context , Result } ;
7
9
@@ -25,27 +27,23 @@ impl TaskResponder {
25
27
let task = compute_message
26
28
. parse_payload :: < TaskRequestPayload < serde_json:: Value > > ( )
27
29
. wrap_err ( "could not parse task request payload" ) ?;
28
- let task_body = match serde_json:: from_value :: < TaskBody > ( task. input )
29
- . wrap_err ( "could not parse task body" )
30
- {
30
+ let task_body = match serde_json:: from_value :: < TaskBody > ( task. input ) {
31
31
Ok ( task_body) => task_body,
32
32
Err ( err) => {
33
- let err_string = format ! ( "{:#}" , err) ;
34
33
log:: error!(
35
- "Task {}/{} failed due to parsing error: {}" ,
34
+ "Task {}/{} failed due to parsing error: {err }" ,
36
35
task. file_id,
37
36
task. row_id,
38
- err_string
39
37
) ;
40
38
41
39
// prepare error payload
42
40
let error_payload = TaskResponsePayload {
43
41
result : None ,
44
- error : Some ( err_string ) ,
42
+ error : Some ( TaskError :: ParseError ( err . to_string ( ) ) ) ,
45
43
row_id : task. row_id ,
46
44
file_id : task. file_id ,
47
45
task_id : task. task_id ,
48
- model : Default :: default ( ) ,
46
+ model : "<n/a>" . to_string ( ) , // no model available due to parsing error
49
47
stats : TaskStats :: new ( ) ,
50
48
} ;
51
49
@@ -56,7 +54,8 @@ impl TaskResponder {
56
54
let response = node. new_message ( error_payload_str, TASK_RESULT_TOPIC ) ;
57
55
node. p2p . respond ( response. into ( ) , channel) . await ?;
58
56
59
- return Err ( err) ;
57
+ // return with error
58
+ eyre:: bail!( "could not parse task body: {err}" )
60
59
}
61
60
} ;
62
61
@@ -75,7 +74,7 @@ impl TaskResponder {
75
74
let task_metadata = TaskWorkerMetadata {
76
75
task_id : task. task_id ,
77
76
file_id : task. file_id ,
78
- model_name : task_body. model . to_string ( ) ,
77
+ model : task_body. model ,
79
78
channel,
80
79
} ;
81
80
let task_input = TaskWorkerInput {
@@ -112,7 +111,7 @@ impl TaskResponder {
112
111
file_id : task_metadata. file_id ,
113
112
task_id : task_metadata. task_id ,
114
113
row_id : task_output. row_id ,
115
- model : task_metadata. model_name ,
114
+ model : task_metadata. model . to_string ( ) ,
116
115
stats : task_output
117
116
. stats
118
117
. record_published_at ( )
@@ -125,22 +124,24 @@ impl TaskResponder {
125
124
}
126
125
Err ( err) => {
127
126
// use pretty display string for error logging with causes
128
- let err_string = format ! ( "{:#}" , err) ;
129
127
log:: error!(
130
- "Task {}/{} failed: {}" ,
128
+ "Task {}/{} failed: {:# }" ,
131
129
task_metadata. file_id,
132
130
task_output. row_id,
133
- err_string
131
+ err
134
132
) ;
135
133
136
134
// prepare error payload
137
135
let error_payload = TaskResponsePayload {
138
136
result : None ,
139
- error : Some ( err_string) ,
137
+ error : Some ( map_prompt_error_to_task_error (
138
+ task_metadata. model . provider ( ) ,
139
+ err,
140
+ ) ) ,
140
141
row_id : task_output. row_id ,
141
142
file_id : task_metadata. file_id ,
142
143
task_id : task_metadata. task_id ,
143
- model : task_metadata. model_name ,
144
+ model : task_metadata. model . to_string ( ) ,
144
145
stats : task_output
145
146
. stats
146
147
. record_published_at ( )
@@ -161,3 +162,111 @@ impl TaskResponder {
161
162
Ok ( ( ) )
162
163
}
163
164
}
165
+
166
+ /// Maps a [`PromptError`] to a [`TaskError`] with respect to the given provider.
167
+ fn map_prompt_error_to_task_error ( provider : ModelProvider , err : PromptError ) -> TaskError {
168
+ match & err {
169
+ // if the error is a provider error, we can try to parse it
170
+ PromptError :: CompletionError ( CompletionError :: ProviderError ( err_inner) ) => {
171
+ /// A wrapper for `{ error: T }` to match the provider error format.
172
+ #[ derive( Clone , serde:: Deserialize ) ]
173
+ struct ErrorObject < T > {
174
+ error : T ,
175
+ }
176
+
177
+ match provider {
178
+ ModelProvider :: Gemini => {
179
+ /// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
180
+ #[ derive( Clone , serde:: Deserialize ) ]
181
+ pub struct GeminiError {
182
+ code : u32 ,
183
+ message : String ,
184
+ status : String ,
185
+ }
186
+
187
+ serde_json:: from_str :: < ErrorObject < GeminiError > > ( err_inner) . map (
188
+ |ErrorObject {
189
+ error : gemini_error,
190
+ } | TaskError :: ProviderError {
191
+ code : format ! ( "{} ({})" , gemini_error. code, gemini_error. status) ,
192
+ message : gemini_error. message ,
193
+ provider : provider. to_string ( ) ,
194
+ } ,
195
+ )
196
+ }
197
+ ModelProvider :: OpenAI => {
198
+ /// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
199
+ #[ derive( Clone , serde:: Deserialize ) ]
200
+ pub struct OpenAIError {
201
+ code : String ,
202
+ message : String ,
203
+ }
204
+
205
+ serde_json:: from_str :: < ErrorObject < OpenAIError > > ( err_inner) . map (
206
+ |ErrorObject {
207
+ error : openai_error,
208
+ } | TaskError :: ProviderError {
209
+ code : openai_error. code ,
210
+ message : openai_error. message ,
211
+ provider : provider. to_string ( ) ,
212
+ } ,
213
+ )
214
+ }
215
+ ModelProvider :: OpenRouter => {
216
+ /// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors).
217
+ #[ derive( Clone , serde:: Deserialize ) ]
218
+ pub struct OpenRouterError {
219
+ code : u32 ,
220
+ message : String ,
221
+ }
222
+
223
+ serde_json:: from_str :: < ErrorObject < OpenRouterError > > ( err_inner) . map (
224
+ |ErrorObject {
225
+ error : openrouter_error,
226
+ } | {
227
+ TaskError :: ProviderError {
228
+ code : openrouter_error. code . to_string ( ) ,
229
+ message : openrouter_error. message ,
230
+ provider : provider. to_string ( ) ,
231
+ }
232
+ } ,
233
+ )
234
+ }
235
+ ModelProvider :: Ollama => serde_json:: from_str :: < ErrorObject < String > > ( err_inner)
236
+ . map (
237
+ // Ollama just returns a string error message
238
+ |ErrorObject {
239
+ error : ollama_error,
240
+ } | {
241
+ // based on the error message, we can come up with out own "dummy" codes
242
+ let code = if ollama_error. contains ( "server busy, please try again." ) {
243
+ "server_busy"
244
+ } else if ollama_error. contains ( "model requires more system memory" ) {
245
+ "model_requires_more_memory"
246
+ } else if ollama_error. contains ( "cudaMalloc failed: out of memory" ) {
247
+ "cuda_malloc_failed"
248
+ } else if ollama_error. contains ( "CUDA error: out of memory" ) {
249
+ "cuda_oom"
250
+ } else {
251
+ "unknown"
252
+ } ;
253
+
254
+ TaskError :: ProviderError {
255
+ code : code. to_string ( ) ,
256
+ message : ollama_error,
257
+ provider : provider. to_string ( ) ,
258
+ }
259
+ } ,
260
+ ) ,
261
+ }
262
+ // if we couldn't parse it, just return a generic prompt error
263
+ . unwrap_or ( TaskError :: ExecutorError ( err_inner. clone ( ) ) )
264
+ }
265
+ // if its a http error, we can try to parse it as well
266
+ PromptError :: CompletionError ( CompletionError :: HttpError ( err_inner) ) => {
267
+ TaskError :: HttpError ( err_inner. to_string ( ) )
268
+ }
269
+ // if it's not a completion error, we just return the error as is
270
+ err => TaskError :: Other ( err. to_string ( ) ) ,
271
+ }
272
+ }
0 commit comments