1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
12use mas_storage::{
13 Clock, Page, Pagination,
14 oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15};
16use oauth2_types::scope::{Scope, ScopeToken};
17use rand::RngCore;
18use sea_query::{
19 Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
20 extension::postgres::PgExpr,
21};
22use sea_query_binder::SqlxBinder;
23use sqlx::PgConnection;
24use ulid::Ulid;
25use uuid::Uuid;
26
27use crate::{
28 DatabaseError, DatabaseInconsistencyError,
29 filter::{Filter, StatementExt},
30 iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
31 pagination::QueryBuilderExt,
32 tracing::ExecuteExt,
33};
34
35pub struct PgOAuth2SessionRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgOAuth2SessionRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48#[derive(sqlx::FromRow)]
49#[enum_def]
50struct OAuthSessionLookup {
51 oauth2_session_id: Uuid,
52 user_id: Option<Uuid>,
53 user_session_id: Option<Uuid>,
54 oauth2_client_id: Uuid,
55 scope_list: Vec<String>,
56 created_at: DateTime<Utc>,
57 finished_at: Option<DateTime<Utc>>,
58 user_agent: Option<String>,
59 last_active_at: Option<DateTime<Utc>>,
60 last_active_ip: Option<IpAddr>,
61 human_name: Option<String>,
62}
63
64impl TryFrom<OAuthSessionLookup> for Session {
65 type Error = DatabaseInconsistencyError;
66
67 fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
68 let id = Ulid::from(value.oauth2_session_id);
69 let scope: Result<Scope, _> = value
70 .scope_list
71 .iter()
72 .map(|s| s.parse::<ScopeToken>())
73 .collect();
74 let scope = scope.map_err(|e| {
75 DatabaseInconsistencyError::on("oauth2_sessions")
76 .column("scope")
77 .row(id)
78 .source(e)
79 })?;
80
81 let state = match value.finished_at {
82 None => SessionState::Valid,
83 Some(finished_at) => SessionState::Finished { finished_at },
84 };
85
86 Ok(Session {
87 id,
88 state,
89 created_at: value.created_at,
90 client_id: value.oauth2_client_id.into(),
91 user_id: value.user_id.map(Ulid::from),
92 user_session_id: value.user_session_id.map(Ulid::from),
93 scope,
94 user_agent: value.user_agent,
95 last_active_at: value.last_active_at,
96 last_active_ip: value.last_active_ip,
97 human_name: value.human_name,
98 })
99 }
100}
101
102impl Filter for OAuth2SessionFilter<'_> {
103 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
104 sea_query::Condition::all()
105 .add_option(self.user().map(|user| {
106 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
107 }))
108 .add_option(self.client().map(|client| {
109 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
110 .eq(Uuid::from(client.id))
111 }))
112 .add_option(self.client_kind().map(|client_kind| {
113 let static_clients = Query::select()
117 .expr(Expr::col((
118 OAuth2Clients::Table,
119 OAuth2Clients::OAuth2ClientId,
120 )))
121 .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
122 .from(OAuth2Clients::Table)
123 .take();
124 if client_kind.is_static() {
125 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
126 .eq(Expr::any(static_clients))
127 } else {
128 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
129 .ne(Expr::all(static_clients))
130 }
131 }))
132 .add_option(self.device().map(|device| -> SimpleExpr {
133 if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
134 Condition::any()
135 .add(
136 Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
137 OAuth2Sessions::Table,
138 OAuth2Sessions::ScopeList,
139 )))),
140 )
141 .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
142 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
143 )))
144 .into()
145 } else {
146 Expr::val(false).into()
148 }
149 }))
150 .add_option(self.browser_session().map(|browser_session| {
151 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
152 .eq(Uuid::from(browser_session.id))
153 }))
154 .add_option(self.browser_session_filter().map(|browser_session_filter| {
155 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
156 Query::select()
157 .expr(Expr::col((
158 UserSessions::Table,
159 UserSessions::UserSessionId,
160 )))
161 .apply_filter(browser_session_filter)
162 .from(UserSessions::Table)
163 .take(),
164 )
165 }))
166 .add_option(self.state().map(|state| {
167 if state.is_active() {
168 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
169 } else {
170 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
171 }
172 }))
173 .add_option(self.scope().map(|scope| {
174 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
175 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
176 }))
177 .add_option(self.any_user().map(|any_user| {
178 if any_user {
179 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
180 } else {
181 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
182 }
183 }))
184 .add_option(self.last_active_after().map(|last_active_after| {
185 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
186 .gt(last_active_after)
187 }))
188 .add_option(self.last_active_before().map(|last_active_before| {
189 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
190 .lt(last_active_before)
191 }))
192 }
193}
194
195#[async_trait]
196impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
197 type Error = DatabaseError;
198
199 #[tracing::instrument(
200 name = "db.oauth2_session.lookup",
201 skip_all,
202 fields(
203 db.query.text,
204 session.id = %id,
205 ),
206 err,
207 )]
208 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
209 let res = sqlx::query_as!(
210 OAuthSessionLookup,
211 r#"
212 SELECT oauth2_session_id
213 , user_id
214 , user_session_id
215 , oauth2_client_id
216 , scope_list
217 , created_at
218 , finished_at
219 , user_agent
220 , last_active_at
221 , last_active_ip as "last_active_ip: IpAddr"
222 , human_name
223 FROM oauth2_sessions
224
225 WHERE oauth2_session_id = $1
226 "#,
227 Uuid::from(id),
228 )
229 .traced()
230 .fetch_optional(&mut *self.conn)
231 .await?;
232
233 let Some(session) = res else { return Ok(None) };
234
235 Ok(Some(session.try_into()?))
236 }
237
238 #[tracing::instrument(
239 name = "db.oauth2_session.add",
240 skip_all,
241 fields(
242 db.query.text,
243 %client.id,
244 session.id,
245 session.scope = %scope,
246 ),
247 err,
248 )]
249 async fn add(
250 &mut self,
251 rng: &mut (dyn RngCore + Send),
252 clock: &dyn Clock,
253 client: &Client,
254 user: Option<&User>,
255 user_session: Option<&BrowserSession>,
256 scope: Scope,
257 ) -> Result<Session, Self::Error> {
258 let created_at = clock.now();
259 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
260 tracing::Span::current().record("session.id", tracing::field::display(id));
261
262 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
263
264 sqlx::query!(
265 r#"
266 INSERT INTO oauth2_sessions
267 ( oauth2_session_id
268 , user_id
269 , user_session_id
270 , oauth2_client_id
271 , scope_list
272 , created_at
273 )
274 VALUES ($1, $2, $3, $4, $5, $6)
275 "#,
276 Uuid::from(id),
277 user.map(|u| Uuid::from(u.id)),
278 user_session.map(|s| Uuid::from(s.id)),
279 Uuid::from(client.id),
280 &scope_list,
281 created_at,
282 )
283 .traced()
284 .execute(&mut *self.conn)
285 .await?;
286
287 Ok(Session {
288 id,
289 state: SessionState::Valid,
290 created_at,
291 user_id: user.map(|u| u.id),
292 user_session_id: user_session.map(|s| s.id),
293 client_id: client.id,
294 scope,
295 user_agent: None,
296 last_active_at: None,
297 last_active_ip: None,
298 human_name: None,
299 })
300 }
301
302 #[tracing::instrument(
303 name = "db.oauth2_session.finish_bulk",
304 skip_all,
305 fields(
306 db.query.text,
307 ),
308 err,
309 )]
310 async fn finish_bulk(
311 &mut self,
312 clock: &dyn Clock,
313 filter: OAuth2SessionFilter<'_>,
314 ) -> Result<usize, Self::Error> {
315 let finished_at = clock.now();
316 let (sql, arguments) = Query::update()
317 .table(OAuth2Sessions::Table)
318 .value(OAuth2Sessions::FinishedAt, finished_at)
319 .apply_filter(filter)
320 .build_sqlx(PostgresQueryBuilder);
321
322 let res = sqlx::query_with(&sql, arguments)
323 .traced()
324 .execute(&mut *self.conn)
325 .await?;
326
327 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
328 }
329
330 #[tracing::instrument(
331 name = "db.oauth2_session.finish",
332 skip_all,
333 fields(
334 db.query.text,
335 %session.id,
336 %session.scope,
337 client.id = %session.client_id,
338 ),
339 err,
340 )]
341 async fn finish(
342 &mut self,
343 clock: &dyn Clock,
344 session: Session,
345 ) -> Result<Session, Self::Error> {
346 let finished_at = clock.now();
347 let res = sqlx::query!(
348 r#"
349 UPDATE oauth2_sessions
350 SET finished_at = $2
351 WHERE oauth2_session_id = $1
352 "#,
353 Uuid::from(session.id),
354 finished_at,
355 )
356 .traced()
357 .execute(&mut *self.conn)
358 .await?;
359
360 DatabaseError::ensure_affected_rows(&res, 1)?;
361
362 session
363 .finish(finished_at)
364 .map_err(DatabaseError::to_invalid_operation)
365 }
366
367 #[tracing::instrument(
368 name = "db.oauth2_session.list",
369 skip_all,
370 fields(
371 db.query.text,
372 ),
373 err,
374 )]
375 async fn list(
376 &mut self,
377 filter: OAuth2SessionFilter<'_>,
378 pagination: Pagination,
379 ) -> Result<Page<Session>, Self::Error> {
380 let (sql, arguments) = Query::select()
381 .expr_as(
382 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
383 OAuthSessionLookupIden::Oauth2SessionId,
384 )
385 .expr_as(
386 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
387 OAuthSessionLookupIden::UserId,
388 )
389 .expr_as(
390 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
391 OAuthSessionLookupIden::UserSessionId,
392 )
393 .expr_as(
394 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
395 OAuthSessionLookupIden::Oauth2ClientId,
396 )
397 .expr_as(
398 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
399 OAuthSessionLookupIden::ScopeList,
400 )
401 .expr_as(
402 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
403 OAuthSessionLookupIden::CreatedAt,
404 )
405 .expr_as(
406 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
407 OAuthSessionLookupIden::FinishedAt,
408 )
409 .expr_as(
410 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
411 OAuthSessionLookupIden::UserAgent,
412 )
413 .expr_as(
414 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
415 OAuthSessionLookupIden::LastActiveAt,
416 )
417 .expr_as(
418 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
419 OAuthSessionLookupIden::LastActiveIp,
420 )
421 .expr_as(
422 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
423 OAuthSessionLookupIden::HumanName,
424 )
425 .from(OAuth2Sessions::Table)
426 .apply_filter(filter)
427 .generate_pagination(
428 (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
429 pagination,
430 )
431 .build_sqlx(PostgresQueryBuilder);
432
433 let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
434 .traced()
435 .fetch_all(&mut *self.conn)
436 .await?;
437
438 let page = pagination.process(edges).try_map(Session::try_from)?;
439
440 Ok(page)
441 }
442
443 #[tracing::instrument(
444 name = "db.oauth2_session.count",
445 skip_all,
446 fields(
447 db.query.text,
448 ),
449 err,
450 )]
451 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
452 let (sql, arguments) = Query::select()
453 .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
454 .from(OAuth2Sessions::Table)
455 .apply_filter(filter)
456 .build_sqlx(PostgresQueryBuilder);
457
458 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
459 .traced()
460 .fetch_one(&mut *self.conn)
461 .await?;
462
463 count
464 .try_into()
465 .map_err(DatabaseError::to_invalid_operation)
466 }
467
468 #[tracing::instrument(
469 name = "db.oauth2_session.record_batch_activity",
470 skip_all,
471 fields(
472 db.query.text,
473 ),
474 err,
475 )]
476 async fn record_batch_activity(
477 &mut self,
478 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
479 ) -> Result<(), Self::Error> {
480 activities.sort_unstable();
483 let mut ids = Vec::with_capacity(activities.len());
484 let mut last_activities = Vec::with_capacity(activities.len());
485 let mut ips = Vec::with_capacity(activities.len());
486
487 for (id, last_activity, ip) in activities {
488 ids.push(Uuid::from(id));
489 last_activities.push(last_activity);
490 ips.push(ip);
491 }
492
493 let res = sqlx::query!(
494 r#"
495 UPDATE oauth2_sessions
496 SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
497 , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
498 FROM (
499 SELECT *
500 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
501 AS t(oauth2_session_id, last_active_at, last_active_ip)
502 ) AS t
503 WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
504 "#,
505 &ids,
506 &last_activities,
507 &ips as &[Option<IpAddr>],
508 )
509 .traced()
510 .execute(&mut *self.conn)
511 .await?;
512
513 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
514
515 Ok(())
516 }
517
518 #[tracing::instrument(
519 name = "db.oauth2_session.record_user_agent",
520 skip_all,
521 fields(
522 db.query.text,
523 %session.id,
524 %session.scope,
525 client.id = %session.client_id,
526 session.user_agent = user_agent,
527 ),
528 err,
529 )]
530 async fn record_user_agent(
531 &mut self,
532 mut session: Session,
533 user_agent: String,
534 ) -> Result<Session, Self::Error> {
535 let res = sqlx::query!(
536 r#"
537 UPDATE oauth2_sessions
538 SET user_agent = $2
539 WHERE oauth2_session_id = $1
540 "#,
541 Uuid::from(session.id),
542 &*user_agent,
543 )
544 .traced()
545 .execute(&mut *self.conn)
546 .await?;
547
548 session.user_agent = Some(user_agent);
549
550 DatabaseError::ensure_affected_rows(&res, 1)?;
551
552 Ok(session)
553 }
554
555 #[tracing::instrument(
556 name = "repository.oauth2_session.set_human_name",
557 skip(self),
558 fields(
559 client.id = %session.client_id,
560 session.human_name = ?human_name,
561 ),
562 err,
563 )]
564 async fn set_human_name(
565 &mut self,
566 mut session: Session,
567 human_name: Option<String>,
568 ) -> Result<Session, Self::Error> {
569 let res = sqlx::query!(
570 r#"
571 UPDATE oauth2_sessions
572 SET human_name = $2
573 WHERE oauth2_session_id = $1
574 "#,
575 Uuid::from(session.id),
576 human_name.as_deref(),
577 )
578 .traced()
579 .execute(&mut *self.conn)
580 .await?;
581
582 session.human_name = human_name;
583
584 DatabaseError::ensure_affected_rows(&res, 1)?;
585
586 Ok(session)
587 }
588}