Skip to content

Commit 6905cd1

Browse files
committed
added performance reports for service checks
1 parent 683b8a9 commit 6905cd1

File tree

13 files changed

+204
-75
lines changed

13 files changed

+204
-75
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

compute/src/main.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,27 @@ async fn main() -> Result<()> {
8686

8787
// check services & models, will exit if there is an error
8888
// since service check can take time, we allow early-exit here as well
89-
tokio::select! {
89+
let model_perf = tokio::select! {
9090
result = config.executors.check_services() => result,
9191
_ = cancellation.cancelled() => {
9292
log::info!("Service check cancelled, exiting.");
9393
return Ok(());
9494
}
9595
}?;
96-
log::warn!(
97-
"Using models: {}",
98-
config.executors.get_model_names().join(", ")
96+
log::info!(
97+
"Using models: {}\n{}",
98+
config.executors.get_model_names().join(", "),
99+
model_perf
100+
.iter()
101+
.map(|(model, perf)| format!("{}: {}", model, perf))
102+
.collect::<Vec<_>>()
103+
.join("\n")
99104
);
100105

101106
// create the node
102107
let batch_size = config.batch_size;
103-
let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?;
108+
let (mut node, p2p, worker_batch, worker_single) =
109+
DriaComputeNode::new(config, model_perf).await?;
104110

105111
// spawn p2p client first
106112
log::info!("Spawning peer-to-peer client thread.");

compute/src/node/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use dkn_executor::Model;
12
use dkn_p2p::{
23
libp2p::PeerId, DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, DriaReqResMessage,
34
};
4-
use dkn_utils::crypto::secret_to_keypair;
5+
use dkn_utils::{crypto::secret_to_keypair, payloads::SpecModelPerformance};
56
use eyre::Result;
67
use std::collections::{HashMap, HashSet};
78
use tokio::sync::mpsc;
@@ -68,6 +69,7 @@ impl DriaComputeNode {
6869
/// Returns the node instance and p2p client together. P2p MUST be run in a separate task before this node is used at all.
6970
pub async fn new(
7071
mut config: DriaComputeNodeConfig,
72+
model_perf: HashMap<Model, SpecModelPerformance>,
7173
) -> Result<(
7274
DriaComputeNode,
7375
DriaP2PClient,
@@ -124,7 +126,7 @@ impl DriaComputeNode {
124126
let model_names = config.executors.get_model_names();
125127
let points_client = DriaPointsClient::new(&config.address, &config.network)?;
126128

127-
let spec_collector = SpecCollector::new(model_names.clone(), config.version);
129+
let spec_collector = SpecCollector::new(model_names.clone(), model_perf, config.version);
128130
Ok((
129131
DriaComputeNode {
130132
config,

compute/src/utils/specs.rs

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
use dkn_utils::{payloads::Specs, SemanticVersion};
1+
use std::collections::HashMap;
2+
3+
use dkn_executor::Model;
4+
use dkn_utils::{
5+
payloads::{SpecModelPerformance, Specs},
6+
SemanticVersion,
7+
};
28
use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind};
39

410
pub struct SpecCollector {
@@ -7,6 +13,8 @@ pub struct SpecCollector {
713
system: sysinfo::System,
814
/// Used models.
915
models: Vec<String>,
16+
/// Model performances
17+
model_perf: HashMap<String, SpecModelPerformance>,
1018
/// Version string.
1119
version: String,
1220
// GPU adapter infos, showing information about the available GPUs.
@@ -20,10 +28,18 @@ pub struct SpecCollector {
2028
// }
2129

2230
impl SpecCollector {
23-
pub fn new(models: Vec<String>, version: SemanticVersion) -> Self {
31+
pub fn new(
32+
models: Vec<String>,
33+
model_perf: HashMap<Model, SpecModelPerformance>,
34+
version: SemanticVersion,
35+
) -> Self {
2436
SpecCollector {
2537
system: sysinfo::System::new_with_specifics(Self::get_refresh_specifics()),
2638
models,
39+
model_perf: model_perf
40+
.into_iter()
41+
.map(|(k, v)| (k.to_string(), v))
42+
.collect(),
2743
version: version.to_string(),
2844
// gpus: wgpu::Instance::default()
2945
// .enumerate_adapters(wgpu::Backends::all())
@@ -55,6 +71,7 @@ impl SpecCollector {
5571
lookup: public_ip_address::perform_lookup(None).await.ok(),
5672
models: self.models.clone(),
5773
version: self.version.clone(),
74+
model_perf: self.model_perf.clone(),
5875
// gpus: self.gpus.clone(),
5976
}
6077
}
@@ -64,13 +81,18 @@ mod tests {
6481
use super::*;
6582

6683
#[tokio::test]
67-
async fn test_print_specs() {
84+
async fn test_specs_serialization() {
6885
let mut spec_collector = SpecCollector::new(
69-
vec!["gpt-4o".to_string()],
86+
vec![Model::Gemma3_4b.to_string()],
87+
HashMap::from_iter([
88+
(Model::Gemma3_4b, SpecModelPerformance::PassedWithTPS(100.0)),
89+
(Model::GPT4oMini, SpecModelPerformance::NotFound),
90+
(Model::Gemma3_27b, SpecModelPerformance::ExecutionFailed),
91+
]),
7092
SemanticVersion {
71-
major: 0,
72-
minor: 1,
73-
patch: 0,
93+
major: 4,
94+
minor: 5,
95+
patch: 1,
7496
},
7597
);
7698
let specs = spec_collector.collect().await;
@@ -82,6 +104,9 @@ mod tests {
82104
assert!(!specs.arch.is_empty());
83105
assert!(specs.lookup.is_some());
84106
assert!(!specs.models.is_empty());
85-
assert_eq!(specs.version, "0.1.0");
107+
assert_eq!(specs.version, "4.5.1");
108+
109+
// should be serializable to JSON
110+
assert!(serde_json::to_string_pretty(&specs).is_ok())
86111
}
87112
}

executor/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ thiserror.workspace = true
2929
enum-iterator = "2.1.0"
3030
rig-core = "0.11.1"
3131
ollama-rs = { version = "0.3.0", features = ["tokio", "rustls", "stream"] }
32+
dkn-utils = { path = "../utils" }
3233

3334
[dev-dependencies]
3435
# only used for tests

executor/src/executors/gemini.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
use dkn_utils::payloads::SpecModelPerformance;
12
use eyre::{eyre, Context, Result};
23
use reqwest::Client;
34
use rig::{
45
completion::{Chat, PromptError},
56
providers::gemini,
67
};
78
use serde::Deserialize;
8-
use std::collections::HashSet;
9+
use std::collections::{HashMap, HashSet};
910

1011
use crate::{Model, TaskBody};
1112

@@ -43,8 +44,12 @@ impl GeminiClient {
4344
}
4445

4546
/// Check if requested models exist & are available in the OpenAI account.
46-
pub async fn check(&self, models: &mut HashSet<Model>) -> Result<()> {
47+
pub async fn check(
48+
&self,
49+
models: &mut HashSet<Model>,
50+
) -> Result<HashMap<Model, SpecModelPerformance>> {
4751
let mut models_to_remove = Vec::new();
52+
let mut model_performances = HashMap::new();
4853
log::info!("Checking Gemini requirements");
4954

5055
// check if models exist and select those that are available
@@ -61,7 +66,10 @@ impl GeminiClient {
6166
requested_model
6267
);
6368
models_to_remove.push(requested_model);
64-
} else
69+
model_performances.insert(requested_model, SpecModelPerformance::NotFound);
70+
continue;
71+
}
72+
6573
// make a dummy request
6674
if let Err(err) = self
6775
.execute(TaskBody::new_prompt("What is 2 + 2?", requested_model))
@@ -73,15 +81,20 @@ impl GeminiClient {
7381
err
7482
);
7583
models_to_remove.push(requested_model);
84+
model_performances.insert(requested_model, SpecModelPerformance::ExecutionFailed);
85+
continue;
7686
}
87+
88+
// record the performance of the model
89+
model_performances.insert(requested_model, SpecModelPerformance::Passed);
7790
}
7891

7992
// remove models that are not available
8093
for model in models_to_remove.iter() {
8194
models.remove(model);
8295
}
8396

84-
Ok(())
97+
Ok(model_performances)
8598
}
8699

87100
/// Returns the list of models available to this account.

executor/src/executors/mod.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::{Model, ModelProvider, TaskBody};
2+
use dkn_utils::payloads::SpecModelPerformance;
23
use rig::completion::PromptError;
34
use std::collections::{HashMap, HashSet};
45

@@ -50,7 +51,7 @@ impl DriaExecutor {
5051
pub async fn check(
5152
&self,
5253
models: &mut HashSet<Model>,
53-
) -> eyre::Result<HashMap<Model, ModelPerformanceMetric>> {
54+
) -> eyre::Result<HashMap<Model, SpecModelPerformance>> {
5455
match self {
5556
DriaExecutor::Ollama(provider) => provider.check(models).await,
5657
DriaExecutor::OpenAI(provider) => provider.check(models).await,
@@ -59,9 +60,3 @@ impl DriaExecutor {
5960
}
6061
}
6162
}
62-
63-
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64-
pub enum ModelPerformanceMetric {
65-
Latency(f64), // in seconds
66-
TPS(f64), // (eval) tokens per second
67-
}

executor/src/executors/ollama.rs

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
use dkn_utils::payloads::SpecModelPerformance;
12
use eyre::{Context, Result};
23
use ollama_rs::generation::completion::request::GenerationRequest;
34
use rig::completion::{Chat, PromptError};
45
use rig::providers::ollama;
6+
use std::collections::HashMap;
57
use std::time::Duration;
68
use std::{collections::HashSet, env};
79

@@ -78,7 +80,10 @@ impl OllamaClient {
7880
}
7981

8082
/// Check if requested models exist in Ollama & test them using a dummy prompt.
81-
pub async fn check(&self, models: &mut HashSet<Model>) -> Result<()> {
83+
pub async fn check(
84+
&self,
85+
models: &mut HashSet<Model>,
86+
) -> Result<HashMap<Model, SpecModelPerformance>> {
8287
log::info!(
8388
"Checking Ollama requirements (auto-pull {}, timeout: {}s, min tps: {})",
8489
if self.auto_pull { "on" } else { "off" },
@@ -101,6 +106,7 @@ impl OllamaClient {
101106
// check external models & pull them if available
102107
// iterate over models and remove bad ones
103108
let mut models_to_remove = Vec::new();
109+
let mut model_performances = HashMap::new();
104110
for model in models.iter() {
105111
// pull the model if it is not in the local models
106112
if !local_models.contains(&model.to_string()) {
@@ -117,8 +123,13 @@ impl OllamaClient {
117123
}
118124

119125
// test its performance
120-
if !self.test_performance(model).await {
126+
let perf = self.measure_tps_with_warmup(model).await;
127+
if let SpecModelPerformance::PassedWithTPS(_) = perf {
128+
model_performances.insert(*model, perf);
129+
} else {
130+
// if its anything but PassedWithTPS, remove the model
121131
models_to_remove.push(*model);
132+
model_performances.insert(*model, perf);
122133
}
123134
}
124135

@@ -133,7 +144,7 @@ impl OllamaClient {
133144
log::info!("Ollama checks are finished, using models: {:#?}", models);
134145
}
135146

136-
Ok(())
147+
Ok(model_performances)
137148
}
138149

139150
/// Pulls a model from Ollama.
@@ -154,7 +165,7 @@ impl OllamaClient {
154165
///
155166
/// This is to see if a given system can execute tasks for their chosen models,
156167
/// e.g. if they have enough RAM/CPU and such.
157-
pub async fn test_performance(&self, model: &Model) -> bool {
168+
pub async fn measure_tps_with_warmup(&self, model: &Model) -> SpecModelPerformance {
158169
const TEST_PROMPT: &str = "Please write a poem about Kapadokya.";
159170
const WARMUP_PROMPT: &str = "Write a short poem about hedgehogs and squirrels.";
160171

@@ -171,44 +182,46 @@ impl OllamaClient {
171182
.await
172183
{
173184
log::warn!("Ignoring model {model}: {err}");
174-
return false;
185+
return SpecModelPerformance::ExecutionFailed;
175186
}
176187

177188
// then, run a sample generation with timeout and measure tps
178-
tokio::select! {
179-
_ = tokio::time::sleep(PERFORMANCE_TIMEOUT) => {
180-
log::warn!("Ignoring model {model}: Timed out");
181-
},
182-
result = self.ollama_rs_client.generate(GenerationRequest::new(
189+
let Ok(result) = tokio::time::timeout(
190+
PERFORMANCE_TIMEOUT,
191+
self.ollama_rs_client.generate(GenerationRequest::new(
183192
model.to_string(),
184193
TEST_PROMPT.to_string(),
185-
)) => {
186-
match result {
187-
Ok(response) => {
188-
let tps = (response.eval_count.unwrap_or_default() as f64)
189-
/ (response.eval_duration.unwrap_or(1) as f64)
190-
* 1_000_000_000f64;
191-
192-
if tps >= PERFORMANCE_MIN_TPS {
193-
log::info!("Model {} passed the test with tps: {}", model, tps);
194-
return true;
195-
}
196-
197-
log::warn!(
198-
"Ignoring model {}: tps too low ({:.3} < {:.3})",
199-
model,
200-
tps,
201-
PERFORMANCE_MIN_TPS
202-
);
203-
}
204-
Err(e) => {
205-
log::warn!("Ignoring model {}: Task failed with error {}", model, e);
206-
}
207-
}
208-
}
194+
)),
195+
)
196+
.await
197+
else {
198+
log::warn!("Ignoring model {model}: Timed out");
199+
return SpecModelPerformance::Timeout;
209200
};
210201

211-
false
202+
// check the result
203+
match result {
204+
Ok(response) => {
205+
let tps = (response.eval_count.unwrap_or_default() as f64)
206+
/ (response.eval_duration.unwrap_or(1) as f64)
207+
* 1_000_000_000f64;
208+
209+
if tps >= PERFORMANCE_MIN_TPS {
210+
log::info!("Model {model} passed the test with tps: {tps}");
211+
SpecModelPerformance::PassedWithTPS(tps)
212+
} else {
213+
log::warn!(
214+
"Ignoring model {model}: tps too low ({tps:.3} < {:.3})",
215+
PERFORMANCE_MIN_TPS
216+
);
217+
SpecModelPerformance::FailedWithTPS(tps)
218+
}
219+
}
220+
Err(err) => {
221+
log::warn!("Ignoring model {model} due to: {err}");
222+
SpecModelPerformance::ExecutionFailed
223+
}
224+
}
212225
}
213226
}
214227

0 commit comments

Comments
 (0)