mas_storage_pg/oauth2/
client.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{
8    collections::{BTreeMap, BTreeSet},
9    str::FromStr,
10    string::ToString,
11};
12
13use async_trait::async_trait;
14use mas_data_model::{Client, JwksOrJwksUri, User};
15use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
16use mas_jose::jwk::PublicJsonWebKeySet;
17use mas_storage::{Clock, oauth2::OAuth2ClientRepository};
18use oauth2_types::{
19    oidc::ApplicationType,
20    requests::GrantType,
21    scope::{Scope, ScopeToken},
22};
23use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
24use rand::RngCore;
25use sqlx::PgConnection;
26use tracing::{Instrument, info_span};
27use ulid::Ulid;
28use url::Url;
29use uuid::Uuid;
30
31use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
32
33/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection
34pub struct PgOAuth2ClientRepository<'c> {
35    conn: &'c mut PgConnection,
36}
37
38impl<'c> PgOAuth2ClientRepository<'c> {
39    /// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL
40    /// connection
41    pub fn new(conn: &'c mut PgConnection) -> Self {
42        Self { conn }
43    }
44}
45
46#[allow(clippy::struct_excessive_bools)]
47#[derive(Debug)]
48struct OAuth2ClientLookup {
49    oauth2_client_id: Uuid,
50    encrypted_client_secret: Option<String>,
51    application_type: Option<String>,
52    redirect_uris: Vec<String>,
53    grant_type_authorization_code: bool,
54    grant_type_refresh_token: bool,
55    grant_type_client_credentials: bool,
56    grant_type_device_code: bool,
57    client_name: Option<String>,
58    logo_uri: Option<String>,
59    client_uri: Option<String>,
60    policy_uri: Option<String>,
61    tos_uri: Option<String>,
62    jwks_uri: Option<String>,
63    jwks: Option<serde_json::Value>,
64    id_token_signed_response_alg: Option<String>,
65    userinfo_signed_response_alg: Option<String>,
66    token_endpoint_auth_method: Option<String>,
67    token_endpoint_auth_signing_alg: Option<String>,
68    initiate_login_uri: Option<String>,
69}
70
71impl TryInto<Client> for OAuth2ClientLookup {
72    type Error = DatabaseInconsistencyError;
73
74    #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
75    fn try_into(self) -> Result<Client, Self::Error> {
76        let id = Ulid::from(self.oauth2_client_id);
77
78        let redirect_uris: Result<Vec<Url>, _> =
79            self.redirect_uris.iter().map(|s| s.parse()).collect();
80        let redirect_uris = redirect_uris.map_err(|e| {
81            DatabaseInconsistencyError::on("oauth2_clients")
82                .column("redirect_uris")
83                .row(id)
84                .source(e)
85        })?;
86
87        let application_type = self
88            .application_type
89            .map(|s| s.parse())
90            .transpose()
91            .map_err(|e| {
92                DatabaseInconsistencyError::on("oauth2_clients")
93                    .column("application_type")
94                    .row(id)
95                    .source(e)
96            })?;
97
98        let mut grant_types = Vec::new();
99        if self.grant_type_authorization_code {
100            grant_types.push(GrantType::AuthorizationCode);
101        }
102        if self.grant_type_refresh_token {
103            grant_types.push(GrantType::RefreshToken);
104        }
105        if self.grant_type_client_credentials {
106            grant_types.push(GrantType::ClientCredentials);
107        }
108        if self.grant_type_device_code {
109            grant_types.push(GrantType::DeviceCode);
110        }
111
112        let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
113            DatabaseInconsistencyError::on("oauth2_clients")
114                .column("logo_uri")
115                .row(id)
116                .source(e)
117        })?;
118
119        let client_uri = self
120            .client_uri
121            .map(|s| s.parse())
122            .transpose()
123            .map_err(|e| {
124                DatabaseInconsistencyError::on("oauth2_clients")
125                    .column("client_uri")
126                    .row(id)
127                    .source(e)
128            })?;
129
130        let policy_uri = self
131            .policy_uri
132            .map(|s| s.parse())
133            .transpose()
134            .map_err(|e| {
135                DatabaseInconsistencyError::on("oauth2_clients")
136                    .column("policy_uri")
137                    .row(id)
138                    .source(e)
139            })?;
140
141        let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
142            DatabaseInconsistencyError::on("oauth2_clients")
143                .column("tos_uri")
144                .row(id)
145                .source(e)
146        })?;
147
148        let id_token_signed_response_alg = self
149            .id_token_signed_response_alg
150            .map(|s| s.parse())
151            .transpose()
152            .map_err(|e| {
153                DatabaseInconsistencyError::on("oauth2_clients")
154                    .column("id_token_signed_response_alg")
155                    .row(id)
156                    .source(e)
157            })?;
158
159        let userinfo_signed_response_alg = self
160            .userinfo_signed_response_alg
161            .map(|s| s.parse())
162            .transpose()
163            .map_err(|e| {
164                DatabaseInconsistencyError::on("oauth2_clients")
165                    .column("userinfo_signed_response_alg")
166                    .row(id)
167                    .source(e)
168            })?;
169
170        let token_endpoint_auth_method = self
171            .token_endpoint_auth_method
172            .map(|s| s.parse())
173            .transpose()
174            .map_err(|e| {
175                DatabaseInconsistencyError::on("oauth2_clients")
176                    .column("token_endpoint_auth_method")
177                    .row(id)
178                    .source(e)
179            })?;
180
181        let token_endpoint_auth_signing_alg = self
182            .token_endpoint_auth_signing_alg
183            .map(|s| s.parse())
184            .transpose()
185            .map_err(|e| {
186                DatabaseInconsistencyError::on("oauth2_clients")
187                    .column("token_endpoint_auth_signing_alg")
188                    .row(id)
189                    .source(e)
190            })?;
191
192        let initiate_login_uri = self
193            .initiate_login_uri
194            .map(|s| s.parse())
195            .transpose()
196            .map_err(|e| {
197                DatabaseInconsistencyError::on("oauth2_clients")
198                    .column("initiate_login_uri")
199                    .row(id)
200                    .source(e)
201            })?;
202
203        let jwks = match (self.jwks, self.jwks_uri) {
204            (None, None) => None,
205            (Some(jwks), None) => {
206                let jwks = serde_json::from_value(jwks).map_err(|e| {
207                    DatabaseInconsistencyError::on("oauth2_clients")
208                        .column("jwks")
209                        .row(id)
210                        .source(e)
211                })?;
212                Some(JwksOrJwksUri::Jwks(jwks))
213            }
214            (None, Some(jwks_uri)) => {
215                let jwks_uri = jwks_uri.parse().map_err(|e| {
216                    DatabaseInconsistencyError::on("oauth2_clients")
217                        .column("jwks_uri")
218                        .row(id)
219                        .source(e)
220                })?;
221
222                Some(JwksOrJwksUri::JwksUri(jwks_uri))
223            }
224            _ => {
225                return Err(DatabaseInconsistencyError::on("oauth2_clients")
226                    .column("jwks(_uri)")
227                    .row(id));
228            }
229        };
230
231        Ok(Client {
232            id,
233            client_id: id.to_string(),
234            encrypted_client_secret: self.encrypted_client_secret,
235            application_type,
236            redirect_uris,
237            grant_types,
238            client_name: self.client_name,
239            logo_uri,
240            client_uri,
241            policy_uri,
242            tos_uri,
243            jwks,
244            id_token_signed_response_alg,
245            userinfo_signed_response_alg,
246            token_endpoint_auth_method,
247            token_endpoint_auth_signing_alg,
248            initiate_login_uri,
249        })
250    }
251}
252
253#[async_trait]
254impl OAuth2ClientRepository for PgOAuth2ClientRepository<'_> {
255    type Error = DatabaseError;
256
257    #[tracing::instrument(
258        name = "db.oauth2_client.lookup",
259        skip_all,
260        fields(
261            db.query.text,
262            oauth2_client.id = %id,
263        ),
264        err,
265    )]
266    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
267        let res = sqlx::query_as!(
268            OAuth2ClientLookup,
269            r#"
270                SELECT oauth2_client_id
271                     , encrypted_client_secret
272                     , application_type
273                     , redirect_uris
274                     , grant_type_authorization_code
275                     , grant_type_refresh_token
276                     , grant_type_client_credentials
277                     , grant_type_device_code
278                     , client_name
279                     , logo_uri
280                     , client_uri
281                     , policy_uri
282                     , tos_uri
283                     , jwks_uri
284                     , jwks
285                     , id_token_signed_response_alg
286                     , userinfo_signed_response_alg
287                     , token_endpoint_auth_method
288                     , token_endpoint_auth_signing_alg
289                     , initiate_login_uri
290                FROM oauth2_clients c
291
292                WHERE oauth2_client_id = $1
293            "#,
294            Uuid::from(id),
295        )
296        .traced()
297        .fetch_optional(&mut *self.conn)
298        .await?;
299
300        let Some(res) = res else { return Ok(None) };
301
302        Ok(Some(res.try_into()?))
303    }
304
305    #[tracing::instrument(
306        name = "db.oauth2_client.load_batch",
307        skip_all,
308        fields(
309            db.query.text,
310        ),
311        err,
312    )]
313    async fn load_batch(
314        &mut self,
315        ids: BTreeSet<Ulid>,
316    ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
317        let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
318        let res = sqlx::query_as!(
319            OAuth2ClientLookup,
320            r#"
321                SELECT oauth2_client_id
322                     , encrypted_client_secret
323                     , application_type
324                     , redirect_uris
325                     , grant_type_authorization_code
326                     , grant_type_refresh_token
327                     , grant_type_client_credentials
328                     , grant_type_device_code
329                     , client_name
330                     , logo_uri
331                     , client_uri
332                     , policy_uri
333                     , tos_uri
334                     , jwks_uri
335                     , jwks
336                     , id_token_signed_response_alg
337                     , userinfo_signed_response_alg
338                     , token_endpoint_auth_method
339                     , token_endpoint_auth_signing_alg
340                     , initiate_login_uri
341                FROM oauth2_clients c
342
343                WHERE oauth2_client_id = ANY($1::uuid[])
344            "#,
345            &ids,
346        )
347        .traced()
348        .fetch_all(&mut *self.conn)
349        .await?;
350
351        res.into_iter()
352            .map(|r| {
353                r.try_into()
354                    .map(|c: Client| (c.id, c))
355                    .map_err(DatabaseError::from)
356            })
357            .collect()
358    }
359
360    #[tracing::instrument(
361        name = "db.oauth2_client.add",
362        skip_all,
363        fields(
364            db.query.text,
365            client.id,
366            client.name = client_name
367        ),
368        err,
369    )]
370    #[allow(clippy::too_many_lines)]
371    async fn add(
372        &mut self,
373        rng: &mut (dyn RngCore + Send),
374        clock: &dyn Clock,
375        redirect_uris: Vec<Url>,
376        encrypted_client_secret: Option<String>,
377        application_type: Option<ApplicationType>,
378        grant_types: Vec<GrantType>,
379        client_name: Option<String>,
380        logo_uri: Option<Url>,
381        client_uri: Option<Url>,
382        policy_uri: Option<Url>,
383        tos_uri: Option<Url>,
384        jwks_uri: Option<Url>,
385        jwks: Option<PublicJsonWebKeySet>,
386        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
387        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
388        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
389        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
390        initiate_login_uri: Option<Url>,
391    ) -> Result<Client, Self::Error> {
392        let now = clock.now();
393        let id = Ulid::from_datetime_with_source(now.into(), rng);
394        tracing::Span::current().record("client.id", tracing::field::display(id));
395
396        let jwks_json = jwks
397            .as_ref()
398            .map(serde_json::to_value)
399            .transpose()
400            .map_err(DatabaseError::to_invalid_operation)?;
401
402        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
403
404        sqlx::query!(
405            r#"
406                INSERT INTO oauth2_clients
407                    ( oauth2_client_id
408                    , encrypted_client_secret
409                    , application_type
410                    , redirect_uris
411                    , grant_type_authorization_code
412                    , grant_type_refresh_token
413                    , grant_type_client_credentials
414                    , grant_type_device_code
415                    , client_name
416                    , logo_uri
417                    , client_uri
418                    , policy_uri
419                    , tos_uri
420                    , jwks_uri
421                    , jwks
422                    , id_token_signed_response_alg
423                    , userinfo_signed_response_alg
424                    , token_endpoint_auth_method
425                    , token_endpoint_auth_signing_alg
426                    , initiate_login_uri
427                    , is_static
428                    )
429                VALUES
430                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, FALSE)
431            "#,
432            Uuid::from(id),
433            encrypted_client_secret,
434            application_type.as_ref().map(ToString::to_string),
435            &redirect_uris_array,
436            grant_types.contains(&GrantType::AuthorizationCode),
437            grant_types.contains(&GrantType::RefreshToken),
438            grant_types.contains(&GrantType::ClientCredentials),
439            grant_types.contains(&GrantType::DeviceCode),
440            client_name,
441            logo_uri.as_ref().map(Url::as_str),
442            client_uri.as_ref().map(Url::as_str),
443            policy_uri.as_ref().map(Url::as_str),
444            tos_uri.as_ref().map(Url::as_str),
445            jwks_uri.as_ref().map(Url::as_str),
446            jwks_json,
447            id_token_signed_response_alg
448                .as_ref()
449                .map(ToString::to_string),
450            userinfo_signed_response_alg
451                .as_ref()
452                .map(ToString::to_string),
453            token_endpoint_auth_method.as_ref().map(ToString::to_string),
454            token_endpoint_auth_signing_alg
455                .as_ref()
456                .map(ToString::to_string),
457            initiate_login_uri.as_ref().map(Url::as_str),
458        )
459        .traced()
460        .execute(&mut *self.conn)
461        .await?;
462
463        let jwks = match (jwks, jwks_uri) {
464            (None, None) => None,
465            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
466            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
467            _ => return Err(DatabaseError::invalid_operation()),
468        };
469
470        Ok(Client {
471            id,
472            client_id: id.to_string(),
473            encrypted_client_secret,
474            application_type,
475            redirect_uris,
476            grant_types,
477            client_name,
478            logo_uri,
479            client_uri,
480            policy_uri,
481            tos_uri,
482            jwks,
483            id_token_signed_response_alg,
484            userinfo_signed_response_alg,
485            token_endpoint_auth_method,
486            token_endpoint_auth_signing_alg,
487            initiate_login_uri,
488        })
489    }
490
491    #[tracing::instrument(
492        name = "db.oauth2_client.upsert_static",
493        skip_all,
494        fields(
495            db.query.text,
496            client.id = %client_id,
497        ),
498        err,
499    )]
500    async fn upsert_static(
501        &mut self,
502        client_id: Ulid,
503        client_auth_method: OAuthClientAuthenticationMethod,
504        encrypted_client_secret: Option<String>,
505        jwks: Option<PublicJsonWebKeySet>,
506        jwks_uri: Option<Url>,
507        redirect_uris: Vec<Url>,
508    ) -> Result<Client, Self::Error> {
509        let jwks_json = jwks
510            .as_ref()
511            .map(serde_json::to_value)
512            .transpose()
513            .map_err(DatabaseError::to_invalid_operation)?;
514
515        let client_auth_method = client_auth_method.to_string();
516        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
517
518        sqlx::query!(
519            r#"
520                INSERT INTO oauth2_clients
521                    ( oauth2_client_id
522                    , encrypted_client_secret
523                    , redirect_uris
524                    , grant_type_authorization_code
525                    , grant_type_refresh_token
526                    , grant_type_client_credentials
527                    , grant_type_device_code
528                    , token_endpoint_auth_method
529                    , jwks
530                    , jwks_uri
531                    , is_static
532                    )
533                VALUES
534                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
535                ON CONFLICT (oauth2_client_id)
536                DO
537                    UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
538                             , redirect_uris = EXCLUDED.redirect_uris
539                             , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
540                             , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
541                             , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
542                             , grant_type_device_code = EXCLUDED.grant_type_device_code
543                             , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
544                             , jwks = EXCLUDED.jwks
545                             , jwks_uri = EXCLUDED.jwks_uri
546                             , is_static = TRUE
547            "#,
548            Uuid::from(client_id),
549            encrypted_client_secret,
550            &redirect_uris_array,
551            true,
552            true,
553            true,
554            true,
555            client_auth_method,
556            jwks_json,
557            jwks_uri.as_ref().map(Url::as_str),
558        )
559        .traced()
560        .execute(&mut *self.conn)
561        .await?;
562
563        let jwks = match (jwks, jwks_uri) {
564            (None, None) => None,
565            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
566            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
567            _ => return Err(DatabaseError::invalid_operation()),
568        };
569
570        Ok(Client {
571            id: client_id,
572            client_id: client_id.to_string(),
573            encrypted_client_secret,
574            application_type: None,
575            redirect_uris,
576            grant_types: vec![
577                GrantType::AuthorizationCode,
578                GrantType::RefreshToken,
579                GrantType::ClientCredentials,
580            ],
581            client_name: None,
582            logo_uri: None,
583            client_uri: None,
584            policy_uri: None,
585            tos_uri: None,
586            jwks,
587            id_token_signed_response_alg: None,
588            userinfo_signed_response_alg: None,
589            token_endpoint_auth_method: None,
590            token_endpoint_auth_signing_alg: None,
591            initiate_login_uri: None,
592        })
593    }
594
595    #[tracing::instrument(
596        name = "db.oauth2_client.all_static",
597        skip_all,
598        fields(
599            db.query.text,
600        ),
601        err,
602    )]
603    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
604        let res = sqlx::query_as!(
605            OAuth2ClientLookup,
606            r#"
607                SELECT oauth2_client_id
608                     , encrypted_client_secret
609                     , application_type
610                     , redirect_uris
611                     , grant_type_authorization_code
612                     , grant_type_refresh_token
613                     , grant_type_client_credentials
614                     , grant_type_device_code
615                     , client_name
616                     , logo_uri
617                     , client_uri
618                     , policy_uri
619                     , tos_uri
620                     , jwks_uri
621                     , jwks
622                     , id_token_signed_response_alg
623                     , userinfo_signed_response_alg
624                     , token_endpoint_auth_method
625                     , token_endpoint_auth_signing_alg
626                     , initiate_login_uri
627                FROM oauth2_clients c
628                WHERE is_static = TRUE
629            "#,
630        )
631        .traced()
632        .fetch_all(&mut *self.conn)
633        .await?;
634
635        res.into_iter()
636            .map(|r| r.try_into().map_err(DatabaseError::from))
637            .collect()
638    }
639
640    #[tracing::instrument(
641        name = "db.oauth2_client.get_consent_for_user",
642        skip_all,
643        fields(
644            db.query.text,
645            %user.id,
646            %client.id,
647        ),
648        err,
649    )]
650    async fn get_consent_for_user(
651        &mut self,
652        client: &Client,
653        user: &User,
654    ) -> Result<Scope, Self::Error> {
655        let scope_tokens: Vec<String> = sqlx::query_scalar!(
656            r#"
657                SELECT scope_token
658                FROM oauth2_consents
659                WHERE user_id = $1 AND oauth2_client_id = $2
660            "#,
661            Uuid::from(user.id),
662            Uuid::from(client.id),
663        )
664        .fetch_all(&mut *self.conn)
665        .await?;
666
667        let scope: Result<Scope, _> = scope_tokens
668            .into_iter()
669            .map(|s| ScopeToken::from_str(&s))
670            .collect();
671
672        let scope = scope.map_err(|e| {
673            DatabaseInconsistencyError::on("oauth2_consents")
674                .column("scope_token")
675                .source(e)
676        })?;
677
678        Ok(scope)
679    }
680
681    #[tracing::instrument(
682        name = "db.oauth2_client.give_consent_for_user",
683        skip_all,
684        fields(
685            db.query.text,
686            %user.id,
687            %client.id,
688            %scope,
689        ),
690        err,
691    )]
692    async fn give_consent_for_user(
693        &mut self,
694        rng: &mut (dyn RngCore + Send),
695        clock: &dyn Clock,
696        client: &Client,
697        user: &User,
698        scope: &Scope,
699    ) -> Result<(), Self::Error> {
700        let now = clock.now();
701        let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
702            .iter()
703            .map(|token| {
704                (
705                    token.to_string(),
706                    Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
707                )
708            })
709            .unzip();
710
711        sqlx::query!(
712            r#"
713                INSERT INTO oauth2_consents
714                    (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
715                SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
716                ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
717            "#,
718            &ids,
719            Uuid::from(user.id),
720            Uuid::from(client.id),
721            &tokens,
722            now,
723        )
724        .traced()
725        .execute(&mut *self.conn)
726        .await?;
727
728        Ok(())
729    }
730
731    #[tracing::instrument(
732        name = "db.oauth2_client.delete_by_id",
733        skip_all,
734        fields(
735            db.query.text,
736            client.id = %id,
737        ),
738        err,
739    )]
740    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
741        // Delete the authorization grants
742        {
743            let span = info_span!(
744                "db.oauth2_client.delete_by_id.authorization_grants",
745                { DB_QUERY_TEXT } = tracing::field::Empty,
746            );
747
748            sqlx::query!(
749                r#"
750                    DELETE FROM oauth2_authorization_grants
751                    WHERE oauth2_client_id = $1
752                "#,
753                Uuid::from(id),
754            )
755            .record(&span)
756            .execute(&mut *self.conn)
757            .instrument(span)
758            .await?;
759        }
760
761        // Delete the user consents
762        {
763            let span = info_span!(
764                "db.oauth2_client.delete_by_id.consents",
765                { DB_QUERY_TEXT } = tracing::field::Empty,
766            );
767
768            sqlx::query!(
769                r#"
770                    DELETE FROM oauth2_consents
771                    WHERE oauth2_client_id = $1
772                "#,
773                Uuid::from(id),
774            )
775            .record(&span)
776            .execute(&mut *self.conn)
777            .instrument(span)
778            .await?;
779        }
780
781        // Delete the OAuth 2 sessions related data
782        {
783            let span = info_span!(
784                "db.oauth2_client.delete_by_id.access_tokens",
785                { DB_QUERY_TEXT } = tracing::field::Empty,
786            );
787
788            sqlx::query!(
789                r#"
790                    DELETE FROM oauth2_access_tokens
791                    WHERE oauth2_session_id IN (
792                        SELECT oauth2_session_id
793                        FROM oauth2_sessions
794                        WHERE oauth2_client_id = $1
795                    )
796                "#,
797                Uuid::from(id),
798            )
799            .record(&span)
800            .execute(&mut *self.conn)
801            .instrument(span)
802            .await?;
803        }
804
805        {
806            let span = info_span!(
807                "db.oauth2_client.delete_by_id.refresh_tokens",
808                { DB_QUERY_TEXT } = tracing::field::Empty,
809            );
810
811            sqlx::query!(
812                r#"
813                    DELETE FROM oauth2_refresh_tokens
814                    WHERE oauth2_session_id IN (
815                        SELECT oauth2_session_id
816                        FROM oauth2_sessions
817                        WHERE oauth2_client_id = $1
818                    )
819                "#,
820                Uuid::from(id),
821            )
822            .record(&span)
823            .execute(&mut *self.conn)
824            .instrument(span)
825            .await?;
826        }
827
828        {
829            let span = info_span!(
830                "db.oauth2_client.delete_by_id.sessions",
831                { DB_QUERY_TEXT } = tracing::field::Empty,
832            );
833
834            sqlx::query!(
835                r#"
836                    DELETE FROM oauth2_sessions
837                    WHERE oauth2_client_id = $1
838                "#,
839                Uuid::from(id),
840            )
841            .record(&span)
842            .execute(&mut *self.conn)
843            .instrument(span)
844            .await?;
845        }
846
847        // Now delete the client itself
848        let res = sqlx::query!(
849            r#"
850                DELETE FROM oauth2_clients
851                WHERE oauth2_client_id = $1
852            "#,
853            Uuid::from(id),
854        )
855        .traced()
856        .execute(&mut *self.conn)
857        .await?;
858
859        DatabaseError::ensure_affected_rows(&res, 1)
860    }
861}