dogy_backend_api/service/assistant/
handlers.rs

1use axum::{
2    extract::{Path, State},
3    Extension, Json,
4};
5use futures::future::join_all;
6use reqwest::Client;
7use serde_json::{json, Value};
8use sqlx::{query, query_as};
9use tokio::task;
10use uuid::Uuid;
11
12use crate::{
13    config::load_config, middleware::auth::layer::CurrentUser,
14    service::assistant::models::MessageType, AppState,
15};
16
17use super::models::{AllThreadResponse, DbThread, Message, Thread, ThreadResponse};
18
19pub async fn link_user_to_thread(
20    Extension(current_user): Extension<CurrentUser>,
21    State(state): State<AppState>,
22    Path(thread_id): Path<Uuid>,
23    Json(thread): Json<Thread>,
24) -> Json<Value> {
25    let conn = &*state.db;
26    query(
27        "INSERT INTO user_assistant_threads (user_id, thread_id, thread_title)
28        VALUES ($1, $2, $3);",
29    )
30    .bind(current_user.internal_id.unwrap())
31    .bind(thread_id)
32    .bind(thread.title)
33    .execute(conn)
34    .await
35    .unwrap();
36
37    Json(json!({
38        "message": format!("Thread {} successfully linked to user {}", thread_id, current_user.user_id)
39    }))
40}
41
42pub async fn update_thread_title(
43    State(state): State<AppState>,
44    Path(thread_id): Path<Uuid>,
45    Json(thread): Json<Thread>,
46) -> Json<Value> {
47    let conn = &*state.db;
48    query(
49        r#"UPDATE user_assistant_threads
50        SET thread_title = $2
51        WHERE thread_id = $1;
52        "#,
53    )
54    .bind(thread_id)
55    .bind(&thread.title)
56    .execute(conn)
57    .await
58    .unwrap();
59
60    Json(json!({
61        "message": format!("Thread {} successfully updated title to {}", thread_id, thread.title)
62    }))
63}
64
65pub async fn unlink_thread_from_user(
66    Extension(current_user): Extension<CurrentUser>,
67    State(state): State<AppState>,
68    Path(thread_id): Path<Uuid>,
69) -> Json<Value> {
70    let conn = &*state.db;
71    query("DELETE FROM user_assistant_threads WHERE thread_id = $1;")
72        .bind(thread_id)
73        .execute(conn)
74        .await
75        .unwrap();
76
77    Json(json!({
78        "message": format!("Thread {} has been successfully unlinked from user {}.", thread_id, current_user.user_id)
79    }))
80}
81
82async fn retrieve_thread_history(client: &Client, thread_id: Uuid) -> Vec<Message> {
83    let config = load_config();
84    let res: serde_json::Value = client
85        .get(format!(
86            "{}/threads/{}",
87            config.LANGGRAPH_ASSISTANT_ENDPOINT, thread_id
88        ))
89        .send()
90        .await
91        .unwrap()
92        .json()
93        .await
94        .unwrap();
95
96    let mut parsed_messages: Vec<Message> = vec![];
97
98    if let Some(messages) = res
99        .get("values")
100        .and_then(|v| v.get("messages"))
101        .and_then(|m| m.as_array())
102    {
103        for message in messages {
104            match message.get("type").unwrap().as_str().unwrap() {
105                "human" => parsed_messages.push(Message {
106                    id: message
107                        .get("id")
108                        .unwrap()
109                        .as_str()
110                        .unwrap()
111                        .parse()
112                        .unwrap(),
113                    is_bot_message: false,
114                    text: message.get("content").unwrap().as_array().unwrap()[0]
115                        .get("text")
116                        .unwrap()
117                        .as_str()
118                        .unwrap()
119                        .to_string(),
120                    title: MessageType::User,
121                }),
122                "ai" => parsed_messages.push(Message {
123                    id: message
124                        .get("id")
125                        .unwrap()
126                        .as_str()
127                        .unwrap()
128                        .strip_prefix("run-")
129                        .unwrap()
130                        .parse()
131                        .unwrap(),
132                    is_bot_message: true,
133                    text: message.get("content").unwrap().to_string(),
134                    title: MessageType::Bot,
135                }),
136                _ => (),
137            }
138        }
139    }
140
141    parsed_messages
142}
143
144pub async fn get_all_threads_from_user(
145    Extension(current_user): Extension<CurrentUser>,
146    State(state): State<AppState>,
147) -> Json<AllThreadResponse> {
148    let conn = &*state.db;
149    let db_threads = query_as::<_, DbThread>(
150        r#"
151        SELECT thread_id, thread_title
152        FROM user_assistant_threads
153        WHERE user_id = $1;
154    "#,
155    )
156    .bind(current_user.internal_id.unwrap())
157    .fetch_all(conn)
158    .await
159    .unwrap();
160
161    let client = Client::new();
162    let threads: Vec<_> = db_threads
163        .into_iter()
164        .map(|thread| {
165            let client = client.clone();
166            let user_id = current_user.user_id.clone();
167            task::spawn(async move {
168                let messages = retrieve_thread_history(&client, thread.thread_id).await;
169                ThreadResponse {
170                    thread_id: thread.thread_id,
171                    user_id,
172                    title: thread.thread_title,
173                    messages,
174                }
175            })
176        })
177        .collect();
178
179    let results = join_all(threads).await;
180    let thread_responses = results.into_iter().map(|r| r.unwrap()).collect::<Vec<_>>();
181
182    Json(AllThreadResponse {
183        threads: thread_responses,
184    })
185}