Commit 6e973420 authored by Stan Hu's avatar Stan Hu

Fix SAML SSO login redirects not working

When a user without a SSO session attempted to access anything in a SAML
group, previously GitLab would redirect the user back to the dashboard
and lose the original link. This was happening because the `RelayState`
wasn't being used since the `SessionsController#create` took over. To
fix this, we have to do things:

1. Store the `RelayState` in the Devise session helper before the
session is created. `ApplicationController#after_sign_in_path_for` will
run and retrieve this value.

2. Ensure the `RelayState` is the right value. Previously this value
would be set to the path of the project (e.g. mygroup/project), but this
drops the full path that was being accessed. Now we retain the entire
path so the `redirect` parameter is passed on to the IdP properly.

Relates to https://gitlab.com/gitlab-org/gitlab/-/issues/247674

Changelog: fixed
EE: true
parent d26b12d6
...@@ -304,7 +304,7 @@ class Admin::UsersController < Admin::ApplicationController ...@@ -304,7 +304,7 @@ class Admin::UsersController < Admin::ApplicationController
end end
def user def user
@user ||= find_routable!(User, params[:id]) @user ||= find_routable!(User, params[:id], request.path_info)
end end
def build_canonical_path(user) def build_canonical_path(user)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
module ProjectUnauthorized module ProjectUnauthorized
module ControllerActions module ControllerActions
def self.on_routable_not_found def self.on_routable_not_found
lambda do |routable| lambda do |routable, path_info|
return unless routable.is_a?(Project) return unless routable.is_a?(Project)
label = routable.external_authorization_classification_label label = routable.external_authorization_classification_label
......
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
module RoutableActions module RoutableActions
extend ActiveSupport::Concern extend ActiveSupport::Concern
def find_routable!(routable_klass, requested_full_path, extra_authorization_proc: nil) def find_routable!(routable_klass, requested_full_path, path_info, extra_authorization_proc: nil)
routable = routable_klass.find_by_full_path(requested_full_path, follow_redirects: request.get?) routable = routable_klass.find_by_full_path(requested_full_path, follow_redirects: request.get?)
if routable_authorized?(routable, extra_authorization_proc) if routable_authorized?(routable, extra_authorization_proc)
ensure_canonical_path(routable, requested_full_path) ensure_canonical_path(routable, requested_full_path)
routable routable
else else
perform_not_found_actions(routable, not_found_actions) perform_not_found_actions(routable, not_found_actions, path_info)
route_not_found unless performed? route_not_found unless performed?
...@@ -21,11 +21,11 @@ module RoutableActions ...@@ -21,11 +21,11 @@ module RoutableActions
[ProjectUnauthorized::ControllerActions.on_routable_not_found] [ProjectUnauthorized::ControllerActions.on_routable_not_found]
end end
def perform_not_found_actions(routable, actions) def perform_not_found_actions(routable, actions, path_info)
actions.each do |action| actions.each do |action|
break if performed? break if performed?
instance_exec(routable, &action) instance_exec(routable, path_info, &action)
end end
end end
......
...@@ -24,7 +24,7 @@ class Groups::ApplicationController < ApplicationController ...@@ -24,7 +24,7 @@ class Groups::ApplicationController < ApplicationController
end end
def group def group
@group ||= find_routable!(Group, params[:group_id] || params[:id]) @group ||= find_routable!(Group, params[:group_id] || params[:id], request.path_info)
end end
def group_projects def group_projects
......
...@@ -13,6 +13,6 @@ class Groups::Clusters::ApplicationsController < Clusters::ApplicationsControlle ...@@ -13,6 +13,6 @@ class Groups::Clusters::ApplicationsController < Clusters::ApplicationsControlle
end end
def group def group
@group ||= find_routable!(Group, params[:group_id] || params[:id]) @group ||= find_routable!(Group, params[:group_id] || params[:id], request.path_info)
end end
end end
...@@ -13,6 +13,6 @@ class Groups::Clusters::IntegrationsController < Clusters::IntegrationsControlle ...@@ -13,6 +13,6 @@ class Groups::Clusters::IntegrationsController < Clusters::IntegrationsControlle
end end
def group def group
@group ||= find_routable!(Group, params[:group_id] || params[:id]) @group ||= find_routable!(Group, params[:group_id] || params[:id], request.path_info)
end end
end end
...@@ -15,7 +15,7 @@ class Groups::ClustersController < Clusters::ClustersController ...@@ -15,7 +15,7 @@ class Groups::ClustersController < Clusters::ClustersController
end end
def group def group
@group ||= find_routable!(Group, params[:group_id] || params[:id]) @group ||= find_routable!(Group, params[:group_id] || params[:id], request.path_info)
end end
def metrics_dashboard_params def metrics_dashboard_params
......
...@@ -6,7 +6,7 @@ class Profiles::GroupsController < Profiles::ApplicationController ...@@ -6,7 +6,7 @@ class Profiles::GroupsController < Profiles::ApplicationController
feature_category :users feature_category :users
def update def update
group = find_routable!(Group, params[:id]) group = find_routable!(Group, params[:id], request.path_info)
notification_setting = current_user.notification_settings_for(group) notification_setting = current_user.notification_settings_for(group)
if notification_setting.update(update_params) if notification_setting.update(update_params)
......
...@@ -26,7 +26,7 @@ class Projects::ApplicationController < ApplicationController ...@@ -26,7 +26,7 @@ class Projects::ApplicationController < ApplicationController
path = File.join(params[:namespace_id], params[:project_id] || params[:id]) path = File.join(params[:namespace_id], params[:project_id] || params[:id])
auth_proc = ->(project) { !project.pending_delete? } auth_proc = ->(project) { !project.pending_delete? }
@project = find_routable!(Project, path, extra_authorization_proc: auth_proc) @project = find_routable!(Project, path, request.path_info, extra_authorization_proc: auth_proc)
end end
def build_canonical_path(project) def build_canonical_path(project)
......
...@@ -10,6 +10,6 @@ class Projects::Clusters::ApplicationsController < Clusters::ApplicationsControl ...@@ -10,6 +10,6 @@ class Projects::Clusters::ApplicationsController < Clusters::ApplicationsControl
end end
def project def project
@project ||= find_routable!(Project, File.join(params[:namespace_id], params[:project_id])) @project ||= find_routable!(Project, File.join(params[:namespace_id], params[:project_id]), request.path_info)
end end
end end
...@@ -10,6 +10,6 @@ class Projects::Clusters::IntegrationsController < ::Clusters::IntegrationsContr ...@@ -10,6 +10,6 @@ class Projects::Clusters::IntegrationsController < ::Clusters::IntegrationsContr
end end
def project def project
@project ||= find_routable!(Project, File.join(params[:namespace_id], params[:project_id])) @project ||= find_routable!(Project, File.join(params[:namespace_id], params[:project_id]), request.path_info)
end end
end end
...@@ -17,7 +17,7 @@ class Projects::ClustersController < Clusters::ClustersController ...@@ -17,7 +17,7 @@ class Projects::ClustersController < Clusters::ClustersController
end end
def project def project
@project ||= find_routable!(Project, File.join(params[:namespace_id], params[:project_id])) @project ||= find_routable!(Project, File.join(params[:namespace_id], params[:project_id]), request.path_info)
end end
def repository def repository
......
...@@ -172,7 +172,7 @@ class UsersController < ApplicationController ...@@ -172,7 +172,7 @@ class UsersController < ApplicationController
private private
def user def user
@user ||= find_routable!(User, params[:username]) @user ||= find_routable!(User, params[:username], request.path_info)
end end
def personal_projects def personal_projects
......
...@@ -6,10 +6,11 @@ module EE ...@@ -6,10 +6,11 @@ module EE
include ::Gitlab::Routing include ::Gitlab::Routing
include ::Gitlab::Utils::StrongMemoize include ::Gitlab::Utils::StrongMemoize
attr_reader :routable attr_reader :routable, :path_info
def initialize(routable) def initialize(routable, path_info)
@routable = routable @routable = routable
@path_info = path_info
end end
def should_redirect_to_group_saml_sso?(current_user, request) def should_redirect_to_group_saml_sso?(current_user, request)
...@@ -25,8 +26,8 @@ module EE ...@@ -25,8 +26,8 @@ module EE
module ControllerActions module ControllerActions
def self.on_routable_not_found def self.on_routable_not_found
lambda do |routable| lambda do |routable, path_info|
redirector = SsoEnforcementRedirect.new(routable) redirector = SsoEnforcementRedirect.new(routable, path_info)
if redirector.should_redirect_to_group_saml_sso?(current_user, request) if redirector.should_redirect_to_group_saml_sso?(current_user, request)
redirect_to redirector.sso_redirect_url redirect_to redirector.sso_redirect_url
...@@ -63,7 +64,7 @@ module EE ...@@ -63,7 +64,7 @@ module EE
def url_params def url_params
{ {
token: root_group.saml_discovery_token, token: root_group.saml_discovery_token,
redirect: "/#{routable.full_path}" redirect: path_info
} }
end end
end end
......
...@@ -23,13 +23,13 @@ class Groups::Analytics::ApplicationController < ApplicationController ...@@ -23,13 +23,13 @@ class Groups::Analytics::ApplicationController < ApplicationController
def load_group def load_group
return unless params['group_id'] return unless params['group_id']
@group = find_routable!(Group, params['group_id']) @group = find_routable!(Group, params['group_id'], request.path_info)
end end
def load_project def load_project
return unless @group && params['project_id'] return unless @group && params['project_id']
@project = find_routable!(@group.projects, params['project_id']) @project = find_routable!(@group.projects, params['project_id'], request.path_info)
end end
private_class_method :increment_usage_counter private_class_method :increment_usage_counter
......
...@@ -13,6 +13,7 @@ class Groups::OmniauthCallbacksController < OmniauthCallbacksController ...@@ -13,6 +13,7 @@ class Groups::OmniauthCallbacksController < OmniauthCallbacksController
identity_linker = Gitlab::Auth::GroupSaml::IdentityLinker.new(current_user, oauth, session, @saml_provider) identity_linker = Gitlab::Auth::GroupSaml::IdentityLinker.new(current_user, oauth, session, @saml_provider)
store_location_for(:redirect, saml_redirect_path)
omniauth_flow(Gitlab::Auth::GroupSaml, identity_linker: identity_linker) omniauth_flow(Gitlab::Auth::GroupSaml, identity_linker: identity_linker)
rescue Gitlab::Auth::Saml::IdentityLinker::UnverifiedRequest rescue Gitlab::Auth::Saml::IdentityLinker::UnverifiedRequest
redirect_unverified_saml_initiation redirect_unverified_saml_initiation
...@@ -131,7 +132,7 @@ class Groups::OmniauthCallbacksController < OmniauthCallbacksController ...@@ -131,7 +132,7 @@ class Groups::OmniauthCallbacksController < OmniauthCallbacksController
end end
def saml_redirect_path def saml_redirect_path
params['RelayState'].presence if current_user params['RelayState'].presence
end end
override :find_message override :find_message
......
...@@ -31,7 +31,7 @@ module Subscriptions ...@@ -31,7 +31,7 @@ module Subscriptions
private private
def find_group def find_group
@group ||= find_routable!(Group, params[:id]) @group ||= find_routable!(Group, params[:id], request.path_info)
end end
def group_params def group_params
......
...@@ -16,13 +16,13 @@ RSpec.describe EE::RoutableActions::SsoEnforcementRedirect do ...@@ -16,13 +16,13 @@ RSpec.describe EE::RoutableActions::SsoEnforcementRedirect do
it 'returns false for User routables' do it 'returns false for User routables' do
routable = build_stubbed(:user) routable = build_stubbed(:user)
subject = described_class.new(routable) subject = described_class.new(routable, '/')
expect(subject.should_redirect_to_group_saml_sso?(double, double)).to eq(false) expect(subject.should_redirect_to_group_saml_sso?(double, double)).to eq(false)
end end
it 'returns false when routable is nil' do it 'returns false when routable is nil' do
subject = described_class.new(nil) subject = described_class.new(nil, '/')
expect(subject.should_redirect_to_group_saml_sso?(double, double)).to eq(false) expect(subject.should_redirect_to_group_saml_sso?(double, double)).to eq(false)
end end
...@@ -46,19 +46,19 @@ RSpec.describe EE::RoutableActions::SsoEnforcementRedirect do ...@@ -46,19 +46,19 @@ RSpec.describe EE::RoutableActions::SsoEnforcementRedirect do
end end
context 'with a project' do context 'with a project' do
subject { described_class.new(project) } subject { described_class.new(project, '/') }
it_behaves_like 'a routable with SSO enforcement redirect' it_behaves_like 'a routable with SSO enforcement redirect'
end end
context 'with a nested project' do context 'with a nested project' do
subject { described_class.new(nested_project) } subject { described_class.new(nested_project, '/') }
it_behaves_like 'a routable with SSO enforcement redirect' it_behaves_like 'a routable with SSO enforcement redirect'
end end
context 'with a project in a personal namespace' do context 'with a project in a personal namespace' do
subject { described_class.new(create(:project)) } subject { described_class.new(create(:project), '/') }
it 'returns false' do it 'returns false' do
expect(subject.should_redirect_to_group_saml_sso?(user, double)).to eq false expect(subject.should_redirect_to_group_saml_sso?(user, double)).to eq false
...@@ -66,13 +66,13 @@ RSpec.describe EE::RoutableActions::SsoEnforcementRedirect do ...@@ -66,13 +66,13 @@ RSpec.describe EE::RoutableActions::SsoEnforcementRedirect do
end end
context 'with a group' do context 'with a group' do
subject { described_class.new(root_group) } subject { described_class.new(root_group, '/') }
it_behaves_like 'a routable with SSO enforcement redirect' it_behaves_like 'a routable with SSO enforcement redirect'
end end
context 'with a nested group' do context 'with a nested group' do
subject { described_class.new(nested_group) } subject { described_class.new(nested_group, '/') }
it_behaves_like 'a routable with SSO enforcement redirect' it_behaves_like 'a routable with SSO enforcement redirect'
end end
...@@ -88,25 +88,25 @@ RSpec.describe EE::RoutableActions::SsoEnforcementRedirect do ...@@ -88,25 +88,25 @@ RSpec.describe EE::RoutableActions::SsoEnforcementRedirect do
end end
context 'with a group' do context 'with a group' do
subject { described_class.new(root_group) } subject { described_class.new(root_group, "/#{root_group.full_path}") }
it_behaves_like 'a routable SSO url' it_behaves_like 'a routable SSO url'
end end
context 'with a nested group' do context 'with a nested group' do
subject { described_class.new(nested_group) } subject { described_class.new(nested_group, "/#{nested_group.full_path}") }
it_behaves_like 'a routable SSO url' it_behaves_like 'a routable SSO url'
end end
context 'with a project' do context 'with a project' do
subject { described_class.new(project) } subject { described_class.new(project, "/#{project.full_path}") }
it_behaves_like 'a routable SSO url' it_behaves_like 'a routable SSO url'
end end
context 'with a nested project' do context 'with a nested project' do
subject { described_class.new(nested_project) } subject { described_class.new(nested_project, "/#{nested_project.full_path}") }
it_behaves_like 'a routable SSO url' it_behaves_like 'a routable SSO url'
end end
......
...@@ -10,7 +10,7 @@ RSpec.describe RoutableActions do ...@@ -10,7 +10,7 @@ RSpec.describe RoutableActions do
def routable def routable
@klass = params[:type].constantize @klass = params[:type].constantize
@routable = find_routable!(params[:type].constantize, params[:id]) @routable = find_routable!(params[:type].constantize, params[:id], '/')
end end
def show def show
......
...@@ -119,5 +119,20 @@ RSpec.describe 'SAML access enforcement' do ...@@ -119,5 +119,20 @@ RSpec.describe 'SAML access enforcement' do
expect(page).to have_selector('#js-auto-redirect-to-provider', visible: false) expect(page).to have_selector('#js-auto-redirect-to-provider', visible: false)
end end
end end
context 'with a merge request' do
let!(:merge_request) { create(:merge_request, source_project: project, target_project: project) }
let(:resource_path) { project_merge_request_path(project, merge_request) }
it 'redirects to the SSO page and then merge request page after login' do
visit resource_path
expect(current_url).to include("redirect=#{CGI.escape(resource_path)}")
click_link 'Sign in with Single Sign-On'
expect(current_path).to eq(resource_path)
end
end
end end
end end
...@@ -229,7 +229,7 @@ RSpec.describe 'Login' do ...@@ -229,7 +229,7 @@ RSpec.describe 'Login' do
fake_successful_u2f_authentication fake_successful_u2f_authentication
expect(current_path).to eq root_path expect(current_path).to eq group_path(group)
end end
end end
...@@ -263,7 +263,7 @@ RSpec.describe 'Login' do ...@@ -263,7 +263,7 @@ RSpec.describe 'Login' do
fake_successful_webauthn_authentication fake_successful_webauthn_authentication
expect(current_path).to eq root_path expect(current_path).to eq group_path(group)
end end
end end
end end
......
...@@ -10,7 +10,7 @@ RSpec.describe RoutableActions do ...@@ -10,7 +10,7 @@ RSpec.describe RoutableActions do
def routable def routable
@klass = params[:type].constantize @klass = params[:type].constantize
@routable = find_routable!(params[:type].constantize, params[:id]) @routable = find_routable!(params[:type].constantize, params[:id], '/')
end end
def show def show
...@@ -135,7 +135,7 @@ RSpec.describe RoutableActions do ...@@ -135,7 +135,7 @@ RSpec.describe RoutableActions do
end end
it 'performs checks in the context of the controller' do it 'performs checks in the context of the controller' do
check = lambda { |routable| redirect_to routable } check = lambda { |routable, path_info| redirect_to routable }
allow(subject).to receive(:not_found_actions).and_return([check]) allow(subject).to receive(:not_found_actions).and_return([check])
get_routable(routable) get_routable(routable)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment