Skip to content

Commit 6997c8f

Browse files
committed
add model name to response payloads
1 parent 936a9b6 commit 6997c8f

File tree

7 files changed

+38
-19
lines changed

7 files changed

+38
-19
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "dkn-compute"
3-
version = "0.2.7"
3+
version = "0.2.8"
44
edition = "2021"
55
license = "Apache-2.0"
66
readme = "README.md"

src/config/ollama.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use eyre::{eyre, Result};
1+
use eyre::{eyre, Context, Result};
22
use ollama_workflows::{
33
ollama_rs::{
44
generation::{
@@ -63,9 +63,10 @@ impl OllamaConfig {
6363
// Ollama workflows may require specific models to be loaded regardless of the choices
6464
let hardcoded_models = HARDCODED_MODELS.iter().map(|s| s.to_string()).collect();
6565

66+
// auto-pull, its true by default
6667
let auto_pull = std::env::var("OLLAMA_AUTO_PULL")
6768
.map(|s| s == "true")
68-
.unwrap_or_default();
69+
.unwrap_or(true);
6970

7071
Self {
7172
host,
@@ -109,7 +110,9 @@ impl OllamaConfig {
109110
// we dont check workflows for hardcoded models
110111
for model in &self.hardcoded_models {
111112
if !local_models.contains(model) {
112-
self.try_pull(&ollama, model.to_owned()).await?;
113+
self.try_pull(&ollama, model.to_owned())
114+
.await
115+
.wrap_err("Could not pull model")?;
113116
}
114117
}
115118

@@ -118,7 +121,9 @@ impl OllamaConfig {
118121
let mut good_models = Vec::new();
119122
for model in external_models {
120123
if !local_models.contains(&model.to_string()) {
121-
self.try_pull(&ollama, model.to_string()).await?;
124+
self.try_pull(&ollama, model.to_string())
125+
.await
126+
.wrap_err("Could not pull model")?;
122127
}
123128

124129
if self

src/handlers/pingpong.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::ComputeHandler;
12
use crate::{
23
utils::{get_current_time_nanos, DKNMessage},
34
DriaComputeNode,
@@ -8,8 +9,6 @@ use libp2p::gossipsub::MessageAcceptance;
89
use ollama_workflows::{Model, ModelProvider};
910
use serde::{Deserialize, Serialize};
1011

11-
use super::ComputeHandler;
12-
1312
pub struct PingpongHandler;
1413

1514
#[derive(Serialize, Deserialize, Debug, Clone)]

src/handlers/workflow.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pub struct WorkflowHandler;
1515

1616
#[derive(Debug, Deserialize)]
1717
struct WorkflowPayload {
18-
/// Workflow object to be parsed.
18+
/// [Workflow](https://github.com/andthattoo/ollama-workflows/) object to be parsed.
1919
pub(crate) workflow: Workflow,
2020
/// A lıst of model (that can be parsed into `Model`) or model provider names.
2121
/// If model provider is given, the first matching model in the node config is used for that.
@@ -69,7 +69,8 @@ impl ComputeHandler for WorkflowHandler {
6969
let (model_provider, model) = config
7070
.model_config
7171
.get_any_matching_model(task.input.model)?;
72-
log::info!("Using model {} for task {}", model, task.task_id);
72+
let model_name = model.to_string(); // get model name, we will pass it in payload
73+
log::info!("Using model {} for task {}", model_name, task.task_id);
7374

7475
// prepare workflow executor
7576
let executor = if model_provider == ModelProvider::Ollama {
@@ -108,9 +109,10 @@ impl ComputeHandler for WorkflowHandler {
108109
&task.task_id,
109110
&task_public_key,
110111
&config.secret_key,
112+
model_name,
111113
)?;
112-
let payload_str =
113-
serde_json::to_string(&payload).wrap_err("Could not serialize payload")?;
114+
let payload_str = serde_json::to_string(&payload)
115+
.wrap_err("Could not serialize response payload")?;
114116

115117
// publish the result
116118
let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC);
@@ -125,8 +127,9 @@ impl ComputeHandler for WorkflowHandler {
125127
log::error!("Task {} failed: {}", task.task_id, err_string);
126128

127129
// prepare error payload
128-
let error_payload = TaskErrorPayload::new(task.task_id, err_string);
129-
let error_payload_str = serde_json::to_string(&error_payload)?;
130+
let error_payload = TaskErrorPayload::new(task.task_id, err_string, model_name);
131+
let error_payload_str = serde_json::to_string(&error_payload)
132+
.wrap_err("Could not serialize error payload")?;
130133

131134
// publish the error result for diagnostics
132135
let message = DKNMessage::new(error_payload_str, Self::RESPONSE_TOPIC);

src/payloads/error.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@ pub struct TaskErrorPayload {
88
/// The unique identifier of the task.
99
pub task_id: String,
1010
/// The stringified error object
11-
pub(crate) error: String,
11+
pub error: String,
12+
/// Name of the model that caused the error.
13+
pub model: String,
1214
}
1315

1416
impl TaskErrorPayload {
15-
pub fn new(task_id: String, error: String) -> Self {
16-
Self { task_id, error }
17+
pub fn new(task_id: String, error: String, model: String) -> Self {
18+
Self {
19+
task_id,
20+
error,
21+
model,
22+
}
1723
}
1824
}

src/payloads/response.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ pub struct TaskResponsePayload {
1717
pub signature: String,
1818
/// Result encrypted with the public key of the task, Hexadecimally encoded.
1919
pub ciphertext: String,
20+
/// Name of the model used for this task.
21+
pub model: String,
2022
}
2123

2224
impl TaskResponsePayload {
@@ -29,6 +31,7 @@ impl TaskResponsePayload {
2931
task_id: &str,
3032
encrypting_public_key: &PublicKey,
3133
signing_secret_key: &SecretKey,
34+
model: String,
3235
) -> Result<Self> {
3336
// create the message `task_id || payload`
3437
let mut preimage = Vec::new();
@@ -43,6 +46,7 @@ impl TaskResponsePayload {
4346
task_id,
4447
signature,
4548
ciphertext,
49+
model,
4650
})
4751
}
4852
}
@@ -58,6 +62,7 @@ mod tests {
5862
fn test_task_response_payload() {
5963
// this is the result that we are "sending"
6064
const RESULT: &[u8; 44] = b"hey im an LLM and I came up with this output";
65+
const MODEL: &str = "gpt-4-turbo";
6166

6267
// the signer will sign the payload, and it will be verified
6368
let signer_sk = SecretKey::random(&mut thread_rng());
@@ -69,8 +74,9 @@ mod tests {
6974
let task_id = uuid::Uuid::new_v4().to_string();
7075

7176
// creates a signed and encrypted payload
72-
let payload = TaskResponsePayload::new(RESULT, &task_id, &task_pk, &signer_sk)
73-
.expect("Should create payload");
77+
let payload =
78+
TaskResponsePayload::new(RESULT, &task_id, &task_pk, &signer_sk, MODEL.to_string())
79+
.expect("Should create payload");
7480

7581
// decrypt result and compare it to plaintext
7682
let ciphertext_bytes = hex::decode(payload.ciphertext).unwrap();

0 commit comments

Comments
 (0)