Skip to content

Commit 831e087

Browse files
authored
Merge pull request #198 from firstbatchxyz/erhant/better-error-reporting
feat: better error reporting
2 parents 0abd136 + 6905cd1 commit 831e087

File tree

20 files changed

+406
-109
lines changed

20 files changed

+406
-109
lines changed

Cargo.lock

Lines changed: 5 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default-members = ["compute"]
77

88
[workspace.package]
99
edition = "2021"
10-
version = "0.5.7"
10+
version = "0.6.0"
1111
license = "Apache-2.0"
1212
readme = "README.md"
1313

compute/src/main.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,27 @@ async fn main() -> Result<()> {
8686

8787
// check services & models, will exit if there is an error
8888
// since service check can take time, we allow early-exit here as well
89-
tokio::select! {
89+
let model_perf = tokio::select! {
9090
result = config.executors.check_services() => result,
9191
_ = cancellation.cancelled() => {
9292
log::info!("Service check cancelled, exiting.");
9393
return Ok(());
9494
}
9595
}?;
96-
log::warn!(
97-
"Using models: {}",
98-
config.executors.get_model_names().join(", ")
96+
log::info!(
97+
"Using models: {}\n{}",
98+
config.executors.get_model_names().join(", "),
99+
model_perf
100+
.iter()
101+
.map(|(model, perf)| format!("{}: {}", model, perf))
102+
.collect::<Vec<_>>()
103+
.join("\n")
99104
);
100105

101106
// create the node
102107
let batch_size = config.batch_size;
103-
let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?;
108+
let (mut node, p2p, worker_batch, worker_single) =
109+
DriaComputeNode::new(config, model_perf).await?;
104110

105111
// spawn p2p client first
106112
log::info!("Spawning peer-to-peer client thread.");

compute/src/node/core.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,16 @@ impl DriaComputeNode {
7676
_ = diagnostic_refresh_interval.tick() => self.handle_diagnostic_refresh().await,
7777

7878
// check RPC, and get a new one if we are disconnected
79-
_ = rpc_liveness_refresh_interval.tick() => self.handle_rpc_liveness_check().await,
79+
_ = rpc_liveness_refresh_interval.tick() => {
80+
let is_connected = self.handle_rpc_liveness_check().await;
81+
if !is_connected {
82+
// make sure we reset the heartbeat and specs intervals so that
83+
// we dont wait the entire duration for this new connection
84+
log::info!("Connecting was re-attempted, resetting timers.");
85+
heartbeat_interval.reset_after(Duration::from_secs(5));
86+
specs_interval.reset_after(Duration::from_secs(5));
87+
}
88+
},
8089

8190
// log points every now and then
8291
_ = points_refresh_interval.tick() => self.handle_points_refresh().await,

compute/src/node/diagnostic.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ impl DriaComputeNode {
8585

8686
/// Dials the existing RPC node if we are not connected to it.
8787
///
88-
/// If there is an error while doing that,
89-
/// it will try to get a new RPC node and dial it.
90-
pub(crate) async fn handle_rpc_liveness_check(&mut self) {
88+
/// If there is an error while doing that, it will try to get a new RPC node and dial it.
89+
///
90+
/// Returns `true` if the RPC is connected, `false` otherwise.
91+
pub(crate) async fn handle_rpc_liveness_check(&mut self) -> bool {
9192
log::debug!("Checking RPC connections for diagnostics.");
9293

9394
// check if we are connected
@@ -124,6 +125,9 @@ impl DriaComputeNode {
124125
} else {
125126
log::debug!("Connection with {} is intact.", self.dria_rpc.peer_id);
126127
}
128+
129+
// return the connection status
130+
is_connected
127131
}
128132

129133
/// Updates the points for the given address.

compute/src/node/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use dkn_executor::Model;
12
use dkn_p2p::{
23
libp2p::PeerId, DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, DriaReqResMessage,
34
};
4-
use dkn_utils::crypto::secret_to_keypair;
5+
use dkn_utils::{crypto::secret_to_keypair, payloads::SpecModelPerformance};
56
use eyre::Result;
67
use std::collections::{HashMap, HashSet};
78
use tokio::sync::mpsc;
@@ -68,6 +69,7 @@ impl DriaComputeNode {
6869
/// Returns the node instance and p2p client together. P2p MUST be run in a separate task before this node is used at all.
6970
pub async fn new(
7071
mut config: DriaComputeNodeConfig,
72+
model_perf: HashMap<Model, SpecModelPerformance>,
7173
) -> Result<(
7274
DriaComputeNode,
7375
DriaP2PClient,
@@ -124,7 +126,7 @@ impl DriaComputeNode {
124126
let model_names = config.executors.get_model_names();
125127
let points_client = DriaPointsClient::new(&config.address, &config.network)?;
126128

127-
let spec_collector = SpecCollector::new(model_names.clone(), config.version);
129+
let spec_collector = SpecCollector::new(model_names.clone(), model_perf, config.version);
128130
Ok((
129131
DriaComputeNode {
130132
config,

compute/src/reqres/task.rs

Lines changed: 127 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use colored::Colorize;
2-
use dkn_executor::TaskBody;
2+
use dkn_executor::{CompletionError, ModelProvider, PromptError, TaskBody};
33
use dkn_p2p::libp2p::request_response::ResponseChannel;
4-
use dkn_utils::payloads::{TaskRequestPayload, TaskResponsePayload, TaskStats, TASK_RESULT_TOPIC};
4+
use dkn_utils::payloads::{
5+
TaskError, TaskRequestPayload, TaskResponsePayload, TaskStats, TASK_RESULT_TOPIC,
6+
};
57
use dkn_utils::DriaMessage;
68
use eyre::{Context, Result};
79

@@ -25,27 +27,23 @@ impl TaskResponder {
2527
let task = compute_message
2628
.parse_payload::<TaskRequestPayload<serde_json::Value>>()
2729
.wrap_err("could not parse task request payload")?;
28-
let task_body = match serde_json::from_value::<TaskBody>(task.input)
29-
.wrap_err("could not parse task body")
30-
{
30+
let task_body = match serde_json::from_value::<TaskBody>(task.input) {
3131
Ok(task_body) => task_body,
3232
Err(err) => {
33-
let err_string = format!("{:#}", err);
3433
log::error!(
35-
"Task {}/{} failed due to parsing error: {}",
34+
"Task {}/{} failed due to parsing error: {err}",
3635
task.file_id,
3736
task.row_id,
38-
err_string
3937
);
4038

4139
// prepare error payload
4240
let error_payload = TaskResponsePayload {
4341
result: None,
44-
error: Some(err_string),
42+
error: Some(TaskError::ParseError(err.to_string())),
4543
row_id: task.row_id,
4644
file_id: task.file_id,
4745
task_id: task.task_id,
48-
model: Default::default(),
46+
model: "<n/a>".to_string(), // no model available due to parsing error
4947
stats: TaskStats::new(),
5048
};
5149

@@ -56,7 +54,8 @@ impl TaskResponder {
5654
let response = node.new_message(error_payload_str, TASK_RESULT_TOPIC);
5755
node.p2p.respond(response.into(), channel).await?;
5856

59-
return Err(err);
57+
// return with error
58+
eyre::bail!("could not parse task body: {err}")
6059
}
6160
};
6261

@@ -75,7 +74,7 @@ impl TaskResponder {
7574
let task_metadata = TaskWorkerMetadata {
7675
task_id: task.task_id,
7776
file_id: task.file_id,
78-
model_name: task_body.model.to_string(),
77+
model: task_body.model,
7978
channel,
8079
};
8180
let task_input = TaskWorkerInput {
@@ -112,7 +111,7 @@ impl TaskResponder {
112111
file_id: task_metadata.file_id,
113112
task_id: task_metadata.task_id,
114113
row_id: task_output.row_id,
115-
model: task_metadata.model_name,
114+
model: task_metadata.model.to_string(),
116115
stats: task_output
117116
.stats
118117
.record_published_at()
@@ -125,22 +124,24 @@ impl TaskResponder {
125124
}
126125
Err(err) => {
127126
// use pretty display string for error logging with causes
128-
let err_string = format!("{:#}", err);
129127
log::error!(
130-
"Task {}/{} failed: {}",
128+
"Task {}/{} failed: {:#}",
131129
task_metadata.file_id,
132130
task_output.row_id,
133-
err_string
131+
err
134132
);
135133

136134
// prepare error payload
137135
let error_payload = TaskResponsePayload {
138136
result: None,
139-
error: Some(err_string),
137+
error: Some(map_prompt_error_to_task_error(
138+
task_metadata.model.provider(),
139+
err,
140+
)),
140141
row_id: task_output.row_id,
141142
file_id: task_metadata.file_id,
142143
task_id: task_metadata.task_id,
143-
model: task_metadata.model_name,
144+
model: task_metadata.model.to_string(),
144145
stats: task_output
145146
.stats
146147
.record_published_at()
@@ -161,3 +162,111 @@ impl TaskResponder {
161162
Ok(())
162163
}
163164
}
165+
166+
/// Maps a [`PromptError`] to a [`TaskError`] with respect to the given provider.
167+
fn map_prompt_error_to_task_error(provider: ModelProvider, err: PromptError) -> TaskError {
168+
match &err {
169+
// if the error is a provider error, we can try to parse it
170+
PromptError::CompletionError(CompletionError::ProviderError(err_inner)) => {
171+
/// A wrapper for `{ error: T }` to match the provider error format.
172+
#[derive(Clone, serde::Deserialize)]
173+
struct ErrorObject<T> {
174+
error: T,
175+
}
176+
177+
match provider {
178+
ModelProvider::Gemini => {
179+
/// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
180+
#[derive(Clone, serde::Deserialize)]
181+
pub struct GeminiError {
182+
code: u32,
183+
message: String,
184+
status: String,
185+
}
186+
187+
serde_json::from_str::<ErrorObject<GeminiError>>(err_inner).map(
188+
|ErrorObject {
189+
error: gemini_error,
190+
}| TaskError::ProviderError {
191+
code: format!("{} ({})", gemini_error.code, gemini_error.status),
192+
message: gemini_error.message,
193+
provider: provider.to_string(),
194+
},
195+
)
196+
}
197+
ModelProvider::OpenAI => {
198+
/// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
199+
#[derive(Clone, serde::Deserialize)]
200+
pub struct OpenAIError {
201+
code: String,
202+
message: String,
203+
}
204+
205+
serde_json::from_str::<ErrorObject<OpenAIError>>(err_inner).map(
206+
|ErrorObject {
207+
error: openai_error,
208+
}| TaskError::ProviderError {
209+
code: openai_error.code,
210+
message: openai_error.message,
211+
provider: provider.to_string(),
212+
},
213+
)
214+
}
215+
ModelProvider::OpenRouter => {
216+
/// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors).
217+
#[derive(Clone, serde::Deserialize)]
218+
pub struct OpenRouterError {
219+
code: u32,
220+
message: String,
221+
}
222+
223+
serde_json::from_str::<ErrorObject<OpenRouterError>>(err_inner).map(
224+
|ErrorObject {
225+
error: openrouter_error,
226+
}| {
227+
TaskError::ProviderError {
228+
code: openrouter_error.code.to_string(),
229+
message: openrouter_error.message,
230+
provider: provider.to_string(),
231+
}
232+
},
233+
)
234+
}
235+
ModelProvider::Ollama => serde_json::from_str::<ErrorObject<String>>(err_inner)
236+
.map(
237+
// Ollama just returns a string error message
238+
|ErrorObject {
239+
error: ollama_error,
240+
}| {
241+
// based on the error message, we can come up with out own "dummy" codes
242+
let code = if ollama_error.contains("server busy, please try again.") {
243+
"server_busy"
244+
} else if ollama_error.contains("model requires more system memory") {
245+
"model_requires_more_memory"
246+
} else if ollama_error.contains("cudaMalloc failed: out of memory") {
247+
"cuda_malloc_failed"
248+
} else if ollama_error.contains("CUDA error: out of memory") {
249+
"cuda_oom"
250+
} else {
251+
"unknown"
252+
};
253+
254+
TaskError::ProviderError {
255+
code: code.to_string(),
256+
message: ollama_error,
257+
provider: provider.to_string(),
258+
}
259+
},
260+
),
261+
}
262+
// if we couldn't parse it, just return a generic prompt error
263+
.unwrap_or(TaskError::ExecutorError(err_inner.clone()))
264+
}
265+
// if its a http error, we can try to parse it as well
266+
PromptError::CompletionError(CompletionError::HttpError(err_inner)) => {
267+
TaskError::HttpError(err_inner.to_string())
268+
}
269+
// if it's not a completion error, we just return the error as is
270+
err => TaskError::Other(err.to_string()),
271+
}
272+
}

0 commit comments

Comments
 (0)