dogy_backend_api/service/assistant/
handlers.rs1use axum::{
2 extract::{Path, State},
3 Extension, Json,
4};
5use futures::future::join_all;
6use reqwest::Client;
7use serde_json::{json, Value};
8use sqlx::{query, query_as};
9use tokio::task;
10use uuid::Uuid;
11
12use crate::{
13 config::load_config, middleware::auth::layer::CurrentUser,
14 service::assistant::models::MessageType, AppState,
15};
16
17use super::models::{AllThreadResponse, DbThread, Message, Thread, ThreadResponse};
18
19pub async fn link_user_to_thread(
20 Extension(current_user): Extension<CurrentUser>,
21 State(state): State<AppState>,
22 Path(thread_id): Path<Uuid>,
23 Json(thread): Json<Thread>,
24) -> Json<Value> {
25 let conn = &*state.db;
26 query(
27 "INSERT INTO user_assistant_threads (user_id, thread_id, thread_title)
28 VALUES ($1, $2, $3);",
29 )
30 .bind(current_user.internal_id.unwrap())
31 .bind(thread_id)
32 .bind(thread.title)
33 .execute(conn)
34 .await
35 .unwrap();
36
37 Json(json!({
38 "message": format!("Thread {} successfully linked to user {}", thread_id, current_user.user_id)
39 }))
40}
41
42pub async fn update_thread_title(
43 State(state): State<AppState>,
44 Path(thread_id): Path<Uuid>,
45 Json(thread): Json<Thread>,
46) -> Json<Value> {
47 let conn = &*state.db;
48 query(
49 r#"UPDATE user_assistant_threads
50 SET thread_title = $2
51 WHERE thread_id = $1;
52 "#,
53 )
54 .bind(thread_id)
55 .bind(&thread.title)
56 .execute(conn)
57 .await
58 .unwrap();
59
60 Json(json!({
61 "message": format!("Thread {} successfully updated title to {}", thread_id, thread.title)
62 }))
63}
64
65pub async fn unlink_thread_from_user(
66 Extension(current_user): Extension<CurrentUser>,
67 State(state): State<AppState>,
68 Path(thread_id): Path<Uuid>,
69) -> Json<Value> {
70 let conn = &*state.db;
71 query("DELETE FROM user_assistant_threads WHERE thread_id = $1;")
72 .bind(thread_id)
73 .execute(conn)
74 .await
75 .unwrap();
76
77 Json(json!({
78 "message": format!("Thread {} has been successfully unlinked from user {}.", thread_id, current_user.user_id)
79 }))
80}
81
82async fn retrieve_thread_history(client: &Client, thread_id: Uuid) -> Vec<Message> {
83 let config = load_config();
84 let res: serde_json::Value = client
85 .get(format!(
86 "{}/threads/{}",
87 config.LANGGRAPH_ASSISTANT_ENDPOINT, thread_id
88 ))
89 .send()
90 .await
91 .unwrap()
92 .json()
93 .await
94 .unwrap();
95
96 let mut parsed_messages: Vec<Message> = vec![];
97
98 if let Some(messages) = res
99 .get("values")
100 .and_then(|v| v.get("messages"))
101 .and_then(|m| m.as_array())
102 {
103 for message in messages {
104 match message.get("type").unwrap().as_str().unwrap() {
105 "human" => parsed_messages.push(Message {
106 id: message
107 .get("id")
108 .unwrap()
109 .as_str()
110 .unwrap()
111 .parse()
112 .unwrap(),
113 is_bot_message: false,
114 text: message.get("content").unwrap().as_array().unwrap()[0]
115 .get("text")
116 .unwrap()
117 .as_str()
118 .unwrap()
119 .to_string(),
120 title: MessageType::User,
121 }),
122 "ai" => parsed_messages.push(Message {
123 id: message
124 .get("id")
125 .unwrap()
126 .as_str()
127 .unwrap()
128 .strip_prefix("run-")
129 .unwrap()
130 .parse()
131 .unwrap(),
132 is_bot_message: true,
133 text: message.get("content").unwrap().to_string(),
134 title: MessageType::Bot,
135 }),
136 _ => (),
137 }
138 }
139 }
140
141 parsed_messages
142}
143
144pub async fn get_all_threads_from_user(
145 Extension(current_user): Extension<CurrentUser>,
146 State(state): State<AppState>,
147) -> Json<AllThreadResponse> {
148 let conn = &*state.db;
149 let db_threads = query_as::<_, DbThread>(
150 r#"
151 SELECT thread_id, thread_title
152 FROM user_assistant_threads
153 WHERE user_id = $1;
154 "#,
155 )
156 .bind(current_user.internal_id.unwrap())
157 .fetch_all(conn)
158 .await
159 .unwrap();
160
161 let client = Client::new();
162 let threads: Vec<_> = db_threads
163 .into_iter()
164 .map(|thread| {
165 let client = client.clone();
166 let user_id = current_user.user_id.clone();
167 task::spawn(async move {
168 let messages = retrieve_thread_history(&client, thread.thread_id).await;
169 ThreadResponse {
170 thread_id: thread.thread_id,
171 user_id,
172 title: thread.thread_title,
173 messages,
174 }
175 })
176 })
177 .collect();
178
179 let results = join_all(threads).await;
180 let thread_responses = results.into_iter().map(|r| r.unwrap()).collect::<Vec<_>>();
181
182 Json(AllThreadResponse {
183 threads: thread_responses,
184 })
185}