diff --git a/extensions/import/config.json b/extensions/import/config.json index 0a5e5b2bab..1e52232bf2 100644 --- a/extensions/import/config.json +++ b/extensions/import/config.json @@ -1,7 +1,7 @@ { "downloadUrl": "https://sqlopsextensions.blob.core.windows.net/extensions/import/service/{#version#}/{#fileName#}", "useDefaultLinuxRuntime": true, - "version": "0.0.7", + "version": "0.0.8", "downloadFileNames": { "Windows_64": "win-x64.zip", "Windows_86": "win-x86.zip", diff --git a/extensions/import/package.json b/extensions/import/package.json index b5088380c6..a23652cb37 100644 --- a/extensions/import/package.json +++ b/extensions/import/package.json @@ -2,7 +2,7 @@ "name": "import", "displayName": "SQL Server Import", "description": "SQL Server Import for Azure Data Studio supports importing CSV or JSON files into SQL Server.", - "version": "1.3.1", + "version": "1.4.0", "publisher": "Microsoft", "preview": true, "engines": { diff --git a/extensions/import/src/services/contracts.ts b/extensions/import/src/services/contracts.ts index d95bc234f3..e4ac76a1a6 100644 --- a/extensions/import/src/services/contracts.ts +++ b/extensions/import/src/services/contracts.ts @@ -78,6 +78,11 @@ const insertDataRequestName = 'flatfile/insertData'; export interface InsertDataParams { connectionString: string; batchSize: number; + /** + * For azure MFA connections we need to send the account token to establish a connection + * from flatFile service without doing Oauth. + */ + azureAccessToken: string | undefined; } export interface InsertDataResponse { diff --git a/extensions/import/src/test/utils.test.ts b/extensions/import/src/test/utils.test.ts index db82573d57..d31cec47eb 100644 --- a/extensions/import/src/test/utils.test.ts +++ b/extensions/import/src/test/utils.test.ts @@ -185,3 +185,37 @@ export class TestFlatFileProvider implements FlatFileProvider { } } + +export function getAzureAccounts(): azdata.Account[] { + return [ + { + isStale: false, + key: { + providerId: 'account1Provider', + accountId: 'account1Id' + }, + displayInfo: { + accountType: 'account1Type', + contextualDisplayName: 'account1ContextualDisplayName', + displayName: 'account1DisplayName', + userId: 'account1@microsoft.com' + }, + properties: {} + }, + { + isStale: false, + key: { + providerId: 'account2Provider', + accountId: 'account2Id' + }, + displayInfo: { + accountType: 'account2Type', + contextualDisplayName: 'account2ContextualDisplayName', + displayName: 'account2DisplayName', + userId: 'account2@microsoft.com' + }, + properties: {} + }, + ]; +} + diff --git a/extensions/import/src/test/wizard/pages/summaryPage.test.ts b/extensions/import/src/test/wizard/pages/summaryPage.test.ts index e869d2cb57..b859ac09a0 100644 --- a/extensions/import/src/test/wizard/pages/summaryPage.test.ts +++ b/extensions/import/src/test/wizard/pages/summaryPage.test.ts @@ -5,11 +5,12 @@ import * as TypeMoq from 'typemoq'; import * as azdata from 'azdata'; +import * as sinon from 'sinon'; import * as constants from '../../../common/constants'; import { FlatFileWizard } from '../../../wizard/flatFileWizard'; import * as should from 'should'; -import { ImportDataModel } from '../../../wizard/api/models'; -import { TestImportDataModel, TestFlatFileProvider } from '../../utils.test'; +import { ColumnMetadata, ImportDataModel } from '../../../wizard/api/models'; +import { TestImportDataModel, TestFlatFileProvider, getAzureAccounts } from '../../utils.test'; import { ImportPage } from '../../../wizard/api/importPage'; import { SummaryPage } from '../../../wizard/pages/summaryPage'; import { FlatFileProvider, InsertDataResponse } from '../../../services/contracts'; @@ -34,6 +35,17 @@ describe('import extension summary page tests', function () { wizard = azdata.window.createWizard(constants.wizardNameText); page = azdata.window.createWizardPage(constants.page4NameText); + sinon.stub(azdata.accounts, 'getAllAccounts').returns(Promise.resolve(getAzureAccounts())); + + sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve({ + token: 'azureToken', + tokenType: 'token' + })); + sinon.stub(azdata.connection, 'getConnectionString').returns(Promise.resolve('testConnectionString')); + }); + + this.afterEach(async () => { + sinon.restore(); }); it('checking if all components are initialized properly', async function () { @@ -141,4 +153,80 @@ describe('import extension summary page tests', function () { should.equal(summaryPage.statusText.value, constants.summaryErrorSymbol + 'testError'); }); + + it('Data is inserted with correct account access token in case of Azure MFA connections', async() => { + + // Creating a test AAD MFA connection + let testServerConnection: azdata.connection.Connection = { + providerName: 'testProviderName', + connectionId: 'testConnectionId', + options: { + azureAccount: getAzureAccounts()[1].key.accountId, + azureTenantId: 'azureAccount2Tenant', + authenticationType: 'AzureMFA' + } + }; + + // Overriding the behavior of getAccountSecurityToken and making sure + // it returns only when called with second azure test account and + // azureTenantId from test connection. + (azdata.accounts)['getAccountSecurityToken'].restore(); + sinon.stub(azdata.accounts, 'getAccountSecurityToken') + .withArgs( + getAzureAccounts()[1], + 'azureAccount2Tenant', + sinon.match.any + ).returns( + Promise.resolve({ + token: 'token2', + tokenType: 'azureTokenType' + })); + + // setting up connection objects in model + mockImportModel.object.server = testServerConnection; + mockImportModel.object.database = 'testDatabase'; + mockImportModel.object.schema = 'testSchema'; + mockImportModel.object.filePath = 'testFilePath'; + + + let testSendInsertDataRequestResponse: InsertDataResponse = { + result: { + success: true, + errorMessage: '' + } + }; + + // Creating test columns + let testProseColumns: ColumnMetadata[] = [ + ]; + mockImportModel.object.proseColumns = testProseColumns; + + // Creating a test request params with azure account 2 token + let testSendInsertDataRequest = { + connectionString: 'testConnectionString', + batchSize: 500, + azureAccessToken: 'token2' + }; + + mockFlatFileProvider.setup(x => x.sendInsertDataRequest(testSendInsertDataRequest)).returns(async () => { return testSendInsertDataRequestResponse; }); + + await new Promise(function (resolve) { + page.registerContent(async (view) => { + summaryPage = new SummaryPage(mockFlatFileWizard.object, page, mockImportModel.object, view, mockFlatFileProvider.object); + pages.set(1, summaryPage); + await summaryPage.start(); + summaryPage.setupNavigationValidator(); + resolve(); + }); + wizard.generateScriptButton.hidden = true; + + wizard.pages = [page]; + wizard.open(); + }); + + await summaryPage.onPageEnter(); + + // Verifying insert data request is called with expected parameters once. + mockFlatFileProvider.verify(x => x.sendInsertDataRequest(testSendInsertDataRequest), TypeMoq.Times.once()); + }); }); diff --git a/extensions/import/src/wizard/pages/summaryPage.ts b/extensions/import/src/wizard/pages/summaryPage.ts index c1ed3fbae5..506a06e25a 100644 --- a/extensions/import/src/wizard/pages/summaryPage.ts +++ b/extensions/import/src/wizard/pages/summaryPage.ts @@ -129,13 +129,23 @@ export class SummaryPage extends ImportPage { let result: InsertDataResponse; let err; - let includePasswordInConnectionString = (this.model.server.options.authenticationType === 'Integrated') ? false : true; + + const currentServer = this.model.server; + const includePasswordInConnectionString = (currentServer.options.authenticationType === 'Integrated') ? false : true; + const connectionString = await azdata.connection.getConnectionString(currentServer.connectionId, includePasswordInConnectionString); + + let accessToken = undefined; + if (currentServer.options.authenticationType = 'AzureMFA') { + const azureAccount = (await azdata.accounts.getAllAccounts()).filter(v => v.key.accountId === currentServer.options.azureAccount)[0]; + accessToken = (await azdata.accounts.getAccountSecurityToken(azureAccount, currentServer.options.azureTenantId, azdata.AzureResource.Sql)).token; + } try { result = await this.provider.sendInsertDataRequest({ - connectionString: await azdata.connection.getConnectionString(this.model.server.connectionId, includePasswordInConnectionString), + connectionString: connectionString, //TODO check what SSMS uses as batch size - batchSize: 500 + batchSize: 500, + azureAccessToken: accessToken }); } catch (e) { err = e.toString();