Skip to content

Commit 683b8a9

Browse files
committed
add reset timer logic
1 parent 5f980f7 commit 683b8a9

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

compute/src/node/core.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,16 @@ impl DriaComputeNode {
7676
_ = diagnostic_refresh_interval.tick() => self.handle_diagnostic_refresh().await,
7777

7878
// check RPC, and get a new one if we are disconnected
79-
_ = rpc_liveness_refresh_interval.tick() => self.handle_rpc_liveness_check().await,
79+
_ = rpc_liveness_refresh_interval.tick() => {
80+
let is_connected = self.handle_rpc_liveness_check().await;
81+
if !is_connected {
82+
// make sure we reset the heartbeat and specs intervals so that
83+
// we dont wait the entire duration for this new connection
84+
log::info!("Connecting was re-attempted, resetting timers.");
85+
heartbeat_interval.reset_after(Duration::from_secs(5));
86+
specs_interval.reset_after(Duration::from_secs(5));
87+
}
88+
},
8089

8190
// log points every now and then
8291
_ = points_refresh_interval.tick() => self.handle_points_refresh().await,

compute/src/node/diagnostic.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ impl DriaComputeNode {
8585

8686
/// Dials the existing RPC node if we are not connected to it.
8787
///
88-
/// If there is an error while doing that,
89-
/// it will try to get a new RPC node and dial it.
90-
pub(crate) async fn handle_rpc_liveness_check(&mut self) {
88+
/// If there is an error while doing that, it will try to get a new RPC node and dial it.
89+
///
90+
/// Returns `true` if the RPC is connected, `false` otherwise.
91+
pub(crate) async fn handle_rpc_liveness_check(&mut self) -> bool {
9192
log::debug!("Checking RPC connections for diagnostics.");
9293

9394
// check if we are connected
@@ -124,6 +125,9 @@ impl DriaComputeNode {
124125
} else {
125126
log::debug!("Connection with {} is intact.", self.dria_rpc.peer_id);
126127
}
128+
129+
// return the connection status
130+
is_connected
127131
}
128132

129133
/// Updates the points for the given address.

executor/src/executors/mod.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use crate::ModelProvider;
1+
use crate::{Model, ModelProvider, TaskBody};
22
use rig::completion::PromptError;
3-
use std::collections::HashSet;
3+
use std::collections::{HashMap, HashSet};
44

55
mod ollama;
66
use ollama::OllamaClient;
@@ -35,7 +35,7 @@ impl DriaExecutor {
3535
}
3636

3737
/// Executes the given task using the appropriate provider.
38-
pub async fn execute(&self, task: crate::TaskBody) -> Result<String, PromptError> {
38+
pub async fn execute(&self, task: TaskBody) -> Result<String, PromptError> {
3939
match self {
4040
DriaExecutor::Ollama(provider) => provider.execute(task).await,
4141
DriaExecutor::OpenAI(provider) => provider.execute(task).await,
@@ -47,7 +47,10 @@ impl DriaExecutor {
4747
/// Checks if the requested models exist and are available in the provider's account.
4848
///
4949
/// For Ollama in particular, it also checks if the models are performant enough.
50-
pub async fn check(&self, models: &mut HashSet<crate::Model>) -> eyre::Result<()> {
50+
pub async fn check(
51+
&self,
52+
models: &mut HashSet<Model>,
53+
) -> eyre::Result<HashMap<Model, ModelPerformanceMetric>> {
5154
match self {
5255
DriaExecutor::Ollama(provider) => provider.check(models).await,
5356
DriaExecutor::OpenAI(provider) => provider.check(models).await,
@@ -56,3 +59,9 @@ impl DriaExecutor {
5659
}
5760
}
5861
}
62+
63+
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64+
pub enum ModelPerformanceMetric {
65+
Latency(f64), // in seconds
66+
TPS(f64), // (eval) tokens per second
67+
}

0 commit comments

Comments
 (0)