diff --git a/src/sql/workbench/services/commandLine/common/commandLineService.ts b/src/sql/workbench/services/commandLine/common/commandLineService.ts index 7f4cf3fe40..1d9e85553f 100644 --- a/src/sql/workbench/services/commandLine/common/commandLineService.ts +++ b/src/sql/workbench/services/commandLine/common/commandLineService.ts @@ -5,6 +5,8 @@ 'use strict'; import * as azdata from 'azdata'; import { ConnectionProfile } from 'sql/platform/connection/common/connectionProfile'; +import {ConnectionProfileGroup} from 'sql/platform/connection/common/connectionProfileGroup'; +import { equalsIgnoreCase } from 'vs/base/common/strings'; import { ICommandLineProcessing } from 'sql/workbench/services/commandLine/common/commandLine'; import { IConnectionManagementService } from 'sql/platform/connection/common/connectionManagement'; import { ICapabilitiesService } from 'sql/platform/capabilities/common/capabilitiesService'; @@ -136,11 +138,37 @@ export class CommandLineService implements ICommandLineProcessing { profile.serverName = args.server; profile.databaseName = args.database ? args.database : ''; profile.userName = args.user ? args.user : ''; - profile.authenticationType = args.integrated ? 'Integrated' : 'SqlLogin'; + profile.authenticationType = args.integrated ? Constants.integrated : args.aad ? Constants.azureMFA : (profile.userName.length > 0) ? Constants.sqlLogin : Constants.integrated; profile.connectionName = ''; profile.setOptionValue('applicationName', Constants.applicationName); profile.setOptionValue('databaseDisplayName', profile.databaseName); profile.setOptionValue('groupId', profile.groupId); - return profile; + return this._connectionManagementService ? this.tryMatchSavedProfile(profile) : profile; + } + + private tryMatchSavedProfile(profile: ConnectionProfile) + { + let match: ConnectionProfile = undefined; + // If we can find a saved mssql provider connection that matches the args, use it + let groups = this._connectionManagementService.getConnectionGroups([Constants.mssqlProviderName]); + if (groups && groups.length > 0) + { + let rootGroup = groups[0]; + let connections = ConnectionProfileGroup.getConnectionsInGroup(rootGroup); + match = connections.find((c) => this.matchProfile(profile, c)) ; + } + return match ? match : profile; + } + + // determines if the 2 profiles are a functional match + // profile1 is the profile generated from command line parameters + private matchProfile(profile1: ConnectionProfile, profile2: ConnectionProfile): boolean + { + return equalsIgnoreCase(profile1.serverName,profile2.serverName) + && equalsIgnoreCase(profile1.providerName, profile2.providerName) + // case sensitive servers can have 2 databases whose name differs only in case + && profile1.databaseName === profile2.databaseName + && equalsIgnoreCase(profile1.userName, profile2.userName) + && profile1.authenticationType === profile2.authenticationType; } } diff --git a/src/sqltest/parts/commandLine/commandLineService.test.ts b/src/sqltest/parts/commandLine/commandLineService.test.ts index 7d67e4e1b8..7c890a3de9 100644 --- a/src/sqltest/parts/commandLine/commandLineService.test.ts +++ b/src/sqltest/parts/commandLine/commandLineService.test.ts @@ -9,7 +9,9 @@ import * as assert from 'assert'; import * as TypeMoq from 'typemoq'; import * as azdata from 'azdata'; import { ConnectionProfile } from 'sql/platform/connection/common/connectionProfile'; +import { ConnectionProfileGroup } from 'sql/platform/connection/common/connectionProfileGroup'; import { CommandLineService } from 'sql/workbench/services/commandLine/common/commandLineService'; +import * as Constants from 'sql/platform/connection/common/constants'; import { ParsedArgs } from 'vs/platform/environment/common/environment'; import { ICapabilitiesService } from 'sql/platform/capabilities/common/capabilitiesService'; import { CapabilitiesTestService } from 'sqltest/stubs/capabilitiesTestService'; @@ -188,10 +190,12 @@ suite('commandLineService tests', () => { const args: TestParsedArgs = new TestParsedArgs(); args.server = 'myserver'; args.database = 'mydatabase'; + args.user = 'myuser'; connectionManagementService.setup((c) => c.showConnectionDialog()).verifiable(TypeMoq.Times.never()); connectionManagementService.setup(c => c.hasRegisteredServers()).returns(() => true).verifiable(TypeMoq.Times.atMostOnce()); + connectionManagementService.setup(c => c.getConnectionGroups(TypeMoq.It.isAny())).returns(() => []); let originalProfile: IConnectionProfile = undefined; - connectionManagementService.setup(c => c.connectIfNotConnected(TypeMoq.It.is(p => p.serverName === 'myserver'), 'connection', true)) + connectionManagementService.setup(c => c.connectIfNotConnected(TypeMoq.It.is(p => p.serverName === 'myserver' && p.authenticationType === Constants.sqlLogin), 'connection', true)) .returns((conn) => { originalProfile = conn; return Promise.resolve('unused'); @@ -250,7 +254,7 @@ suite('commandLineService tests', () => { }) .verifiable(TypeMoq.Times.once()); connectionManagementService.setup(c => c.getConnectionProfileById(TypeMoq.It.isAnyString())).returns(() => originalProfile); - + connectionManagementService.setup(c => c.getConnectionGroups(TypeMoq.It.isAny())).returns(() => []); let actualProfile: azdata.ConnectedContext = undefined; commandService.setup(c => c.executeCommand('mycommand', TypeMoq.It.isAny())) .returns((cmdName, profile) => { @@ -283,4 +287,70 @@ suite('commandLineService tests', () => { let service = getCommandLineService(connectionManagementService.object, configurationService.object, capabilitiesService, commandService.object); assertThrowsAsync(async () => await service.processCommandLine(args)); }); + + test('processCommandLine uses Integrated auth if no user name or auth type is passed', async () => { + const connectionManagementService: TypeMoq.Mock + = TypeMoq.Mock.ofType(TestConnectionManagementService, TypeMoq.MockBehavior.Strict); + + const args: TestParsedArgs = new TestParsedArgs(); + args.server = 'myserver'; + args.database = 'mydatabase'; + connectionManagementService.setup((c) => c.showConnectionDialog()).verifiable(TypeMoq.Times.never()); + connectionManagementService.setup(c => c.hasRegisteredServers()).returns(() => true).verifiable(TypeMoq.Times.atMostOnce()); + let originalProfile: IConnectionProfile = undefined; + connectionManagementService.setup(c => c.connectIfNotConnected(TypeMoq.It.is(p => p.serverName === 'myserver' && p.authenticationType === Constants.integrated), 'connection', true)) + .returns((conn) => { + originalProfile = conn; + return Promise.resolve('unused'); + }) + .verifiable(TypeMoq.Times.once()); + connectionManagementService.setup(c => c.getConnectionProfileById(TypeMoq.It.isAnyString())).returns(() => originalProfile); + connectionManagementService.setup(c => c.getConnectionGroups(TypeMoq.It.isAny())).returns(() => []); + const configurationService = getConfigurationServiceMock(true); + let service = getCommandLineService(connectionManagementService.object, configurationService.object, capabilitiesService); + await service.processCommandLine(args); + connectionManagementService.verifyAll(); + }); + + test('processCommandLine reuses saved connections that match args', async () => { + const connectionManagementService: TypeMoq.Mock + = TypeMoq.Mock.ofType(TestConnectionManagementService, TypeMoq.MockBehavior.Strict); + + var connection = new ConnectionProfile(capabilitiesService, { + connectionName: 'Test', + savePassword: false, + groupFullName: 'testGroup', + serverName: 'myserver', + databaseName: 'mydatabase', + authenticationType: Constants.integrated, + password: undefined, + userName: '', + groupId: undefined, + providerName: 'MSSQL', + options: {}, + saveProfile: true, + id: 'testID' + }); + var conProfGroup = new ConnectionProfileGroup('testGroup', undefined, 'testGroup', undefined, undefined); + conProfGroup.connections = [connection]; + const args: TestParsedArgs = new TestParsedArgs(); + args.server = 'myserver'; + args.database = 'mydatabase'; + connectionManagementService.setup((c) => c.showConnectionDialog()).verifiable(TypeMoq.Times.never()); + connectionManagementService.setup(c => c.hasRegisteredServers()).returns(() => true).verifiable(TypeMoq.Times.atMostOnce()); + let originalProfile: IConnectionProfile = undefined; + connectionManagementService.setup(c => c.connectIfNotConnected( + TypeMoq.It.is(p => p.serverName === 'myserver' && p.authenticationType === Constants.integrated && p.connectionName === 'Test' && p.id === 'testID'), 'connection', true)) + .returns((conn) => { + originalProfile = conn; + return Promise.resolve('unused'); + }) + .verifiable(TypeMoq.Times.once()); + connectionManagementService.setup(c => c.getConnectionProfileById('testID')).returns(() => originalProfile).verifiable(TypeMoq.Times.once()); + connectionManagementService.setup(x => x.getConnectionGroups(TypeMoq.It.isAny())).returns(() => [conProfGroup]); + const configurationService = getConfigurationServiceMock(true); + let service = getCommandLineService(connectionManagementService.object, configurationService.object, capabilitiesService); + await service.processCommandLine(args); + connectionManagementService.verifyAll(); + }); });