tetratto_core/database/
messages.rs

1use std::collections::HashMap;
2use oiseau::cache::Cache;
3use crate::model::auth::Notification;
4use crate::model::moderation::AuditLogEntry;
5use crate::model::socket::{SocketMessage, SocketMethod};
6use crate::model::{
7    Error, Result, auth::User, permissions::FinePermission,
8    communities_permissions::CommunityPermission, channels::Message,
9};
10use serde::Serialize;
11use tetratto_shared::unix_epoch_timestamp;
12use crate::{auto_method, DataManager};
13
14use oiseau::{PostgresRow, cache::redis::Commands};
15
16use oiseau::{execute, get, query_rows, params};
17
18#[derive(Serialize)]
19struct DeleteMessageEvent {
20    pub id: String,
21}
22
23impl DataManager {
24    /// Get a [`Message`] from an SQL row.
25    pub(crate) fn get_message_from_row(x: &PostgresRow) -> Message {
26        Message {
27            id: get!(x->0(i64)) as usize,
28            channel: get!(x->1(i64)) as usize,
29            owner: get!(x->2(i64)) as usize,
30            created: get!(x->3(i64)) as usize,
31            edited: get!(x->4(i64)) as usize,
32            content: get!(x->5(String)),
33            context: serde_json::from_str(&get!(x->6(String))).unwrap(),
34            reactions: serde_json::from_str(&get!(x->7(String))).unwrap(),
35        }
36    }
37
38    auto_method!(get_message_by_id(usize as i64)@get_message_from_row -> "SELECT * FROM messages WHERE id = $1" --name="message" --returns=Message --cache-key-tmpl="atto.message:{}");
39
40    /// Complete a vector of just messages with their owner as well.
41    ///
42    /// # Returns
43    /// `(message, owner, group with previous messages in ui)`
44    pub async fn fill_messages(
45        &self,
46        messages: Vec<Message>,
47        ignore_users: &[usize],
48    ) -> Result<Vec<(Message, User, bool)>> {
49        let mut out: Vec<(Message, User, bool)> = Vec::new();
50
51        let mut users: HashMap<usize, User> = HashMap::new();
52        for (i, message) in messages.iter().enumerate() {
53            let next_owner: usize = match messages.get(i + 1) {
54                Some(m) => m.owner,
55                None => 0,
56            };
57
58            let owner = message.owner;
59
60            if ignore_users.contains(&owner) {
61                continue;
62            }
63
64            if let Some(user) = users.get(&owner) {
65                out.push((message.to_owned(), user.clone(), next_owner == owner));
66            } else {
67                let user = self.get_user_by_id_with_void(owner).await?;
68                users.insert(owner, user.clone());
69                out.push((message.to_owned(), user, next_owner == owner));
70            }
71        }
72
73        Ok(out)
74    }
75
76    /// Get all messages by channel (paginated).
77    ///
78    /// # Arguments
79    /// * `channel` - the ID of the community to fetch channels for
80    /// * `batch` - the limit of items in each page
81    /// * `page` - the page number
82    pub async fn get_messages_by_channel(
83        &self,
84        channel: usize,
85        batch: usize,
86        page: usize,
87    ) -> Result<Vec<Message>> {
88        let conn = match self.0.connect().await {
89            Ok(c) => c,
90            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
91        };
92
93        let res = query_rows!(
94            &conn,
95            "SELECT * FROM messages WHERE channel = $1 ORDER BY created DESC LIMIT $2 OFFSET $3",
96            &[&(channel as i64), &(batch as i64), &((page * batch) as i64)],
97            |x| { Self::get_message_from_row(x) }
98        );
99
100        if res.is_err() {
101            return Err(Error::GeneralNotFound("message".to_string()));
102        }
103
104        Ok(res.unwrap())
105    }
106
107    /// Create a new message in the database.
108    ///
109    /// # Arguments
110    /// * `data` - a mock [`Message`] object to insert
111    pub async fn create_message(&self, mut data: Message) -> Result<()> {
112        if data.content.len() < 2 {
113            return Err(Error::DataTooLong("content".to_string()));
114        }
115
116        if data.content.len() > 2048 {
117            return Err(Error::DataTooLong("content".to_string()));
118        }
119
120        let owner = self.get_user_by_id(data.owner).await?;
121        let channel = self.get_channel_by_id(data.channel).await?;
122
123        // check user permission in community
124        let membership = self
125            .get_membership_by_owner_community(owner.id, channel.community)
126            .await?;
127
128        // check user permission to post in channel
129        if !channel.check_post(owner.id, Some(membership.role)) {
130            return Err(Error::NotAllowed);
131        }
132
133        // send mention notifications
134        let mut already_notified: HashMap<String, User> = HashMap::new();
135        for username in User::parse_mentions(&data.content) {
136            let user = {
137                if let Some(ua) = already_notified.get(&username) {
138                    ua.to_owned()
139                } else {
140                    let user = self.get_user_by_username(&username).await?;
141
142                    // check blocked status
143                    if self
144                        .get_userblock_by_initiator_receiver(user.id, data.owner)
145                        .await
146                        .is_ok()
147                    {
148                        return Err(Error::NotAllowed);
149                    }
150
151                    // check private status
152                    if user.settings.private_profile {
153                        if self
154                            .get_userfollow_by_initiator_receiver(user.id, data.owner)
155                            .await
156                            .is_err()
157                        {
158                            return Err(Error::NotAllowed);
159                        }
160                    }
161
162                    // check if the user can read the channel
163                    let membership = self
164                        .get_membership_by_owner_community(user.id, channel.community)
165                        .await?;
166
167                    if !channel.check_read(user.id, Some(membership.role)) {
168                        continue;
169                    }
170
171                    // create notif
172                    self.create_notification(Notification::new(
173                        "You've been mentioned in a message!".to_string(),
174                        format!(
175                            "[@{}](/api/v1/auth/user/find/{}) has mentioned you in their [message](/chats/{}/{}?message={}).",
176                            owner.username, owner.id, channel.community, data.channel, data.id
177                        ),
178                        user.id,
179                    ))
180                    .await?;
181
182                    // ...
183                    already_notified.insert(username.to_owned(), user.clone());
184                    user
185                }
186            };
187
188            data.content = data.content.replace(
189                &format!("@{username}"),
190                &format!(
191                    "<a href=\"/api/v1/auth/user/find/{}\" target=\"_top\">@{username}</a>",
192                    user.id
193                ),
194            );
195        }
196
197        // send notifs to members (if this message isn't associated with a channel)
198        if channel.community == 0 {
199            for member in [channel.members, vec![channel.owner]].concat() {
200                if member == owner.id {
201                    continue;
202                }
203
204                let user = self.get_user_by_id(member).await?;
205                if user.channel_mutes.contains(&channel.id) {
206                    continue;
207                }
208
209                let mut notif = Notification::new(
210                    "You've received a new message!".to_string(),
211                    format!(
212                        "[@{}](/api/v1/auth/user/find/{}) has sent a [message](/chats/{}/{}?message={}) in [{}](/chats/{}/{}).",
213                        owner.username,
214                        owner.id,
215                        channel.community,
216                        data.channel,
217                        data.id,
218                        channel.title,
219                        channel.community,
220                        data.channel
221                    ),
222                    member,
223                );
224
225                notif.tag = format!("chats/{}", channel.id);
226                self.create_notification(notif).await?;
227            }
228        }
229
230        // ...
231        let conn = match self.0.connect().await {
232            Ok(c) => c,
233            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
234        };
235
236        let res = execute!(
237            &conn,
238            "INSERT INTO messages VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
239            params![
240                &(data.id as i64),
241                &(data.channel as i64),
242                &(data.owner as i64),
243                &(data.created as i64),
244                &(data.edited as i64),
245                &data.content,
246                &serde_json::to_string(&data.context).unwrap(),
247                &serde_json::to_string(&data.reactions).unwrap(),
248            ]
249        );
250
251        if let Err(e) = res {
252            return Err(Error::DatabaseError(e.to_string()));
253        }
254
255        // post event
256        let mut con = self.0.1.get_con().await;
257
258        if let Err(e) = con.publish::<String, String, ()>(
259            if channel.community != 0 {
260                // broadcast to community ws
261                format!("chats/{}", channel.community)
262            } else {
263                // broadcast to channel ws
264                format!("chats/{}", channel.id)
265            },
266            serde_json::to_string(&SocketMessage {
267                method: SocketMethod::Message,
268                data: serde_json::to_string(&(data.channel.to_string(), data)).unwrap(),
269            })
270            .unwrap(),
271        ) {
272            return Err(Error::MiscError(e.to_string()));
273        }
274
275        // update channel position
276        self.update_channel_last_message(channel.id, unix_epoch_timestamp() as i64)
277            .await?;
278
279        // ...
280        Ok(())
281    }
282
283    pub async fn delete_message(&self, id: usize, user: User) -> Result<()> {
284        let message = self.get_message_by_id(id).await?;
285        let channel = self.get_channel_by_id(message.channel).await?;
286
287        // check user permission in community
288        if user.id != message.owner {
289            let membership = self
290                .get_membership_by_owner_community(user.id, channel.community)
291                .await?;
292
293            if !membership.role.check(CommunityPermission::MANAGE_MESSAGES)
294                && !user.permissions.check(FinePermission::MANAGE_MESSAGES)
295            {
296                return Err(Error::NotAllowed);
297            } else if user.permissions.check(FinePermission::MANAGE_MESSAGES) {
298                self.create_audit_log_entry(AuditLogEntry::new(
299                    user.id,
300                    format!("invoked `delete_message` with x value `{id}`"),
301                ))
302                .await?
303            }
304        }
305
306        // ...
307        let conn = match self.0.connect().await {
308            Ok(c) => c,
309            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
310        };
311
312        let res = execute!(&conn, "DELETE FROM messages WHERE id = $1", &[&(id as i64)]);
313
314        if let Err(e) = res {
315            return Err(Error::DatabaseError(e.to_string()));
316        }
317
318        self.0.1.remove(format!("atto.message:{}", id)).await;
319
320        // post event
321        let mut con = self.0.1.get_con().await;
322
323        if let Err(e) = con.publish::<String, String, ()>(
324            if channel.community != 0 {
325                // broadcast to community ws
326                format!("chats/{}", channel.community)
327            } else {
328                // broadcast to channel ws
329                format!("chats/{}", channel.id)
330            },
331            serde_json::to_string(&SocketMessage {
332                method: SocketMethod::Delete,
333                data: serde_json::to_string(&DeleteMessageEvent { id: id.to_string() }).unwrap(),
334            })
335            .unwrap(),
336        ) {
337            return Err(Error::MiscError(e.to_string()));
338        }
339
340        // ...
341        Ok(())
342    }
343
344    pub async fn update_message_content(&self, id: usize, user: User, x: String) -> Result<()> {
345        let y = self.get_message_by_id(id).await?;
346
347        if user.id != y.owner {
348            if !user.permissions.check(FinePermission::MANAGE_MESSAGES) {
349                return Err(Error::NotAllowed);
350            } else {
351                self.create_audit_log_entry(AuditLogEntry::new(
352                    user.id,
353                    format!("invoked `update_message_content` with x value `{id}`"),
354                ))
355                .await?
356            }
357        }
358
359        // ...
360        let conn = match self.0.connect().await {
361            Ok(c) => c,
362            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
363        };
364
365        let res = execute!(
366            &conn,
367            "UPDATE messages SET content = $1, edited = $2 WHERE id = $2",
368            params![&x, &(unix_epoch_timestamp() as i64), &(id as i64)]
369        );
370
371        if let Err(e) = res {
372            return Err(Error::DatabaseError(e.to_string()));
373        }
374
375        // return
376        Ok(())
377    }
378
379    auto_method!(update_message_reactions(HashMap<String, usize>) -> "UPDATE messages SET reactions = $1 WHERE id = $2" --serde --cache-key-tmpl="atto.message:{}");
380}