diff --git a/extensions/machine-learning-services/images/arrow.svg b/extensions/machine-learning-services/images/arrow.svg
new file mode 100644
index 0000000000..91626b1129
--- /dev/null
+++ b/extensions/machine-learning-services/images/arrow.svg
@@ -0,0 +1,3 @@
+
diff --git a/extensions/machine-learning-services/src/common/constants.ts b/extensions/machine-learning-services/src/common/constants.ts
index e8d55ba9be..6711463a77 100644
--- a/extensions/machine-learning-services/src/common/constants.ts
+++ b/extensions/machine-learning-services/src/common/constants.ts
@@ -60,7 +60,7 @@ export function confirmInstallPythonPackages(packages: string): string {
export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ...");
export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed.");
export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); }
-export const packageManagerNoConnection = localize('mls.packageManager.NoConnection', "No connection selected");
+export const noConnectionError = localize('mls.packageManager.NoConnection', "No connection selected");
export const notebookExtensionNotLoaded = localize('mls.notebookExtensionNotLoaded', "Notebook extension is not loaded");
export const mssqlExtensionNotLoaded = localize('mls.mssqlExtensionNotLoaded', "MSSQL extension is not loaded");
export const mlsEnabledMessage = localize('mls.enabledMessage', "Machine Learning Services Enabled");
@@ -74,6 +74,8 @@ export const mlsExternalExecuteScriptTitle = localize('mls.externalExecuteScript
export const mlsPythonLanguageTitle = localize('mls.pythonLanguageTitle', "Python");
export const mlsRLanguageTitle = localize('mls.rLanguageTitle', "R");
export const downloadError = localize('mls.downloadError', "Error while downloading");
+export function invalidModelIdError(modelUrl: string | undefined): string { return localize('mls.invalidModelIdError', "Invalid model id. model url: {0}", modelUrl || ''); }
+export function noArtifactError(modelUrl: string | undefined): string { return localize('mls.noArtifactError', "Model doesn't have any artifact. model url: {0}", modelUrl || ''); }
export const downloadingProgress = localize('mls.downloadingProgress', "Downloading");
export const pythonConfigError = localize('mls.pythonConfigError', "Python executable is not configured");
export const rConfigError = localize('mls.rConfigError', "R executable is not configured");
@@ -119,12 +121,15 @@ export const modelCreated = localize('models.created', "Date Created");
export const modelVersion = localize('models.version', "Version");
export const browseModels = localize('models.browseButton', "...");
export const azureAccount = localize('models.azureAccount', "Azure account");
-export const columnDatabase = localize('predict.columnDatabase', "Database");
-export const columnTable = localize('predict.columnTable', "Table");
-export const inputColumns = localize('predict.inputColumns', "Input columns");
-export const outputColumns = localize('predict.outputColumns', "Output column");
-export const columnName = localize('predict.columnName', "Name");
-export const inputName = localize('predict.inputName', "Input Name");
+export const columnDatabase = localize('predict.columnDatabase', "Target database");
+export const columnTable = localize('predict.columnTable', "Target table");
+export const inputColumns = localize('predict.inputColumns', "Model input mapping");
+export const outputColumns = localize('predict.outputColumns', "Model output");
+export const columnName = localize('predict.columnName', "Target columns");
+export const dataTypeName = localize('predict.dataTypeName', "Type");
+export const displayName = localize('predict.displayName', "Display name");
+export const inputName = localize('predict.inputName', "Required model input features");
+export const outputName = localize('predict.outputName', "Name");
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
export const azureGroup = localize('models.azureGroup', "Azure resource group");
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
@@ -134,7 +139,7 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Enter model source details");
-export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Select input columns");
+export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input");
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details");
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
@@ -156,6 +161,8 @@ export const invalidModelToSelectError = localize('models.invalidModelToSelectEr
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model");
+export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'");
+export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported");
diff --git a/extensions/machine-learning-services/src/common/processService.ts b/extensions/machine-learning-services/src/common/processService.ts
index f6a95031f4..9dc050cf4c 100644
--- a/extensions/machine-learning-services/src/common/processService.ts
+++ b/extensions/machine-learning-services/src/common/processService.ts
@@ -23,16 +23,19 @@ export class ProcessService {
scriptExecution.stdin.end();
// Add listeners to print stdout and stderr if an output channel was provided
- if (outputChannel) {
- scriptExecution.stdout.on('data', data => {
+
+ scriptExecution.stdout.on('data', data => {
+ if (outputChannel) {
this.outputDataChunk(data, outputChannel, ' stdout: ');
- output = output + data.toString();
- });
- scriptExecution.stderr.on('data', data => {
+ }
+ output = output + data.toString();
+ });
+ scriptExecution.stderr.on('data', data => {
+ if (outputChannel) {
this.outputDataChunk(data, outputChannel, ' stderr: ');
- output = output + data.toString();
- });
- }
+ }
+ output = output + data.toString();
+ });
scriptExecution.on('exit', (code) => {
if (timer) {
diff --git a/extensions/machine-learning-services/src/common/utils.ts b/extensions/machine-learning-services/src/common/utils.ts
index 483eb9440c..96c8697c38 100644
--- a/extensions/machine-learning-services/src/common/utils.ts
+++ b/extensions/machine-learning-services/src/common/utils.ts
@@ -22,7 +22,17 @@ export async function execCommandOnTempFile(content: string, command: (filePa
return result;
}
finally {
- await fs.promises.unlink(tempFilePath);
+ await deleteFile(tempFilePath);
+ }
+}
+
+/**
+ * Deletes a file
+ * @param filePath file path
+ */
+export async function deleteFile(filePath: string) {
+ if (filePath) {
+ await fs.promises.unlink(filePath);
}
}
@@ -215,7 +225,7 @@ export function getRegisteredModelsThreePartsName(config: Config) {
const dbName = doubleEscapeSingleBrackets(config.registeredModelDatabaseName);
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
- return `[${dbName}].${schema}.[${tableName}]`;
+ return `[${dbName}].[${schema}].[${tableName}]`;
}
/**
@@ -227,3 +237,14 @@ export function getRegisteredModelsTowPartsName(config: Config) {
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
return `[${schema}].[${tableName}]`;
}
+
+/**
+ * Write a file using a hex string
+ * @param content file content
+ */
+export async function writeFileFromHex(content: string): Promise {
+ content = content.startsWith('0x') || content.startsWith('0X') ? content.substr(2) : content;
+ const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
+ await fs.promises.writeFile(tempFilePath, Buffer.from(content, 'hex'));
+ return tempFilePath;
+}
diff --git a/extensions/machine-learning-services/src/controllers/mainController.ts b/extensions/machine-learning-services/src/controllers/mainController.ts
index 390dfb3d1e..a9cc6babf6 100644
--- a/extensions/machine-learning-services/src/controllers/mainController.ts
+++ b/extensions/machine-learning-services/src/controllers/mainController.ts
@@ -18,9 +18,9 @@ import { HttpClient } from '../common/httpClient';
import { LanguageController } from '../views/externalLanguages/languageController';
import { LanguageService } from '../externalLanguage/languageService';
import { ModelManagementController } from '../views/models/modelManagementController';
-import { RegisteredModelService } from '../modelManagement/registeredModelService';
+import { DeployedModelService } from '../modelManagement/deployedModelService';
import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService';
-import { ModelImporter } from '../modelManagement/modelImporter';
+import { ModelPythonClient } from '../modelManagement/modelPythonClient';
import { PredictService } from '../prediction/predictService';
/**
@@ -100,11 +100,11 @@ export default class MainController implements vscode.Disposable {
let mssqlService = await this.getLanguageExtensionService();
let languagesModel = new LanguageService(this._apiWrapper, mssqlService);
let languageController = new LanguageController(this._apiWrapper, this._rootPath, languagesModel);
- let modelImporter = new ModelImporter(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
+ let modelImporter = new ModelPythonClient(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
// Model Management
//
- let registeredModelService = new RegisteredModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter);
+ let registeredModelService = new DeployedModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter);
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
let predictService = new PredictService(this._apiWrapper, this._queryRunner, this._config);
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath,
diff --git a/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts b/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts
index b63bead25e..47f24c0532 100644
--- a/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts
+++ b/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts
@@ -28,10 +28,16 @@ import * as utils from '../common/utils';
*/
export class AzureModelRegistryService {
+ private _amlClient: AzureMachineLearningWorkspaces | undefined;
+ private _modelClient: WorkspaceModels | undefined;
/**
- *
+ * Creates new service
*/
- constructor(private _apiWrapper: ApiWrapper, private _config: Config, private _httpClient: HttpClient, private _outputChannel: vscode.OutputChannel) {
+ constructor(
+ private _apiWrapper: ApiWrapper,
+ private _config: Config,
+ private _httpClient: HttpClient,
+ private _outputChannel: vscode.OutputChannel) {
}
/**
@@ -120,10 +126,18 @@ export class AzureModelRegistryService {
return downloadedFilePath;
}
+ public set AzureMachineLearningClient(value: AzureMachineLearningWorkspaces) {
+ this._amlClient = value;
+ }
+
+ public set ModelClient(value: WorkspaceModels) {
+ this._modelClient = value;
+ }
+
/**
- * Installs dependencies for the extension
+ * Execute the background task to download the artifact
*/
- public async execDownloadArtifactTask(downloadUrl: string): Promise {
+ private async execDownloadArtifactTask(downloadUrl: string): Promise {
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
return results && results.length > 0 ? results[0] : constants.noResultError;
}
@@ -139,15 +153,14 @@ export class AzureModelRegistryService {
try {
for (const tenant of account.properties.tenants) {
- const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
- const token = tokens[tenant.id].token;
- const tokenType = tokens[tenant.id].tokenType;
- const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id);
+ const client = await this.getAmlClient(account, subscription, tenant);
let result = resourceGroup ? await client.workspaces.listByResourceGroup(resourceGroup.name) : await client.workspaces.listBySubscription();
- resources.push(...result);
+ if (result) {
+ resources.push(...result);
+ }
}
} catch (error) {
-
+ console.log(error);
}
return resources;
}
@@ -161,9 +174,11 @@ export class AzureModelRegistryService {
for (const tenant of account.properties.tenants) {
try {
- let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
- const client = await this.getClient(baseUri, account, subscription, tenant);
- let modelsClient = new WorkspaceModels(client);
+ let options: AzureMachineLearningWorkspacesOptions = {
+ baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
+ };
+ const client = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
+ let modelsClient = this.getModelClient(client);
resources = resources.concat(await modelsClient.listModels(resourceGroup.name, workspace.name || ''));
} catch (error) {
@@ -182,22 +197,28 @@ export class AzureModelRegistryService {
client: AzureMachineLearningWorkspaces): Promise {
const modelId = this.getModelId(model);
- let modelsClient = new Assets(client);
- return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
+ if (modelId) {
+ let modelsClient = new Assets(client);
+ return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
+ } else {
+ throw Error(constants.invalidModelIdError(model.url));
+ }
}
- public async getAssetArtifactsDownloadLinks(
+ private async getAssetArtifactsDownloadLinks(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace,
model: WorkspaceModel,
tenant: any): Promise {
- let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
- const modelManagementClient = await this.getClient(baseUri, account, subscription, tenant);
+ let options: AzureMachineLearningWorkspacesOptions = {
+ baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
+ };
+ const modelManagementClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
const asset = await this.fetchModelAsset(subscription, resourceGroup, workspace, model, modelManagementClient);
- baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
- const experienceClient = await this.getClient(baseUri, account, subscription, tenant);
+ options.baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
+ const experienceClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
const artifactClient = new Artifacts(experienceClient);
let downloadLinks: string[] = [];
if (asset && asset.artifacts) {
@@ -230,17 +251,19 @@ export class AzureModelRegistryService {
downloadLinkPromises.push(promise);
}
}
-
try {
downloadLinks = await Promise.all(downloadLinkPromises);
} catch (rejectedPromiseError) {
return rejectedPromiseError;
}
+ return downloadLinks;
+
+ } else {
+ throw Error(constants.noArtifactError(model.url));
}
- return downloadLinks;
}
- public getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
+ private getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
const artifactRegex = /^(.+?)\/(.+?)\/(.+?)$/;
if (idOrPrefix) {
const parts = artifactRegex.exec(idOrPrefix);
@@ -263,16 +286,35 @@ export class AzureModelRegistryService {
return baseUri;
}
- private async getClient(baseUri: string, account: azdata.Account, subscription: azureResource.AzureResourceSubscription, tenant: any): Promise {
- const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
- const token = tokens[tenant.id].token;
- const tokenType = tokens[tenant.id].tokenType;
- const options: AzureMachineLearningWorkspacesOptions = {
- baseUri: baseUri
- };
- const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
- client.apiVersion = this._config.amlApiVersion;
- return client;
+ private getModelClient(amlClient: AzureMachineLearningWorkspaces) {
+ return this._modelClient ?? new WorkspaceModels(amlClient);
+ }
+
+ private async getAmlClient(
+ account: azdata.Account,
+ subscription: azureResource.AzureResourceSubscription,
+ tenant: any,
+ options: AzureMachineLearningWorkspacesOptions | undefined = undefined,
+ apiVersion: string | undefined = undefined): Promise {
+ if (this._amlClient) {
+ return this._amlClient;
+ } else {
+ const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
+ let token: string = '';
+ let tokenType: string | undefined = undefined;
+ if (tokens && tenant.id in tokens) {
+ const tokenForId = tokens[tenant.id];
+ if (tokenForId) {
+ token = tokenForId.token;
+ tokenType = tokenForId.tokenType;
+ }
+ }
+ const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
+ if (apiVersion) {
+ client.apiVersion = apiVersion;
+ }
+ return client;
+ }
}
private getModelId(model: WorkspaceModel): string {
diff --git a/extensions/machine-learning-services/src/modelManagement/registeredModelService.ts b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts
similarity index 73%
rename from extensions/machine-learning-services/src/modelManagement/registeredModelService.ts
rename to extensions/machine-learning-services/src/modelManagement/deployedModelService.ts
index 6720bc2563..7bee455235 100644
--- a/extensions/machine-learning-services/src/modelManagement/registeredModelService.ts
+++ b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts
@@ -9,73 +9,85 @@ import { ApiWrapper } from '../common/apiWrapper';
import * as utils from '../common/utils';
import { Config } from '../configurations/config';
import { QueryRunner } from '../common/queryRunner';
-import { RegisteredModel, RegisteredModelDetails } from './interfaces';
-import { ModelImporter } from './modelImporter';
+import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces';
+import { ModelPythonClient } from './modelPythonClient';
import * as constants from '../common/constants';
/**
- * Service to registered models
+ * Service to deployed models
*/
-export class RegisteredModelService {
+export class DeployedModelService {
/**
- *
+ * Creates new instance
*/
constructor(
private _apiWrapper: ApiWrapper,
private _config: Config,
private _queryRunner: QueryRunner,
- private _modelImporter: ModelImporter) {
+ private _modelClient: ModelPythonClient) {
}
- public async getRegisteredModels(): Promise {
+ /**
+ * Returns deployed models
+ */
+ public async getDeployedModels(): Promise {
let connection = await this.getCurrentConnection();
let list: RegisteredModel[] = [];
if (connection) {
let query = this.getConfigureQuery(connection.databaseName);
await this._queryRunner.safeRunQuery(connection, query);
- query = this.registeredModelsQuery();
+ query = this.getDeployedModelsQuery();
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push(this.loadModelData(row));
});
}
+ } else {
+ throw Error(constants.noConnectionError);
}
return list;
}
- private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
- return {
- id: +row[0].displayValue,
- artifactName: row[1].displayValue,
- title: row[2].displayValue,
- description: row[3].displayValue,
- version: row[4].displayValue,
- created: row[5].displayValue
- };
- }
-
- public async updateModel(model: RegisteredModel): Promise {
+ /**
+ * Downloads model
+ * @param model model object
+ */
+ public async downloadModel(model: RegisteredModel): Promise {
let connection = await this.getCurrentConnection();
- let updatedModel: RegisteredModel | undefined = undefined;
if (connection) {
- const query = this.getUpdateModelScript(connection.databaseName, model);
+ const query = this.getModelContentQuery(model);
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
- const row = result.rows[0];
- updatedModel = this.loadModelData(row);
+ const content = result.rows[0][0].displayValue;
+ return await utils.writeFileFromHex(content);
+ } else {
+ throw Error(constants.invalidModelToSelectError);
}
+ } else {
+ throw Error(constants.noConnectionError);
}
- return updatedModel;
}
- public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
+ /**
+ * Loads model parameters
+ */
+ public async loadModelParameters(filePath: string): Promise {
+ return await this._modelClient.loadModelParameters(filePath);
+ }
+
+ /**
+ * Deploys local model
+ * @param filePath model file path
+ * @param details model details
+ */
+ public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
let connection = await this.getCurrentConnection();
if (connection) {
- let currentModels = await this.getRegisteredModels();
- await this._modelImporter.registerModel(connection, filePath);
- let updatedModels = await this.getRegisteredModels();
+ let currentModels = await this.getDeployedModels();
+ await this._modelClient.deployModel(connection, filePath);
+ let updatedModels = await this.getDeployedModels();
if (details && updatedModels.length >= currentModels.length + 1) {
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
const addedModel = updatedModels[updatedModels.length - 1];
@@ -92,16 +104,40 @@ export class RegisteredModelService {
}
}
}
+ private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
+ return {
+ id: +row[0].displayValue,
+ artifactName: row[1].displayValue,
+ title: row[2].displayValue,
+ description: row[3].displayValue,
+ version: row[4].displayValue,
+ created: row[5].displayValue
+ };
+ }
+
+ private async updateModel(model: RegisteredModel): Promise {
+ let connection = await this.getCurrentConnection();
+ let updatedModel: RegisteredModel | undefined = undefined;
+ if (connection) {
+ const query = this.getUpdateModelQuery(connection.databaseName, model);
+ let result = await this._queryRunner.safeRunQuery(connection, query);
+ if (result?.rows && result.rows.length > 0) {
+ const row = result.rows[0];
+ updatedModel = this.loadModelData(row);
+ }
+ }
+ return updatedModel;
+ }
private async getCurrentConnection(): Promise {
return await this._apiWrapper.getCurrentConnection();
}
- private getConfigureQuery(currentDatabaseName: string): string {
- return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable());
+ public getConfigureQuery(currentDatabaseName: string): string {
+ return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.getConfigureTableQuery());
}
- private registeredModelsQuery(): string {
+ public getDeployedModelsQuery(): string {
return `
SELECT artifact_id, artifact_name, name, description, version, created
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
@@ -116,7 +152,7 @@ export class RegisteredModelService {
* @param databaseName
* @param tableName
*/
- private configureTable(): string {
+ public getConfigureTableQuery(): string {
let databaseName = this._config.registeredModelDatabaseName;
let tableName = this._config.registeredModelTableName;
let schemaName = this._config.registeredModelTableSchemaName;
@@ -171,7 +207,7 @@ export class RegisteredModelService {
`;
}
- private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string {
+ public getUpdateModelQuery(currentDatabaseName: string, model: RegisteredModel): string {
let updateScript = `
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
SET
@@ -187,4 +223,12 @@ export class RegisteredModelService {
WHERE artifact_id = ${model.id};
`;
}
+
+ public getModelContentQuery(model: RegisteredModel): string {
+ return `
+ SELECT artifact_content
+ FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
+ WHERE artifact_id = ${model.id};
+ `;
+ }
}
diff --git a/extensions/machine-learning-services/src/modelManagement/interfaces.ts b/extensions/machine-learning-services/src/modelManagement/interfaces.ts
index 212c3adc34..f827bffc34 100644
--- a/extensions/machine-learning-services/src/modelManagement/interfaces.ts
+++ b/extensions/machine-learning-services/src/modelManagement/interfaces.ts
@@ -53,6 +53,16 @@ export interface RegisteredModel extends RegisteredModelDetails {
artifactName: string;
}
+export interface ModelParameter {
+ name: string;
+ type: string;
+}
+
+export interface ModelParameters {
+ inputs: ModelParameter[],
+ outputs: ModelParameter[]
+}
+
/**
* An interface representing registered model
*/
diff --git a/extensions/machine-learning-services/src/modelManagement/modelImporter.ts b/extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts
similarity index 53%
rename from extensions/machine-learning-services/src/modelManagement/modelImporter.ts
rename to extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts
index ad00576055..1b5022b554 100644
--- a/extensions/machine-learning-services/src/modelManagement/modelImporter.ts
+++ b/extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts
@@ -13,33 +13,89 @@ import * as utils from '../common/utils';
import { PackageManager } from '../packageManagement/packageManager';
import * as constants from '../common/constants';
import * as os from 'os';
+import { ModelParameters } from './interfaces';
/**
- * Service to import model to database
+ * Python client for ONNX models
*/
-export class ModelImporter {
+export class ModelPythonClient {
/**
- *
+ * Creates new instance
*/
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
}
- public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise {
+ /**
+ * Deploys models in the SQL database using mlflow
+ * @param connection
+ * @param modelPath
+ */
+ public async deployModel(connection: azdata.connection.ConnectionProfile, modelPath: string): Promise {
await this.installDependencies();
- await this.executeScripts(connection, modelFolderPath);
+ await this.executeDeployScripts(connection, modelPath);
}
/**
- * Installs dependencies for model importer
+ * Installs dependencies for python client
*/
- public async installDependencies(): Promise {
+ private async installDependencies(): Promise {
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
}
- protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise {
+ /**
+ *
+ * @param modelPath Loads model parameters
+ */
+ public async loadModelParameters(modelPath: string): Promise {
+ await this.installDependencies();
+ return await this.executeModelParametersScripts(modelPath);
+ }
+ private async executeModelParametersScripts(modelFolderPath: string): Promise {
+ modelFolderPath = utils.makeLinuxPath(modelFolderPath);
+
+ let scripts: string[] = [
+ 'import onnx',
+ 'import json',
+ `onnx_model_path = '${modelFolderPath}'`,
+ `onnx_model = onnx.load_model(onnx_model_path)`,
+ `type_map = {
+ onnx.TensorProto.DataType.FLOAT: 'real',
+ onnx.TensorProto.DataType.UINT8: 'tinyint',
+ onnx.TensorProto.DataType.INT16: 'smallint',
+ onnx.TensorProto.DataType.INT32: 'int',
+ onnx.TensorProto.DataType.INT64: 'bigint',
+ onnx.TensorProto.DataType.STRING: 'varchar(MAX)',
+ onnx.TensorProto.DataType.DOUBLE: 'float'}`,
+ `parameters = {
+ "inputs": [],
+ "outputs": []
+ }`,
+ `def addParameters(list, paramType):
+ for id, p in enumerate(list):
+ p_type = ''
+
+ if p.type.tensor_type.elem_type in type_map:
+ p_type = type_map[p.type.tensor_type.elem_type]
+
+ parameters[paramType].append({
+ 'name': p.name,
+ 'type': p_type
+ })`,
+
+ 'addParameters(onnx_model.graph.input, "inputs")',
+ 'addParameters(onnx_model.graph.output, "outputs")',
+ 'print(json.dumps(parameters))'
+ ];
+ let pythonExecutable = this._config.pythonExecutable;
+ let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined);
+ let parametersJson = JSON.parse(output);
+ return Object.assign({}, parametersJson);
+ }
+
+ private async executeDeployScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise {
let home = utils.makeLinuxPath(os.homedir());
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
diff --git a/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts b/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts
index 703c2d5d4c..90870aa84a 100644
--- a/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts
+++ b/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts
@@ -30,7 +30,7 @@ export abstract class SqlPackageManageProviderBase {
if (connection) {
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`;
}
- return constants.packageManagerNoConnection;
+ return constants.noConnectionError;
}
protected async getCurrentConnection(): Promise {
diff --git a/extensions/machine-learning-services/src/prediction/interfaces.ts b/extensions/machine-learning-services/src/prediction/interfaces.ts
index 5fcb789edd..2274a0d8b1 100644
--- a/extensions/machine-learning-services/src/prediction/interfaces.ts
+++ b/extensions/machine-learning-services/src/prediction/interfaces.ts
@@ -3,10 +3,13 @@
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
-export interface PredictColumn {
- name: string;
+export interface TableColumn {
+ columnName: string;
dataType?: string;
- displayName?: string;
+}
+
+export interface PredictColumn extends TableColumn {
+ paramName?: string;
}
export interface DatabaseTable {
diff --git a/extensions/machine-learning-services/src/prediction/predictService.ts b/extensions/machine-learning-services/src/prediction/predictService.ts
index 31b714c82d..eebb89bd29 100644
--- a/extensions/machine-learning-services/src/prediction/predictService.ts
+++ b/extensions/machine-learning-services/src/prediction/predictService.ts
@@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner';
import * as utils from '../common/utils';
import { RegisteredModel } from '../modelManagement/interfaces';
-import { PredictParameters, PredictColumn, DatabaseTable } from '../prediction/interfaces';
+import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
import { Config } from '../configurations/config';
/**
@@ -98,15 +98,18 @@ export class PredictService {
*Returns list of column names of a database
* @param databaseTable table info
*/
- public async getTableColumnsList(databaseTable: DatabaseTable): Promise {
+ public async getTableColumnsList(databaseTable: DatabaseTable): Promise {
let connection = await this.getCurrentConnection();
- let list: string[] = [];
+ let list: TableColumn[] = [];
if (connection && databaseTable.databaseName) {
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
- list.push(row[0].displayValue);
+ list.push({
+ columnName: row[0].displayValue,
+ dataType: row[1].displayValue
+ });
});
}
}
@@ -119,7 +122,7 @@ export class PredictService {
private getTableColumnsScript(databaseTable: DatabaseTable): string {
return `
-SELECT COLUMN_NAME,*
+SELECT COLUMN_NAME,DATA_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
@@ -149,14 +152,14 @@ DECLARE @model VARBINARY(max) = (
WITH predict_input
AS (
SELECT TOP 1000
- ${this.getColumnNames(columns, 'pi')}
+ ${this.getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
)
SELECT
-${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
+${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = @model, DATA = predict_input)
WITH (
- ${this.getColumnTypes(outputColumns)}
+ ${this.getOutputParameters(outputColumns)}
) AS p
`;
}
@@ -170,33 +173,43 @@ WITH (
WITH predict_input
AS (
SELECT TOP 1000
- ${this.getColumnNames(columns, 'pi')}
+ ${this.getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
)
SELECT
-${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
+${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getOutputColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input)
WITH (
- ${this.getColumnTypes(outputColumns)}
+ ${this.getOutputParameters(outputColumns)}
) AS p
`;
}
- private getColumnNames(columns: PredictColumn[], tableName: string) {
- return columns.map(c => {
- return c.displayName ? `${tableName}.${c.name} AS ${c.displayName}` : `${tableName}.${c.name}`;
- }).join(',\n');
- }
-
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
- return c.displayName ? `${tableName}.${c.displayName}` : `${tableName}.${c.name}`;
+ return this.getColumnName(tableName, c.paramName || '', c.columnName);
}).join(',\n');
}
- private getColumnTypes(columns: PredictColumn[]) {
+ private getOutputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
- return `${c.name} ${c.dataType}`;
+ return this.getColumnName(tableName, c.columnName, c.paramName || '');
+ }).join(',\n');
+ }
+
+ private getColumnName(tableName: string, columnName: string, displayName: string) {
+ return columnName && columnName !== displayName ? `${tableName}.${columnName} AS ${displayName}` : `${tableName}.${columnName}`;
+ }
+
+ private getPredictColumnNames(columns: PredictColumn[], tableName: string) {
+ return columns.map(c => {
+ return c.paramName ? `${tableName}.${c.paramName}` : `${tableName}.${c.columnName}`;
+ }).join(',\n');
+ }
+
+ private getOutputParameters(columns: PredictColumn[]) {
+ return columns.map(c => {
+ return `${c.paramName} ${c.dataType}`;
}).join(',\n');
}
}
diff --git a/extensions/machine-learning-services/src/test/modelManagement/azureModelRegistryService.test.ts b/extensions/machine-learning-services/src/test/modelManagement/azureModelRegistryService.test.ts
new file mode 100644
index 0000000000..e2a92f9169
--- /dev/null
+++ b/extensions/machine-learning-services/src/test/modelManagement/azureModelRegistryService.test.ts
@@ -0,0 +1,232 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the Source EULA. See License.txt in the project root for license information.
+ *--------------------------------------------------------------------------------------------*/
+
+import * as azdata from 'azdata';
+import * as vscode from 'vscode';
+import { ApiWrapper } from '../../common/apiWrapper';
+import * as TypeMoq from 'typemoq';
+import * as should from 'should';
+import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
+import { Config } from '../../configurations/config';
+import { HttpClient } from '../../common/httpClient';
+import { azureResource } from '../../typings/azure-resource';
+
+import * as utils from '../utils';
+import { Workspace, WorkspacesListByResourceGroupResponse } from '@azure/arm-machinelearningservices/esm/models';
+import { WorkspaceModel, AssetsQueryByIdResponse, Asset, GetArtifactContentInformation2Response } from '../../modelManagement/interfaces';
+import { AzureMachineLearningWorkspaces, Workspaces } from '@azure/arm-machinelearningservices';
+import { WorkspaceModels } from '../../modelManagement/workspacesModels';
+
+interface TestContext {
+
+ apiWrapper: TypeMoq.IMock;
+ config: TypeMoq.IMock;
+ httpClient: TypeMoq.IMock;
+ outputChannel: vscode.OutputChannel;
+ op: azdata.BackgroundOperation;
+ accounts: azdata.Account[];
+ subscriptions: azureResource.AzureResourceSubscription[];
+ groups: azureResource.AzureResourceResourceGroup[];
+ workspaces: Workspace[];
+ models: WorkspaceModel[];
+ client: TypeMoq.IMock;
+ workspacesClient: TypeMoq.IMock;
+ modelClient: TypeMoq.IMock;
+}
+
+function createContext(): TestContext {
+ const context = utils.createContext();
+ const workspaces = TypeMoq.Mock.ofType(Workspaces);
+ const credentials = {
+ signRequest: () => {
+ return Promise.resolve(undefined!!);
+ }
+ };
+ const client = TypeMoq.Mock.ofInstance(new AzureMachineLearningWorkspaces(credentials, 'subscription'));
+ client.setup(x => x.apiVersion).returns(() => '20180101');
+
+ return {
+ apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
+ config: TypeMoq.Mock.ofType(Config),
+ httpClient: TypeMoq.Mock.ofType(HttpClient),
+ outputChannel: context.outputChannel,
+ op: context.op,
+ accounts: [
+ {
+ key: {
+ providerId: '',
+ accountId: 'a1'
+ },
+ displayInfo: {
+ contextualDisplayName: '',
+ accountType: '',
+ displayName: 'a1',
+ userId: 'a1'
+ },
+ properties:
+ {
+ tenants: [
+ {
+ id: '1',
+ }
+ ]
+ }
+ ,
+ isStale: true
+ }
+ ],
+ subscriptions: [
+ {
+ name: 's1',
+ id: 's1'
+ }
+ ],
+ groups: [
+ {
+ name: 'g1',
+ id: 'g1'
+ }
+ ],
+ workspaces: [{
+ name: 'w1',
+ id: 'w1'
+ }
+ ],
+ models: [
+ {
+ name: 'm1',
+ id: 'm1',
+ url: 'aml://asset/test.test'
+ }
+ ],
+ client: client,
+ workspacesClient: workspaces,
+ modelClient: TypeMoq.Mock.ofInstance(new WorkspaceModels(client.object))
+ };
+}
+
+describe('AzureModelRegistryService', () => {
+ it('getAccounts should return the list of accounts successfully', async function (): Promise {
+ let testContext = createContext();
+ const accounts = testContext.accounts;
+ let service = new AzureModelRegistryService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.httpClient.object,
+ testContext.outputChannel);
+ testContext.apiWrapper.setup(x => x.getAllAccounts()).returns(() => Promise.resolve(accounts));
+ let actual = await service.getAccounts();
+ should.deepEqual(actual, testContext.accounts);
+ });
+
+ it('getSubscriptions should return the list of subscriptions successfully', async function (): Promise {
+ let testContext = createContext();
+ const expected = testContext.subscriptions;
+ let service = new AzureModelRegistryService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.httpClient.object,
+ testContext.outputChannel);
+ testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ subscriptions: expected, errors: [] }));
+ let actual = await service.getSubscriptions(testContext.accounts[0]);
+ should.deepEqual(actual, expected);
+ });
+
+ it('getGroups should return the list of groups successfully', async function (): Promise {
+ let testContext = createContext();
+ const expected = testContext.groups;
+ let service = new AzureModelRegistryService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.httpClient.object,
+ testContext.outputChannel);
+ testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ resourceGroups: expected, errors: [] }));
+ let actual = await service.getGroups(testContext.accounts[0], testContext.subscriptions[0]);
+ should.deepEqual(actual, expected);
+ });
+
+ it('getWorkspaces should return the list of workspaces successfully', async function (): Promise {
+ let testContext = createContext();
+ const response: WorkspacesListByResourceGroupResponse = Object.assign(new Array(...testContext.workspaces), {
+ _response: undefined!
+ });
+ const expected = testContext.workspaces;
+ testContext.workspacesClient.setup(x => x.listByResourceGroup(TypeMoq.It.isAny())).returns(() => Promise.resolve(response));
+ testContext.workspacesClient.setup(x => x.listBySubscription()).returns(() => Promise.resolve(response));
+ testContext.client.setup(x => x.workspaces).returns(() => testContext.workspacesClient.object);
+ let service = new AzureModelRegistryService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.httpClient.object,
+ testContext.outputChannel);
+
+
+ service.AzureMachineLearningClient = testContext.client.object;
+ let actual = await service.getWorkspaces(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0]);
+ should.deepEqual(actual, expected);
+ });
+
+ it('getModels should return the list of models successfully', async function (): Promise {
+ let testContext = createContext();
+ testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
+ testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
+ const expected = testContext.models;
+ let service = new AzureModelRegistryService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.httpClient.object,
+ testContext.outputChannel);
+ service.AzureMachineLearningClient = testContext.client.object;
+ service.ModelClient = testContext.modelClient.object;
+ testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
+ let actual = await service.getModels(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0]);
+ should.deepEqual(actual, expected);
+ });
+
+ it('downloadModel should download model artifact successfully', async function (): Promise {
+ let testContext = createContext();
+ const asset: Asset =
+ {
+ id: '1',
+ name: 'asset',
+ artifacts: [
+ {
+ id: '/1/2/3/4/5/'
+ }
+ ]
+ };
+ const assetResponse: AssetsQueryByIdResponse = Object.assign(asset, {
+ _response: undefined!
+ });
+ const artifactResponse: GetArtifactContentInformation2Response = Object.assign({
+ contentUri: 'downloadUrl'
+ }, {
+ _response: undefined!
+ });
+
+ testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
+ testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
+ testContext.config.setup(x => x.amlExperienceUrl).returns(() => 'test.url');
+ testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
+ TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('modelmanagement')), TypeMoq.It.isAny())).returns(() => Promise.resolve(assetResponse));
+ testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
+ TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('artifact')), TypeMoq.It.isAny())).returns(() => Promise.resolve(artifactResponse));
+ testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
+ operationInfo.operation(testContext.op);
+ });
+ testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
+ let service = new AzureModelRegistryService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.httpClient.object,
+ testContext.outputChannel);
+ service.AzureMachineLearningClient = testContext.client.object;
+ service.ModelClient = testContext.modelClient.object;
+ testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
+ let actual = await service.downloadModel(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0], testContext.models[0]);
+ should.notEqual(actual, undefined);
+ testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
+ });
+});
diff --git a/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts
new file mode 100644
index 0000000000..f324ed11b1
--- /dev/null
+++ b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts
@@ -0,0 +1,410 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the Source EULA. See License.txt in the project root for license information.
+ *--------------------------------------------------------------------------------------------*/
+
+import * as azdata from 'azdata';
+import * as utils from '../../common/utils';
+import { ApiWrapper } from '../../common/apiWrapper';
+import * as TypeMoq from 'typemoq';
+import * as should from 'should';
+import { Config } from '../../configurations/config';
+import { DeployedModelService } from '../../modelManagement/deployedModelService';
+import { QueryRunner } from '../../common/queryRunner';
+import { RegisteredModel } from '../../modelManagement/interfaces';
+import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
+import * as path from 'path';
+import * as os from 'os';
+import * as UUID from 'vscode-languageclient/lib/utils/uuid';
+import * as fs from 'fs';
+
+interface TestContext {
+
+ apiWrapper: TypeMoq.IMock;
+ config: TypeMoq.IMock;
+ queryRunner: TypeMoq.IMock;
+ modelClient: TypeMoq.IMock;
+}
+
+function createContext(): TestContext {
+
+ return {
+ apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
+ config: TypeMoq.Mock.ofType(Config),
+ queryRunner: TypeMoq.Mock.ofType(QueryRunner),
+ modelClient: TypeMoq.Mock.ofType(ModelPythonClient)
+ };
+}
+
+describe('DeployedModelService', () => {
+ it('getDeployedModels should fail with no connection', async function (): Promise {
+ const testContext = createContext();
+ let connection: azdata.connection.ConnectionProfile;
+
+ testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
+ let service = new DeployedModelService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.queryRunner.object,
+ testContext.modelClient.object);
+ await should(service.getDeployedModels()).rejected();
+ });
+
+ it('getDeployedModels should returns models successfully', async function (): Promise {
+ const testContext = createContext();
+ const connection = new azdata.connection.ConnectionProfile();
+ testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
+ const expected: RegisteredModel[] = [
+ {
+ id: 1,
+ artifactName: 'name1',
+ title: 'title1',
+ description: 'desc1',
+ created: '2018-01-01',
+ version: '1.1'
+ }
+ ];
+ const result = {
+ rowCount: 1,
+ columnInfo: [],
+ rows: [
+ [
+ {
+ displayValue: '1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'name1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'title1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'desc1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: '1.1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: '2018-01-01',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ }
+ ]
+ ]
+ };
+ let service = new DeployedModelService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.queryRunner.object,
+ testContext.modelClient.object);
+ testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
+
+ testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
+ testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
+ const actual = await service.getDeployedModels();
+ should.deepEqual(actual, expected);
+ });
+
+ it('loadModelParameters should load parameters using python client successfully', async function (): Promise {
+ const testContext = createContext();
+ const expected = {
+ inputs: [
+ {
+ 'name': 'p1',
+ 'type': 'int'
+ },
+ {
+ 'name': 'p2',
+ 'type': 'varchar'
+ }
+ ],
+ outputs: [
+ {
+ 'name': 'o1',
+ 'type': 'int'
+ },
+ ]
+ };
+ testContext.modelClient.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(expected));
+ let service = new DeployedModelService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.queryRunner.object,
+ testContext.modelClient.object);
+ const actual = await service.loadModelParameters('');
+ should.deepEqual(actual, expected);
+ });
+
+ it('downloadModel should download model successfully', async function (): Promise {
+ const testContext = createContext();
+ const connection = new azdata.connection.ConnectionProfile();
+ const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
+ await fs.promises.writeFile(tempFilePath, 'test');
+ testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
+ const model: RegisteredModel =
+ {
+ id: 1,
+ artifactName: 'name1',
+ title: 'title1',
+ description: 'desc1',
+ created: '2018-01-01',
+ version: '1.1'
+ };
+ const result = {
+ rowCount: 1,
+ columnInfo: [],
+ rows: [
+ [
+ {
+ displayValue: await utils.readFileInHex(tempFilePath),
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ }
+ ]
+ ]
+ };
+ let service = new DeployedModelService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.queryRunner.object,
+ testContext.modelClient.object);
+ testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
+
+ testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
+ testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
+ testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
+ const actual = await service.downloadModel(model);
+ should.notEqual(actual, undefined);
+ });
+
+ it('deployLocalModel should returns models successfully', async function (): Promise {
+ const testContext = createContext();
+ const connection = new azdata.connection.ConnectionProfile();
+ testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
+ const model: RegisteredModel =
+ {
+ id: 1,
+ artifactName: 'name1',
+ title: 'title1',
+ description: 'desc1',
+ created: '2018-01-01',
+ version: '1.1'
+ };
+ const row = [
+ {
+ displayValue: '1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'name1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'title1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'desc1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: '1.1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: '2018-01-01',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ }
+ ];
+ const result = {
+ rowCount: 1,
+ columnInfo: [],
+ rows: [row]
+ };
+ let updatedResult = {
+ rowCount: 1,
+ columnInfo: [],
+ rows: [row, row]
+ };
+ let deployed = false;
+ let service = new DeployedModelService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.queryRunner.object,
+ testContext.modelClient.object);
+ testContext.modelClient.setup(x => x.deployModel(connection, '')).returns(() => {
+ deployed = true;
+ return Promise.resolve();
+ });
+ testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
+ return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
+ });
+
+ testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
+ testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
+ testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
+ await should(service.deployLocalModel('', model)).resolved();
+ });
+
+ it('getConfigureQuery should escape db name', async function (): Promise {
+ const testContext = createContext();
+ const dbName = 'curre[n]tDb';
+ let service = new DeployedModelService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.queryRunner.object,
+ testContext.modelClient.object);
+ testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
+ testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
+ testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
+ const expected = `
+ IF NOT EXISTS (
+ SELECT [name]
+ FROM sys.databases
+ WHERE [name] = N'd[]b'
+ )
+ CREATE DATABASE [d[[]]b]
+ GO
+ USE [d[[]]b]
+ IF EXISTS
+ ( SELECT [t.name], [s.name]
+ FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
+ WHERE [t.name] = 'ta[b]le'
+ AND [s.name] = 'dbo'
+ )
+ BEGIN
+ IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
+ ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
+ IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
+ ALTER TABLE [dbo].[ta[[b]]le] ADD [version] [varchar](256) NULL
+ IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='created')
+ BEGIN
+ ALTER TABLE [dbo].[ta[[b]]le] ADD [created] [datetime] NULL
+ ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
+ END
+ IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='description')
+ ALTER TABLE [dbo].[ta[[b]]le] ADD [description] [varchar](256) NULL
+ END
+ Else
+ BEGIN
+ CREATE TABLE [dbo].[ta[[b]]le](
+ [artifact_id] [int] IDENTITY(1,1) NOT NULL,
+ [artifact_name] [varchar](256) NOT NULL,
+ [group_path] [varchar](256) NOT NULL,
+ [artifact_content] [varbinary](max) NOT NULL,
+ [artifact_initial_size] [bigint] NULL,
+ [name] [varchar](256) NULL,
+ [version] [varchar](256) NULL,
+ [created] [datetime] NULL,
+ [description] [varchar](256) NULL,
+ CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
+ (
+ [artifact_id] ASC
+ )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
+ ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
+ ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
+ END
+ `;
+ const actual = service.getConfigureQuery(dbName);
+ should.equal(actual.indexOf(expected) > 0, true);
+ });
+
+ it('getDeployedModelsQuery should escape db name', async function (): Promise {
+ const testContext = createContext();
+ let service = new DeployedModelService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.queryRunner.object,
+ testContext.modelClient.object);
+ testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
+ testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
+ testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
+ const expected = `
+ SELECT artifact_id, artifact_name, name, description, version, created
+ FROM [d[[]]b].[dbo].[ta[[b]]le]
+ WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
+ Order by artifact_id
+ `;
+ const actual = service.getDeployedModelsQuery();
+ should.deepEqual(expected, actual);
+ });
+
+ it('getUpdateModelQuery should escape db name', async function (): Promise {
+ const testContext = createContext();
+ const dbName = 'curre[n]tDb';
+ const model: RegisteredModel =
+ {
+ id: 1,
+ artifactName: 'name1',
+ title: 'title1',
+ description: 'desc1',
+ created: '2018-01-01',
+ version: '1.1'
+ };
+
+ let service = new DeployedModelService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.queryRunner.object,
+ testContext.modelClient.object);
+ testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
+ testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
+ testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
+ const expected = `
+ UPDATE [dbo].[ta[[b]]le]
+ SET
+ name = 'title1',
+ version = '1.1',
+ description = 'desc1'
+ WHERE artifact_id = 1`;
+ const actual = service.getUpdateModelQuery(dbName, model);
+ should.equal(actual.indexOf(expected) > 0, true);
+ //should.deepEqual(actual, expected);
+
+ });
+
+ it('getModelContentQuery should escape db name', async function (): Promise {
+ const testContext = createContext();
+ const model: RegisteredModel =
+ {
+ id: 1,
+ artifactName: 'name1',
+ title: 'title1',
+ description: 'desc1',
+ created: '2018-01-01',
+ version: '1.1'
+ };
+
+ let service = new DeployedModelService(
+ testContext.apiWrapper.object,
+ testContext.config.object,
+ testContext.queryRunner.object,
+ testContext.modelClient.object);
+ testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
+ testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
+ testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
+ const expected = `
+ SELECT artifact_content
+ FROM [d[[]]b].[dbo].[ta[[b]]le]
+ WHERE artifact_id = 1;
+ `;
+ const actual = service.getModelContentQuery(model);
+ should.deepEqual(actual, expected);
+ });
+});
diff --git a/extensions/machine-learning-services/src/test/modelManagement/modelPythonClient.test.ts b/extensions/machine-learning-services/src/test/modelManagement/modelPythonClient.test.ts
new file mode 100644
index 0000000000..a2b985a016
--- /dev/null
+++ b/extensions/machine-learning-services/src/test/modelManagement/modelPythonClient.test.ts
@@ -0,0 +1,121 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the Source EULA. See License.txt in the project root for license information.
+ *--------------------------------------------------------------------------------------------*/
+
+import * as azdata from 'azdata';
+import * as vscode from 'vscode';
+import { ApiWrapper } from '../../common/apiWrapper';
+import * as TypeMoq from 'typemoq';
+import * as should from 'should';
+import { Config } from '../../configurations/config';
+
+import * as utils from '../utils';
+import { ProcessService } from '../../common/processService';
+import { PackageManager } from '../../packageManagement/packageManager';
+import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
+
+interface TestContext {
+
+ apiWrapper: TypeMoq.IMock;
+ config: TypeMoq.IMock;
+ outputChannel: vscode.OutputChannel;
+ op: azdata.BackgroundOperation;
+ processService: TypeMoq.IMock;
+ packageManager: TypeMoq.IMock;
+}
+
+function createContext(): TestContext {
+ const context = utils.createContext();
+
+ return {
+ apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
+ config: TypeMoq.Mock.ofType(Config),
+ outputChannel: context.outputChannel,
+ op: context.op,
+ processService: TypeMoq.Mock.ofType(ProcessService),
+ packageManager: TypeMoq.Mock.ofType(PackageManager)
+ };
+}
+
+describe('ModelPythonClient', () => {
+ it('deployModel should deploy the model successfully', async function (): Promise {
+ const testContext = createContext();
+ const connection = new azdata.connection.ConnectionProfile();
+ const modelPath = 'C:\\test';
+ let service = new ModelPythonClient(
+ testContext.outputChannel,
+ testContext.apiWrapper.object,
+ testContext.processService.object,
+ testContext.config.object,
+ testContext.packageManager.object);
+ testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
+ testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
+ operationInfo.operation(testContext.op);
+ });
+ testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
+ testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
+ TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
+
+ await service.deployModel(connection, modelPath);
+ });
+
+ it('loadModelParameters should load model parameters successfully', async function (): Promise {
+ const testContext = createContext();
+ const modelPath = 'C:\\test';
+ const expected = {
+ inputs: [
+ {
+ 'name': 'p1',
+ 'type': 'int'
+ },
+ {
+ 'name': 'p2',
+ 'type': 'varchar'
+ }
+ ],
+ outputs: [
+ {
+ 'name': 'o1',
+ 'type': 'int'
+ },
+ ]
+ };
+ const parametersJson = `
+ {
+ "inputs": [
+ {
+ "name": "p1",
+ "type": "int"
+ },
+ {
+ "name": "p2",
+ "type": "varchar"
+ }
+ ],
+ "outputs": [
+ {
+ "name": "o1",
+ "type": "int"
+ }
+ ]
+ }
+ `;
+ let service = new ModelPythonClient(
+ testContext.outputChannel,
+ testContext.apiWrapper.object,
+ testContext.processService.object,
+ testContext.config.object,
+ testContext.packageManager.object);
+ testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
+ testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
+ testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
+ TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson));
+ testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
+ operationInfo.operation(testContext.op);
+ });
+
+ const actual = await service.loadModelParameters(modelPath);
+ should.deepEqual(actual, expected);
+ });
+});
diff --git a/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts b/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts
index 856135603a..0ee3592e29 100644
--- a/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts
+++ b/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts
@@ -354,7 +354,7 @@ describe('SQL Python Package Manager', () => {
let provider = createProvider(testContext);
let actual = await provider.getLocationTitle();
- should.deepEqual(actual, constants.packageManagerNoConnection);
+ should.deepEqual(actual, constants.noConnectionError);
});
it('getLocationTitle Should return connection title string for valid connection', async function (): Promise {
diff --git a/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts b/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts
index cbe86dc4a7..6752a00dc1 100644
--- a/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts
+++ b/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts
@@ -279,7 +279,7 @@ describe('SQL R Package Manager', () => {
let provider = createProvider(testContext);
let actual = await provider.getLocationTitle();
- should.deepEqual(actual, constants.packageManagerNoConnection);
+ should.deepEqual(actual, constants.noConnectionError);
});
it('getLocationTitle Should return connection title string for valid connection', async function (): Promise {
diff --git a/extensions/machine-learning-services/src/test/packageManagement/utils.ts b/extensions/machine-learning-services/src/test/packageManagement/utils.ts
index 3911af4db0..1957d38d8c 100644
--- a/extensions/machine-learning-services/src/test/packageManagement/utils.ts
+++ b/extensions/machine-learning-services/src/test/packageManagement/utils.ts
@@ -11,6 +11,7 @@ import { QueryRunner } from '../../common/queryRunner';
import { ProcessService } from '../../common/processService';
import { Config } from '../../configurations/config';
import { HttpClient } from '../../common/httpClient';
+import * as utils from '../utils';
import { PackageManagementService } from '../../packageManagement/packageManagementService';
export interface TestContext {
@@ -27,31 +28,18 @@ export interface TestContext {
}
export function createContext(): TestContext {
- let opStatus: azdata.TaskStatus;
+ const context = utils.createContext();
return {
- outputChannel: {
- name: '',
- append: () => { },
- appendLine: () => { },
- clear: () => { },
- show: () => { },
- hide: () => { },
- dispose: () => { }
- },
+
+ outputChannel: context.outputChannel,
processService: TypeMoq.Mock.ofType(ProcessService),
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
config: TypeMoq.Mock.ofType(Config),
httpClient: TypeMoq.Mock.ofType(HttpClient),
- op: {
- updateStatus: (status: azdata.TaskStatus) => {
- opStatus = status;
- },
- id: '',
- onCanceled: new vscode.EventEmitter().event,
- },
- getOpStatus: () => { return opStatus; },
+ op: context.op,
+ getOpStatus: context.getOpStatus,
serverConfigManager: TypeMoq.Mock.ofType(PackageManagementService)
};
}
diff --git a/extensions/machine-learning-services/src/test/prediction/predictService.test.ts b/extensions/machine-learning-services/src/test/prediction/predictService.test.ts
new file mode 100644
index 0000000000..94b3a8cd84
--- /dev/null
+++ b/extensions/machine-learning-services/src/test/prediction/predictService.test.ts
@@ -0,0 +1,303 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the Source EULA. See License.txt in the project root for license information.
+ *--------------------------------------------------------------------------------------------*/
+
+import * as azdata from 'azdata';
+import * as vscode from 'vscode';
+import { ApiWrapper } from '../../common/apiWrapper';
+import * as TypeMoq from 'typemoq';
+import * as should from 'should';
+import { Config } from '../../configurations/config';
+import { PredictService } from '../../prediction/predictService';
+import { QueryRunner } from '../../common/queryRunner';
+import { RegisteredModel } from '../../modelManagement/interfaces';
+import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
+import * as path from 'path';
+import * as os from 'os';
+import * as UUID from 'vscode-languageclient/lib/utils/uuid';
+import * as fs from 'fs';
+
+
+interface TestContext {
+
+ apiWrapper: TypeMoq.IMock;
+ config: TypeMoq.IMock;
+ queryRunner: TypeMoq.IMock;
+}
+
+function createContext(): TestContext {
+
+ return {
+ apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
+ config: TypeMoq.Mock.ofType(Config),
+ queryRunner: TypeMoq.Mock.ofType(QueryRunner)
+ };
+}
+
+describe('PredictService', () => {
+
+ it('getDatabaseList should return databases successfully', async function (): Promise {
+ const testContext = createContext();
+ const expected: string[] = [
+ 'db1',
+ 'db2'
+ ];
+ const connection = new azdata.connection.ConnectionProfile();
+ testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
+ testContext.apiWrapper.setup(x => x.listDatabases(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(expected); });
+
+ let service = new PredictService(
+ testContext.apiWrapper.object,
+ testContext.queryRunner.object,
+ testContext.config.object);
+ const actual = await service.getDatabaseList();
+ should.deepEqual(actual, expected);
+ });
+
+ it('getTableList should return tables successfully', async function (): Promise {
+ const testContext = createContext();
+ const expected: DatabaseTable[] = [
+ {
+ databaseName: 'db1',
+ schema: 'dbo',
+ tableName: 'tb1'
+ },
+ {
+ databaseName: 'db1',
+ tableName: 'tb2',
+ schema: 'dbo'
+ }
+ ];
+
+ const result = {
+ rowCount: 1,
+ columnInfo: [],
+ rows: [[
+ {
+ displayValue: 'tb1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'dbo',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ }
+ ], [
+ {
+ displayValue: 'tb2',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'dbo',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ }
+ ]]
+ };
+ const connection = new azdata.connection.ConnectionProfile();
+ testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
+ testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
+ let service = new PredictService(
+ testContext.apiWrapper.object,
+ testContext.queryRunner.object,
+ testContext.config.object);
+ const actual = await service.getTableList('db1');
+ should.deepEqual(actual, expected);
+ });
+
+ it('getTableColumnsList should return table columns successfully', async function (): Promise {
+ const testContext = createContext();
+ const expected: TableColumn[] = [
+ {
+ columnName: 'c1',
+ dataType: 'int'
+ },
+ {
+ columnName: 'c2',
+ dataType: 'varchar'
+ }
+ ];
+ const table: DatabaseTable =
+ {
+ databaseName: 'db1',
+ schema: 'dbo',
+ tableName: 'tb1'
+ };
+
+ const result = {
+ rowCount: 1,
+ columnInfo: [],
+ rows: [[
+ {
+ displayValue: 'c1',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'int',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ }
+ ], [
+ {
+ displayValue: 'c2',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ },
+ {
+ displayValue: 'varchar',
+ isNull: false,
+ invariantCultureDisplayValue: ''
+ }
+ ]]
+ };
+ const connection = new azdata.connection.ConnectionProfile();
+ testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
+
+ testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
+ let service = new PredictService(
+ testContext.apiWrapper.object,
+ testContext.queryRunner.object,
+ testContext.config.object);
+ const actual = await service.getTableColumnsList(table);
+ should.deepEqual(actual, expected);
+ });
+
+ it('generatePredictScript should generate the script successfully using model', async function (): Promise {
+ const testContext = createContext();
+ const connection = new azdata.connection.ConnectionProfile();
+ testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
+ const predictParams: PredictParameters = {
+ inputColumns: [
+ {
+ paramName: 'p1',
+ dataType: 'int',
+ columnName: ''
+ },
+ {
+ paramName: 'p2',
+ dataType: 'varchar',
+ columnName: ''
+ }
+ ],
+ outputColumns: [
+ {
+ paramName: 'o1',
+ dataType: 'int',
+ columnName: ''
+ },
+ ],
+ databaseName: '',
+ tableName: '',
+ schema: ''
+ };
+ const model: RegisteredModel =
+ {
+ id: 1,
+ artifactName: 'name1',
+ title: 'title1',
+ description: 'desc1',
+ created: '2018-01-01',
+ version: '1.1'
+ };
+
+ let service = new PredictService(
+ testContext.apiWrapper.object,
+ testContext.queryRunner.object,
+ testContext.config.object);
+
+ const document: vscode.TextDocument = {
+ uri: vscode.Uri.parse('file:///usr/home'),
+ fileName: '',
+ isUntitled: true,
+ languageId: 'sql',
+ version: 1,
+ isDirty: true,
+ isClosed: false,
+ save: undefined!,
+ eol: undefined!,
+ lineCount: 1,
+ lineAt: undefined!,
+ offsetAt: undefined!,
+ positionAt: undefined!,
+ getText: undefined!,
+ getWordRangeAtPosition: undefined!,
+ validateRange: undefined!,
+ validatePosition: undefined!
+ };
+ testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
+ testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
+ testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
+
+ const actual = await service.generatePredictScript(predictParams, model, undefined);
+ should.notEqual(actual, undefined);
+ should.equal(actual.indexOf('FROM PREDICT(MODEL = @model') > 0, true);
+ });
+
+ it('generatePredictScript should generate the script successfully using file', async function (): Promise {
+ const testContext = createContext();
+ const connection = new azdata.connection.ConnectionProfile();
+ testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
+ const predictParams: PredictParameters = {
+ inputColumns: [
+ {
+ paramName: 'p1',
+ dataType: 'int',
+ columnName: ''
+ },
+ {
+ paramName: 'p2',
+ dataType: 'varchar',
+ columnName: ''
+ }
+ ],
+ outputColumns: [
+ {
+ paramName: 'o1',
+ dataType: 'int',
+ columnName: ''
+ },
+ ],
+ databaseName: '',
+ tableName: '',
+ schema: ''
+ };
+ const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
+ await fs.promises.writeFile(tempFilePath, 'test');
+
+ let service = new PredictService(
+ testContext.apiWrapper.object,
+ testContext.queryRunner.object,
+ testContext.config.object);
+
+ const document: vscode.TextDocument = {
+ uri: vscode.Uri.parse('file:///usr/home'),
+ fileName: '',
+ isUntitled: true,
+ languageId: 'sql',
+ version: 1,
+ isDirty: true,
+ isClosed: false,
+ save: undefined!,
+ eol: undefined!,
+ lineCount: 1,
+ lineAt: undefined!,
+ offsetAt: undefined!,
+ positionAt: undefined!,
+ getText: undefined!,
+ getWordRangeAtPosition: undefined!,
+ validateRange: undefined!,
+ validatePosition: undefined!
+ };
+ testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
+ testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
+ testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
+
+ const actual = await service.generatePredictScript(predictParams, undefined, tempFilePath);
+ should.notEqual(actual, undefined);
+ should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true);
+ });
+});
diff --git a/extensions/machine-learning-services/src/test/utils.ts b/extensions/machine-learning-services/src/test/utils.ts
new file mode 100644
index 0000000000..420e35a506
--- /dev/null
+++ b/extensions/machine-learning-services/src/test/utils.ts
@@ -0,0 +1,38 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the Source EULA. See License.txt in the project root for license information.
+ *--------------------------------------------------------------------------------------------*/
+
+import * as vscode from 'vscode';
+import * as azdata from 'azdata';
+
+export interface TestContext {
+
+ outputChannel: vscode.OutputChannel;
+ op: azdata.BackgroundOperation;
+ getOpStatus: () => azdata.TaskStatus;
+}
+
+export function createContext(): TestContext {
+ let opStatus: azdata.TaskStatus;
+
+ return {
+ outputChannel: {
+ name: '',
+ append: () => { },
+ appendLine: () => { },
+ clear: () => { },
+ show: () => { },
+ hide: () => { },
+ dispose: () => { }
+ },
+ op: {
+ updateStatus: (status: azdata.TaskStatus) => {
+ opStatus = status;
+ },
+ id: '',
+ onCanceled: new vscode.EventEmitter().event,
+ },
+ getOpStatus: () => { return opStatus; }
+ };
+}
diff --git a/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts
new file mode 100644
index 0000000000..c5b40ceb11
--- /dev/null
+++ b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts
@@ -0,0 +1,178 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the Source EULA. See License.txt in the project root for license information.
+ *--------------------------------------------------------------------------------------------*/
+
+import * as azdata from 'azdata';
+import * as should from 'should';
+import 'mocha';
+import { createContext } from './utils';
+import {
+ ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
+ ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName
+}
+ from '../../../views/models/modelViewBase';
+import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces';
+import { azureResource } from '../../../typings/azure-resource';
+import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
+import { ViewBase } from '../../../views/viewBase';
+import { WorkspaceModel } from '../../../modelManagement/interfaces';
+import { PredictWizard } from '../../../views/models/prediction/predictWizard';
+import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
+
+describe('Predict Wizard', () => {
+ it('Should create view components successfully ', async function (): Promise {
+ let testContext = createContext();
+
+ let view = new PredictWizard(testContext.apiWrapper.object, '');
+ await view.open();
+ should.notEqual(view.wizardView, undefined);
+ should.notEqual(view.modelSourcePage, undefined);
+ });
+
+ it('Should load data successfully ', async function (): Promise {
+ let testContext = createContext();
+
+ let view = new PredictWizard(testContext.apiWrapper.object, '');
+ await view.open();
+ let accounts: azdata.Account[] = [
+ {
+ key: {
+ accountId: '1',
+ providerId: ''
+ },
+ displayInfo: {
+ displayName: 'account',
+ userId: '',
+ accountType: '',
+ contextualDisplayName: ''
+ },
+ isStale: false,
+ properties: []
+ }
+ ];
+ let subscriptions: azureResource.AzureResourceSubscription[] = [
+ {
+ name: 'subscription',
+ id: '2'
+ }
+ ];
+ let groups: azureResource.AzureResourceResourceGroup[] = [
+ {
+ name: 'group',
+ id: '3'
+ }
+ ];
+ let workspaces: Workspace[] = [
+ {
+ name: 'workspace',
+ id: '4'
+ }
+ ];
+ let models: WorkspaceModel[] = [
+ {
+ id: '5',
+ name: 'model'
+ }
+ ];
+ let localModels: RegisteredModel[] = [
+ {
+ id: 1,
+ artifactName: 'model',
+ title: 'model'
+ }
+ ];
+ const dbNames: string[] = [
+ 'db1',
+ 'db2'
+ ];
+ const tableNames: DatabaseTable[] = [
+ {
+ databaseName: 'db1',
+ schema: 'dbo',
+ tableName: 'tb1'
+ },
+ {
+ databaseName: 'db1',
+ tableName: 'tb2',
+ schema: 'dbo'
+ }
+ ];
+ const columnNames: TableColumn[] = [
+ {
+ columnName: 'c1',
+ dataType: 'int'
+ },
+ {
+ columnName: 'c2',
+ dataType: 'varchar'
+ }
+ ];
+ const modelParameters: ModelParameters = {
+ inputs: [
+ {
+ 'name': 'p1',
+ 'type': 'int'
+ },
+ {
+ 'name': 'p2',
+ 'type': 'varchar'
+ }
+ ],
+ outputs: [
+ {
+ 'name': 'o1',
+ 'type': 'int'
+ }
+ ]
+ };
+
+ view.on(ListModelsEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
+ });
+ view.on(ListAccountsEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
+ });
+ view.on(ListSubscriptionsEventName, () => {
+
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
+ });
+ view.on(ListGroupsEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
+ });
+ view.on(ListWorkspacesEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
+ });
+ view.on(ListAzureModelsEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
+ });
+ view.on(ListDatabaseNamesEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: dbNames });
+ });
+ view.on(ListTableNamesEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: tableNames });
+ });
+ view.on(ListColumnNamesEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { data: columnNames });
+ });
+ view.on(LoadModelParametersEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { data: modelParameters });
+ });
+ view.on(DownloadAzureModelEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { data: 'path' });
+ });
+ view.on(DownloadRegisteredModelEventName, () => {
+ view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' });
+ });
+ await view.refresh();
+ should.notEqual(view.azureModelsComponent?.data, undefined);
+ should.notEqual(view.localModelsComponent?.data, undefined);
+
+ should.notEqual(await view.getModelFileName(), undefined);
+ await view.columnsSelectionPage?.onEnter();
+
+ should.notEqual(view.columnsSelectionPage?.data, undefined);
+ should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name);
+ should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length);
+ });
+});
diff --git a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts
index 21c37a93ac..d4f542fc0c 100644
--- a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts
+++ b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts
@@ -20,7 +20,7 @@ describe('Register Model Wizard', () => {
let testContext = createContext();
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
- view.open();
+ await view.open();
await view.refresh();
should.notEqual(view.wizardView, undefined);
should.notEqual(view.modelSourcePage, undefined);
@@ -30,7 +30,7 @@ describe('Register Model Wizard', () => {
let testContext = createContext();
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
- view.open();
+ await view.open();
let accounts: azdata.Account[] = [
{
key: {
@@ -98,5 +98,7 @@ describe('Register Model Wizard', () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
});
await view.refresh();
+ should.notEqual(view.azureModelsComponent?.data ,undefined);
+ should.notEqual(view.localModelsComponent?.data, undefined);
});
});
diff --git a/extensions/machine-learning-services/src/test/views/models/utils.ts b/extensions/machine-learning-services/src/test/views/models/utils.ts
index 964ecd93e8..c65150a998 100644
--- a/extensions/machine-learning-services/src/test/views/models/utils.ts
+++ b/extensions/machine-learning-services/src/test/views/models/utils.ts
@@ -7,14 +7,12 @@ import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as TypeMoq from 'typemoq';
import { ApiWrapper } from '../../../common/apiWrapper';
-import * as mssql from '../../../../../mssql/src/mssql';
import { createViewContext } from '../utils';
import { ModelViewBase } from '../../../views/models/modelViewBase';
export interface TestContext {
apiWrapper: TypeMoq.IMock;
view: azdata.ModelView;
- languageExtensionService: mssql.ILanguageExtensionService;
onClick: vscode.EventEmitter;
}
@@ -34,16 +32,10 @@ export class ParentDialog extends ModelViewBase {
export function createContext(): TestContext {
let viewTestContext = createViewContext();
- let languageExtensionService: mssql.ILanguageExtensionService = {
- listLanguages: () => { return Promise.resolve([]); },
- deleteLanguage: () => { return Promise.resolve(); },
- updateLanguage: () => { return Promise.resolve(); }
- };
return {
apiWrapper: viewTestContext.apiWrapper,
view: viewTestContext.view,
- languageExtensionService: languageExtensionService,
onClick: viewTestContext.onClick
};
}
diff --git a/extensions/machine-learning-services/src/test/views/utils.ts b/extensions/machine-learning-services/src/test/views/utils.ts
index 3ae57623f4..7f36593597 100644
--- a/extensions/machine-learning-services/src/test/views/utils.ts
+++ b/extensions/machine-learning-services/src/test/views/utils.ts
@@ -62,6 +62,9 @@ export function createViewContext(): ViewTestContext {
onTextChanged: undefined!,
onEnterKeyPressed: undefined!,
value: ''
+ });
+ let image: () => azdata.ImageComponent = () => Object.assign({}, componentBase, {
+
});
let dropdown: () => azdata.DropDownComponent = () => Object.assign({}, componentBase, {
onValueChanged: onClick.event,
@@ -124,6 +127,14 @@ export function createViewContext(): ViewTestContext {
withProperties: () => inputBoxBuilder,
withValidation: () => inputBoxBuilder
};
+ let imageBuilder: azdata.ComponentBuilder = {
+ component: () => {
+ let r = image();
+ return r;
+ },
+ withProperties: () => imageBuilder,
+ withValidation: () => imageBuilder
+ };
let dropdownBuilder: azdata.ComponentBuilder = {
component: () => {
let r = dropdown();
@@ -156,7 +167,7 @@ export function createViewContext(): ViewTestContext {
editor: undefined!,
diffeditor: undefined!,
text: () => inputBoxBuilder,
- image: undefined!,
+ image: () => imageBuilder,
button: () => buttonBuilder,
dropDown: () => dropdownBuilder,
tree: undefined!,
@@ -181,7 +192,7 @@ export function createViewContext(): ViewTestContext {
try {
await handler(view);
} catch (err) {
- console.log(err);
+ throw err;
}
},
onValidityChanged: undefined!,
@@ -242,7 +253,13 @@ export function createViewContext(): ViewTestContext {
enabled: true,
description: '',
onValidityChanged: onClick.event,
- registerContent: () => { },
+ registerContent: async (handler) => {
+ try {
+ await handler(view);
+ } catch (err) {
+ throw err;
+ }
+ },
modelView: undefined!,
valid: true
};
diff --git a/extensions/machine-learning-services/src/views/externalLanguages/languageViewBase.ts b/extensions/machine-learning-services/src/views/externalLanguages/languageViewBase.ts
index c652fcb507..9e3cc04550 100644
--- a/extensions/machine-learning-services/src/views/externalLanguages/languageViewBase.ts
+++ b/extensions/machine-learning-services/src/views/externalLanguages/languageViewBase.ts
@@ -107,14 +107,14 @@ export abstract class LanguageViewBase {
if (connection) {
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : constants.extLangLocal}`;
}
- return constants.packageManagerNoConnection;
+ return constants.noConnectionError;
}
public getServerTitle(): string {
if (this.connection) {
return this.connection.serverName;
}
- return constants.packageManagerNoConnection;
+ return constants.noConnectionError;
}
private async getCurrentConnectionUrl(): Promise {
diff --git a/extensions/machine-learning-services/src/views/interfaces.ts b/extensions/machine-learning-services/src/views/interfaces.ts
index 028a656e16..bae2be5bd4 100644
--- a/extensions/machine-learning-services/src/views/interfaces.ts
+++ b/extensions/machine-learning-services/src/views/interfaces.ts
@@ -18,6 +18,7 @@ export interface IPageView {
onLeave?: () => Promise;
validate?: () => Promise;
refresh: () => Promise;
+ disposePage?: () => Promise;
viewPanel: azdata.window.ModelViewPanel | undefined;
title: string;
}
diff --git a/extensions/machine-learning-services/src/views/mainViewBase.ts b/extensions/machine-learning-services/src/views/mainViewBase.ts
index fa3e5b342e..ec996d6d30 100644
--- a/extensions/machine-learning-services/src/views/mainViewBase.ts
+++ b/extensions/machine-learning-services/src/views/mainViewBase.ts
@@ -37,6 +37,16 @@ export class MainViewBase {
}
}
+ public async disposePages(): Promise {
+ if (this._pages) {
+ await Promise.all(this._pages.map(async (p) => {
+ if (p.disposePage) {
+ await p.disposePage();
+ }
+ }));
+ }
+ }
+
public async refresh(): Promise {
if (this._pages) {
await Promise.all(this._pages.map(async (p) => await p.refresh()));
diff --git a/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts b/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts
index d04c1d0709..3dde5de665 100644
--- a/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts
+++ b/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts
@@ -9,6 +9,7 @@ import { ApiWrapper } from '../../common/apiWrapper';
import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
import { AzureModelsTable } from './azureModelsTable';
import { IDataComponent, AzureModelResource } from '../interfaces';
+import { ModelArtifact } from './prediction/modelArtifact';
export class AzureModelsComponent extends ModelViewBase implements IDataComponent {
@@ -17,6 +18,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
private _loader: azdata.LoadingComponent | undefined;
private _form: azdata.FormContainer | undefined;
+ private _downloadedFile: ModelArtifact | undefined;
/**
* Component to render a view to pick an azure model
@@ -37,8 +39,14 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
.withProperties({
loading: true
}).component();
+ this.azureModelsTable.onModelSelectionChanged(async () => {
+ if (this._downloadedFile) {
+ await this._downloadedFile.close();
+ }
+ this._downloadedFile = undefined;
+ });
- this.azureFilterComponent.onWorkspacesSelected(async () => {
+ this.azureFilterComponent.onWorkspacesSelectedChanged(async () => {
await this.onLoading();
await this.azureModelsTable?.loadData(this.azureFilterComponent?.data);
await this.onLoaded();
@@ -107,6 +115,22 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
});
}
+ public async getDownloadedModel(): Promise {
+ if (!this._downloadedFile) {
+ this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(this.data));
+ }
+ return this._downloadedFile;
+ }
+
+ /**
+ * disposes the view
+ */
+ public async disposeComponent(): Promise {
+ if (this._downloadedFile) {
+ await this._downloadedFile.close();
+ }
+ }
+
/**
* Refreshes the view
*/
diff --git a/extensions/machine-learning-services/src/views/models/azureModelsTable.ts b/extensions/machine-learning-services/src/views/models/azureModelsTable.ts
index 37c3caea51..ea128c7cf7 100644
--- a/extensions/machine-learning-services/src/views/models/azureModelsTable.ts
+++ b/extensions/machine-learning-services/src/views/models/azureModelsTable.ts
@@ -4,6 +4,7 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
+import * as vscode from 'vscode';
import * as constants from '../../common/constants';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
@@ -18,6 +19,8 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent = new vscode.EventEmitter();
+ public readonly onModelSelectionChanged: vscode.Event = this._onModelSelectionChanged.event;
/**
* Creates a view to render azure models in a table
@@ -115,6 +118,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent {
this._selectedModelId = model.id;
+ this._onModelSelectionChanged.fire();
});
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
}
diff --git a/extensions/machine-learning-services/src/views/models/azureResourceFilterComponent.ts b/extensions/machine-learning-services/src/views/models/azureResourceFilterComponent.ts
index 43b9fba599..ad36ab8a33 100644
--- a/extensions/machine-learning-services/src/views/models/azureResourceFilterComponent.ts
+++ b/extensions/machine-learning-services/src/views/models/azureResourceFilterComponent.ts
@@ -27,8 +27,8 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
private _azureSubscriptions: azureResource.AzureResourceSubscription[] = [];
private _azureGroups: azureResource.AzureResource[] = [];
private _azureWorkspaces: Workspace[] = [];
- private _onWorkspacesSelected: vscode.EventEmitter = new vscode.EventEmitter();
- public readonly onWorkspacesSelected: vscode.Event = this._onWorkspacesSelected.event;
+ private _onWorkspacesSelectedChanged: vscode.EventEmitter = new vscode.EventEmitter();
+ public readonly onWorkspacesSelectedChanged: vscode.Event = this._onWorkspacesSelectedChanged.event;
/**
* Creates a new view
@@ -59,7 +59,7 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
await this.onGroupSelected();
});
this._workspaces.onValueChanged(async () => {
- await this.onWorkspaceSelected();
+ await this.onWorkspaceSelectedChanged();
});
this._form = this._modelBuilder.formContainer().withFormItems([{
@@ -182,26 +182,26 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
this._workspaces.values = values;
this._workspaces.value = values[0];
}
- this.onWorkspaceSelected();
+ this.onWorkspaceSelectedChanged();
}
- private onWorkspaceSelected(): void {
- this._onWorkspacesSelected.fire();
+ private onWorkspaceSelectedChanged(): void {
+ this._onWorkspacesSelectedChanged.fire();
}
private get workspace(): Workspace | undefined {
- return this._azureWorkspaces ? this._azureWorkspaces.find(a => a.id === (this._workspaces.value).name) : undefined;
+ return this._azureWorkspaces && this._workspaces.value ? this._azureWorkspaces.find(a => a.id === (this._workspaces.value).name) : undefined;
}
private get account(): azdata.Account | undefined {
- return this._azureAccounts ? this._azureAccounts.find(a => a.key.accountId === (this._accounts.value).name) : undefined;
+ return this._azureAccounts && this._accounts.value ? this._azureAccounts.find(a => a.key.accountId === (this._accounts.value).name) : undefined;
}
private get group(): azureResource.AzureResource | undefined {
- return this._azureGroups ? this._azureGroups.find(a => a.id === (this._groups.value).name) : undefined;
+ return this._azureGroups && this._groups.value ? this._azureGroups.find(a => a.id === (this._groups.value).name) : undefined;
}
private get subscription(): azureResource.AzureResourceSubscription | undefined {
- return this._azureSubscriptions ? this._azureSubscriptions.find(a => a.id === (this._subscriptions.value).name) : undefined;
+ return this._azureSubscriptions && this._subscriptions.value ? this._azureSubscriptions.find(a => a.id === (this._subscriptions.value).name) : undefined;
}
}
diff --git a/extensions/machine-learning-services/src/views/models/modelManagementController.ts b/extensions/machine-learning-services/src/views/models/modelManagementController.ts
index 5076d56b37..97c4bcfe71 100644
--- a/extensions/machine-learning-services/src/views/models/modelManagementController.ts
+++ b/extensions/machine-learning-services/src/views/models/modelManagementController.ts
@@ -9,15 +9,15 @@ import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
-import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
-import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
-import { RegisteredModelService } from '../../modelManagement/registeredModelService';
+import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces';
+import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
+import { DeployedModelService } from '../../modelManagement/deployedModelService';
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
- ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs
+ ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName
} from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { RegisterModelWizard } from './registerModels/registerModelWizard';
@@ -39,7 +39,7 @@ export class ModelManagementController extends ControllerBase {
apiWrapper: ApiWrapper,
private _root: string,
private _amlService: AzureModelRegistryService,
- private _registeredModelService: RegisteredModelService,
+ private _registeredModelService: DeployedModelService,
private _predictService: PredictService) {
super(apiWrapper);
}
@@ -61,7 +61,7 @@ export class ModelManagementController extends ControllerBase {
// Open view
//
- view.open();
+ await view.open();
await view.refresh();
return view;
}
@@ -74,10 +74,15 @@ export class ModelManagementController extends ControllerBase {
let view = new PredictWizard(this._apiWrapper, this._root);
this.registerEvents(view);
+ view.on(LoadModelParametersEventName, async () => {
+ const modelArtifact = await view.getModelFileName();
+ await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService,
+ modelArtifact?.filePath);
+ });
// Open view
//
- view.open();
+ await view.open();
await view.refresh();
return view;
}
@@ -151,6 +156,11 @@ export class ModelManagementController extends ControllerBase {
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
predictArgs, predictArgs.model, predictArgs.filePath);
});
+ view.on(DownloadRegisteredModelEventName, async (arg) => {
+ let model = arg;
+ await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
+ model);
+ });
view.on(SourceModelSelectedEventName, () => {
view.refresh();
});
@@ -191,8 +201,8 @@ export class ModelManagementController extends ControllerBase {
return await service.getWorkspaces(account, subscription, group);
}
- private async getRegisteredModels(registeredModelService: RegisteredModelService): Promise {
- return registeredModelService.getRegisteredModels();
+ private async getRegisteredModels(registeredModelService: DeployedModelService): Promise {
+ return registeredModelService.getDeployedModels();
}
private async getAzureModels(
@@ -207,9 +217,9 @@ export class ModelManagementController extends ControllerBase {
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
}
- private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise {
+ private async registerLocalModel(service: DeployedModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise {
if (filePath) {
- await service.registerLocalModel(filePath, details);
+ await service.deployLocalModel(filePath, details);
} else {
throw Error(constants.invalidModelToRegisterError);
@@ -218,7 +228,7 @@ export class ModelManagementController extends ControllerBase {
private async registerAzureModel(
azureService: AzureModelRegistryService,
- service: RegisteredModelService,
+ service: DeployedModelService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
@@ -231,7 +241,7 @@ export class ModelManagementController extends ControllerBase {
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
if (filePath) {
- await service.registerLocalModel(filePath, details);
+ await service.deployLocalModel(filePath, details);
await fs.promises.unlink(filePath);
} else {
throw Error(constants.invalidModelToRegisterError);
@@ -246,7 +256,7 @@ export class ModelManagementController extends ControllerBase {
return await predictService.getTableList(databaseName);
}
- public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise {
+ public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise {
return await predictService.getTableColumnsList(databaseTable);
}
@@ -263,6 +273,24 @@ export class ModelManagementController extends ControllerBase {
return result;
}
+ private async downloadRegisteredModel(
+ registeredModelService: DeployedModelService,
+ model: RegisteredModel | undefined): Promise {
+ if (!model) {
+ throw Error(constants.invalidModelToPredictError);
+ }
+ return await registeredModelService.downloadModel(model);
+ }
+
+ private async loadModelParameters(
+ registeredModelService: DeployedModelService,
+ model: string | undefined): Promise {
+ if (!model) {
+ return undefined;
+ }
+ return await registeredModelService.loadModelParameters(model);
+ }
+
private async downloadAzureModel(
azureService: AzureModelRegistryService,
account: azdata.Account | undefined,
diff --git a/extensions/machine-learning-services/src/views/models/modelSourcePage.ts b/extensions/machine-learning-services/src/views/models/modelSourcePage.ts
index a3f521f4f2..c8b046b5c7 100644
--- a/extensions/machine-learning-services/src/views/models/modelSourcePage.ts
+++ b/extensions/machine-learning-services/src/views/models/modelSourcePage.ts
@@ -120,4 +120,14 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
}
return Promise.resolve(validated);
}
+
+ public async disposePage(): Promise {
+ if (this.azureModelsComponent) {
+ await this.azureModelsComponent.disposeComponent();
+
+ }
+ if (this.registeredModelsComponent) {
+ await this.registeredModelsComponent.disposeComponent();
+ }
+ }
}
diff --git a/extensions/machine-learning-services/src/views/models/modelViewBase.ts b/extensions/machine-learning-services/src/views/models/modelViewBase.ts
index 6b1b39da06..5b69ecc19b 100644
--- a/extensions/machine-learning-services/src/views/models/modelViewBase.ts
+++ b/extensions/machine-learning-services/src/views/models/modelViewBase.ts
@@ -8,8 +8,8 @@ import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { ViewBase } from '../viewBase';
-import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
-import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
+import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces';
+import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
@@ -47,9 +47,11 @@ export const ListWorkspacesEventName = 'listWorkspaces';
export const RegisterLocalModelEventName = 'registerLocalModel';
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
+export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
export const PredictModelEventName = 'predictModel';
export const RegisterModelEventName = 'registerModel';
export const SourceModelSelectedEventName = 'sourceModelSelected';
+export const LoadModelParametersEventName = 'loadModelParameters';
/**
* Base class for all model management views
@@ -75,7 +77,9 @@ export abstract class ModelViewBase extends ViewBase {
ListTableNamesEventName,
ListColumnNamesEventName,
PredictModelEventName,
- DownloadAzureModelEventName]);
+ DownloadAzureModelEventName,
+ DownloadRegisteredModelEventName,
+ LoadModelParametersEventName]);
}
/**
@@ -124,7 +128,7 @@ export abstract class ModelViewBase extends ViewBase {
/**
* lists column names
*/
- public async listColumnNames(table: DatabaseTable): Promise {
+ public async listColumnNames(table: DatabaseTable): Promise {
return await this.sendDataRequest(ListColumnNamesEventName, table);
}
@@ -151,6 +155,14 @@ export abstract class ModelViewBase extends ViewBase {
return await this.sendDataRequest(RegisterLocalModelEventName, args);
}
+ /**
+ * downloads registered model
+ * @param model model to download
+ */
+ public async downloadRegisteredModel(model: RegisteredModel | undefined): Promise {
+ return await this.sendDataRequest(DownloadRegisteredModelEventName, model);
+ }
+
/**
* download azure model
* @param args azure resource
@@ -159,6 +171,13 @@ export abstract class ModelViewBase extends ViewBase {
return await this.sendDataRequest(DownloadAzureModelEventName, resource);
}
+ /**
+ * Loads model parameters
+ */
+ public async loadModelParameters(): Promise {
+ return await this.sendDataRequest(LoadModelParametersEventName);
+ }
+
/**
* registers azure model
* @param args azure resource
diff --git a/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts b/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts
index f0c622442d..6eb45172df 100644
--- a/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts
+++ b/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts
@@ -8,7 +8,7 @@ import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
-import { ColumnsFilterComponent } from './columnsFilterComponent';
+import { InputColumnsComponent } from './inputColumnsComponent';
import { OutputColumnsComponent } from './outputColumnsComponent';
import { PredictParameters } from '../../../prediction/interfaces';
@@ -19,7 +19,7 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
- public columnsFilterComponent: ColumnsFilterComponent | undefined;
+ public inputColumnsComponent: InputColumnsComponent | undefined;
public outputColumnsComponent: OutputColumnsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
@@ -32,15 +32,14 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
- this.columnsFilterComponent = new ColumnsFilterComponent(this._apiWrapper, this);
- this.columnsFilterComponent.registerComponent(modelBuilder);
- this.columnsFilterComponent.addComponents(this._formBuilder);
- this.refresh();
+ this.inputColumnsComponent = new InputColumnsComponent(this._apiWrapper, this);
+ this.inputColumnsComponent.registerComponent(modelBuilder);
+ this.inputColumnsComponent.addComponents(this._formBuilder);
this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this);
this.outputColumnsComponent.registerComponent(modelBuilder);
this.outputColumnsComponent.addComponents(this._formBuilder);
- this.refresh();
+
this._form = this._formBuilder.component();
return this._form;
}
@@ -49,8 +48,8 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
* Returns selected data
*/
public get data(): PredictParameters | undefined {
- return this.columnsFilterComponent?.data && this.outputColumnsComponent?.data ?
- Object.assign({}, this.columnsFilterComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
+ return this.inputColumnsComponent?.data && this.outputColumnsComponent?.data ?
+ Object.assign({}, this.inputColumnsComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
undefined;
}
@@ -66,8 +65,8 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
*/
public async refresh(): Promise {
if (this._formBuilder) {
- if (this.columnsFilterComponent) {
- await this.columnsFilterComponent.refresh();
+ if (this.inputColumnsComponent) {
+ await this.inputColumnsComponent.refresh();
}
if (this.outputColumnsComponent) {
await this.outputColumnsComponent.refresh();
@@ -75,6 +74,24 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
}
}
+ public async onEnter(): Promise {
+ await this.inputColumnsComponent?.onLoading();
+ await this.outputColumnsComponent?.onLoading();
+ try {
+ const modelParameters = await this.loadModelParameters();
+ if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) {
+ this.inputColumnsComponent.modelParameters = modelParameters;
+ this.outputColumnsComponent.modelParameters = modelParameters;
+ await this.inputColumnsComponent.refresh();
+ await this.outputColumnsComponent.refresh();
+ }
+ } catch (error) {
+ this.showErrorMessage(constants.loadModelParameterFailedError, error);
+ }
+ await this.inputColumnsComponent?.onLoaded();
+ await this.outputColumnsComponent?.onLoaded();
+ }
+
/**
* Returns page title
*/
diff --git a/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts b/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts
index 7230f7f63a..717cae893f 100644
--- a/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts
+++ b/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts
@@ -8,133 +8,280 @@ import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { IDataComponent } from '../../interfaces';
-import { PredictColumn, DatabaseTable } from '../../../prediction/interfaces';
+import { PredictColumn, DatabaseTable, TableColumn } from '../../../prediction/interfaces';
+import { ModelParameter, ModelParameters } from '../../../modelManagement/interfaces';
/**
* View to render azure models in a table
*/
export class ColumnsTable extends ModelViewBase implements IDataComponent {
- private _table: azdata.DeclarativeTableComponent;
- private _selectedColumns: PredictColumn[] = [];
- private _columns: string[] | undefined;
+ private _table: azdata.DeclarativeTableComponent | undefined;
+ private _parameters: PredictColumn[] = [];
+ private _loader: azdata.LoadingComponent;
+ private _dataTypes: string[] = [
+ 'bigint',
+ 'int',
+ 'smallint',
+ 'real',
+ 'float',
+ 'varchar(MAX)',
+ 'bit'
+ ];
+
/**
* Creates a view to render azure models in a table
*/
- constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
+ constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _forInput: boolean = true) {
super(apiWrapper, parent.root, parent);
- this._table = this.registerComponent(this._modelBuilder);
+ this._loader = this.registerComponent(this._modelBuilder);
}
/**
* Register components
* @param modelBuilder model builder
*/
- public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
+ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.LoadingComponent {
+ let columnHeader: azdata.DeclarativeTableColumn[];
+ if (this._forInput) {
+ columnHeader = [
+ { // Action
+ displayName: constants.columnName,
+ ariaLabel: constants.columnName,
+ valueType: azdata.DeclarativeDataType.component,
+ isReadOnly: true,
+ width: 50,
+ headerCssStyles: {
+ ...constants.cssStyles.tableHeader
+ },
+ rowCssStyles: {
+ ...constants.cssStyles.tableRow
+ },
+ },
+ { // Name
+ displayName: '',
+ ariaLabel: '',
+ valueType: azdata.DeclarativeDataType.component,
+ isReadOnly: true,
+ width: 50,
+ headerCssStyles: {
+ ...constants.cssStyles.tableHeader
+ },
+ rowCssStyles: {
+ ...constants.cssStyles.tableRow
+ },
+ },
+ { // Name
+ displayName: constants.inputName,
+ ariaLabel: constants.inputName,
+ valueType: azdata.DeclarativeDataType.component,
+ isReadOnly: true,
+ width: 120,
+ headerCssStyles: {
+ ...constants.cssStyles.tableHeader
+ },
+ rowCssStyles: {
+ ...constants.cssStyles.tableRow
+ },
+ }
+ ];
+ } else {
+ columnHeader = [
+ { // Name
+ displayName: constants.outputName,
+ ariaLabel: constants.outputName,
+ valueType: azdata.DeclarativeDataType.string,
+ isReadOnly: true,
+ width: 200,
+ headerCssStyles: {
+ ...constants.cssStyles.tableHeader
+ },
+ rowCssStyles: {
+ ...constants.cssStyles.tableRow
+ },
+ },
+ { // Action
+ displayName: constants.displayName,
+ ariaLabel: constants.displayName,
+ valueType: azdata.DeclarativeDataType.component,
+ isReadOnly: true,
+ width: 50,
+ headerCssStyles: {
+ ...constants.cssStyles.tableHeader
+ },
+ rowCssStyles: {
+ ...constants.cssStyles.tableRow
+ },
+ },
+ { // Action
+ displayName: constants.dataTypeName,
+ ariaLabel: constants.dataTypeName,
+ valueType: azdata.DeclarativeDataType.component,
+ isReadOnly: true,
+ width: 50,
+ headerCssStyles: {
+ ...constants.cssStyles.tableHeader
+ },
+ rowCssStyles: {
+ ...constants.cssStyles.tableRow
+ },
+ }
+ ];
+ }
this._table = modelBuilder.declarativeTable()
+
.withProperties(
{
- columns: [
- { // Name
- displayName: constants.columnDatabase,
- ariaLabel: constants.columnName,
- valueType: azdata.DeclarativeDataType.string,
- isReadOnly: true,
- width: 120,
- headerCssStyles: {
- ...constants.cssStyles.tableHeader
- },
- rowCssStyles: {
- ...constants.cssStyles.tableRow
- },
- },
- { // Action
- displayName: constants.inputName,
- ariaLabel: constants.inputName,
- valueType: azdata.DeclarativeDataType.component,
- isReadOnly: true,
- width: 50,
- headerCssStyles: {
- ...constants.cssStyles.tableHeader
- },
- rowCssStyles: {
- ...constants.cssStyles.tableRow
- },
- },
- { // Action
- displayName: '',
- valueType: azdata.DeclarativeDataType.component,
- isReadOnly: true,
- width: 50,
- headerCssStyles: {
- ...constants.cssStyles.tableHeader
- },
- rowCssStyles: {
- ...constants.cssStyles.tableRow
- },
- }
- ],
+ columns: columnHeader,
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
- return this._table;
+ this._loader = modelBuilder.loadingComponent()
+ .withItem(this._table)
+ .withProperties({
+ loading: true
+ }).component();
+ return this._loader;
}
- public get component(): azdata.DeclarativeTableComponent {
- return this._table;
+ public async onLoading(): Promise {
+ if (this._loader) {
+ await this._loader.updateProperties({ loading: true });
+ }
+ }
+
+ public async onLoaded(): Promise {
+ if (this._loader) {
+ await this._loader.updateProperties({ loading: false });
+ }
+ }
+
+ public get component(): azdata.Component {
+ return this._loader;
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
- public async loadData(table: DatabaseTable): Promise {
- this._selectedColumns = [];
- if (this._table) {
- this._columns = await this.listColumnNames(table);
- let tableData: any[][] = [];
+ public async loadInputs(modelParameters: ModelParameters | undefined, table: DatabaseTable): Promise {
+ await this.onLoading();
+ this._parameters = [];
+ let tableData: any[][] = [];
- if (this._columns) {
- tableData = tableData.concat(this._columns.map(model => this.createTableRow(model)));
+ if (this._table) {
+ if (this._forInput) {
+ const columns = await this.listColumnNames(table);
+ if (modelParameters?.inputs && columns) {
+ tableData = tableData.concat(modelParameters.inputs.map(input => this.createInputTableRow(input, columns)));
+ }
}
this._table.data = tableData;
}
+ await this.onLoaded();
}
- private createTableRow(column: string): any[] {
- if (this._modelBuilder) {
- let selectRowButton = this._modelBuilder.checkBox().withProperties({
+ public async loadOutputs(modelParameters: ModelParameters | undefined): Promise {
+ this.onLoading();
+ this._parameters = [];
+ let tableData: any[][] = [];
- width: 15,
- height: 15,
- checked: true
+ if (this._table) {
+ if (!this._forInput) {
+ if (modelParameters?.outputs && this._dataTypes) {
+ tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes)));
+ }
+ }
+
+ this._table.data = tableData;
+ }
+ this.onLoaded();
+ }
+
+ private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] {
+ if (this._modelBuilder) {
+
+ let nameInput = this._modelBuilder.dropDown().withProperties({
+ values: dataTypes,
+ width: this.componentMaxLength
}).component();
- let nameInputBox = this._modelBuilder.inputBox().withProperties({
- value: '',
- width: 150
- }).component();
- this._selectedColumns.push({ name: column });
- selectRowButton.onChanged(() => {
- if (selectRowButton.checked) {
- if (!this._selectedColumns.find(x => x.name === column)) {
- this._selectedColumns.push({ name: column });
- }
- } else {
- if (this._selectedColumns.find(x => x.name === column)) {
- this._selectedColumns = this._selectedColumns.filter(x => x.name !== column);
+ const name = modelParameter.name;
+ const dataType = dataTypes.find(x => x === modelParameter.type);
+ if (dataType) {
+ nameInput.value = dataType;
+ }
+ this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type });
+
+ nameInput.onValueChanged(() => {
+ const value = nameInput.value;
+ if (value !== modelParameter.type) {
+ let selectedRow = this._parameters.find(x => x.paramName === name);
+ if (selectedRow) {
+ selectedRow.dataType = value;
}
}
});
- nameInputBox.onTextChanged(() => {
- let selectedRow = this._selectedColumns.find(x => x.name === column);
+ let displayNameInput = this._modelBuilder.inputBox().withProperties({
+ value: name,
+ width: 200
+ }).component();
+ displayNameInput.onTextChanged(() => {
+ let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
- selectedRow.displayName = nameInputBox.value;
+ selectedRow.columnName = displayNameInput.value || name;
}
});
- return [column, nameInputBox, selectRowButton];
+ return [`${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`, displayNameInput, nameInput];
+ }
+
+ return [];
+ }
+
+ private createInputTableRow(modelParameter: ModelParameter, columns: TableColumn[] | undefined): any[] {
+ if (this._modelBuilder && columns) {
+ const values = columns.map(c => { return { name: c.columnName, displayName: `${c.columnName}(${c.dataType})` }; });
+ let nameInput = this._modelBuilder.dropDown().withProperties({
+ values: values,
+ width: this.componentMaxLength
+ }).component();
+ const name = modelParameter.name;
+ let column = values.find(x => x.name === modelParameter.name);
+ if (!column) {
+ column = values[0];
+ }
+ nameInput.value = column;
+
+ this._parameters.push({ columnName: column.name, paramName: name });
+
+ nameInput.onValueChanged(() => {
+ const selectedColumn = nameInput.value;
+ const value = selectedColumn ? (selectedColumn).name : undefined;
+
+ let selectedRow = this._parameters.find(x => x.paramName === name);
+ if (selectedRow) {
+ selectedRow.columnName = value || '';
+ }
+ });
+ const label = this._modelBuilder.inputBox().withProperties({
+ value: `${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`,
+ enabled: false,
+ width: this.componentMaxLength
+ }).component();
+ const image = this._modelBuilder.image().withProperties({
+ width: 50,
+ height: 50,
+ iconPath: {
+ dark: this.asAbsolutePath('images/arrow.svg'),
+ light: this.asAbsolutePath('images/arrow.svg')
+ },
+ iconWidth: 20,
+ iconHeight: 20,
+ title: 'maps'
+ }).component();
+ return [nameInput, image, label];
}
return [];
@@ -144,7 +291,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent {
+export class InputColumnsComponent extends ModelViewBase implements IDataComponent {
private _form: azdata.FormContainer | undefined;
private _databases: azdata.DropDownComponent | undefined;
@@ -22,7 +23,9 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
private _columns: ColumnsTable | undefined;
private _dbNames: string[] = [];
private _tableNames: DatabaseTable[] = [];
-
+ private _modelParameters: ModelParameters | undefined;
+ private _dbTableComponent: azdata.FlexContainer | undefined;
+ private tableMaxLength = this.componentMaxLength * 2 + 70;
/**
* Creates a new view
*/
@@ -52,27 +55,47 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
});
- this._form = modelBuilder.formContainer().withFormItems([{
- title: constants.azureAccount,
+ const databaseForm = modelBuilder.formContainer().withFormItems([{
+ title: constants.columnDatabase,
component: this._databases
- }, {
- title: constants.azureSubscription,
+ }]).withLayout({
+ padding: '0px'
+ }).component();
+ const tableForm = modelBuilder.formContainer().withFormItems([{
+ title: constants.columnTable,
component: this._tables
+ }]).withLayout({
+ padding: '0px'
+ }).component();
+ this._dbTableComponent = modelBuilder.flexContainer().withItems([
+ databaseForm,
+ tableForm
+ ], {
+ flex: '0 0 auto',
+ CSSStyles: {
+ 'align-items': 'flex-start'
+ }
+ }).withLayout({
+ flexFlow: 'row',
+ justifyContent: 'space-between',
+ width: this.tableMaxLength
+ }).component();
+
+ this._form = modelBuilder.formContainer().withFormItems([{
+ title: '',
+ component: this._dbTableComponent
}, {
- title: constants.azureGroup,
+ title: constants.inputColumns,
component: this._columns.component
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
- if (this._databases && this._tables && this._columns) {
+ if (this._columns && this._dbTableComponent) {
formBuilder.addFormItems([{
- title: constants.columnDatabase,
- component: this._databases
- }, {
- title: constants.columnTable,
- component: this._tables
+ title: '',
+ component: this._dbTableComponent
}, {
title: constants.inputColumns,
component: this._columns.component
@@ -81,17 +104,13 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
}
public removeComponents(formBuilder: azdata.FormBuilder) {
- if (this._databases && this._tables && this._columns) {
+ if (this._columns && this._dbTableComponent) {
formBuilder.removeFormItem({
- title: constants.azureAccount,
- component: this._databases
+ title: '',
+ component: this._dbTableComponent
});
formBuilder.removeFormItem({
- title: constants.azureSubscription,
- component: this._tables
- });
- formBuilder.removeFormItem({
- title: constants.azureGroup,
+ title: constants.inputColumns,
component: this._columns.component
});
}
@@ -125,6 +144,22 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
await this.onDatabaseSelected();
}
+ public set modelParameters(value: ModelParameters) {
+ this._modelParameters = value;
+ }
+
+ public async onLoading(): Promise {
+ if (this._columns) {
+ await this._columns.onLoading();
+ }
+ }
+
+ public async onLoaded(): Promise {
+ if (this._columns) {
+ await this._columns.onLoaded();
+ }
+ }
+
/**
* refreshes the view
*/
@@ -146,7 +181,7 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
}
private async onTableSelected(): Promise {
- this._columns?.loadData(this.databaseTable);
+ this._columns?.loadInputs(this._modelParameters, this.databaseTable);
}
private get databaseName(): string | undefined {
diff --git a/extensions/machine-learning-services/src/views/models/prediction/modelArtifact.ts b/extensions/machine-learning-services/src/views/models/prediction/modelArtifact.ts
new file mode 100644
index 0000000000..4f2501d249
--- /dev/null
+++ b/extensions/machine-learning-services/src/views/models/prediction/modelArtifact.ts
@@ -0,0 +1,35 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the Source EULA. See License.txt in the project root for license information.
+ *--------------------------------------------------------------------------------------------*/
+
+import * as utils from '../../../common/utils';
+
+/**
+* Wizard to register a model
+*/
+export class ModelArtifact {
+
+ /**
+ * Creates new model artifact
+ */
+ constructor(private _filePath: string, private _deleteAtClose: boolean = true) {
+ }
+
+ public get filePath(): string {
+ return this._filePath;
+ }
+
+ /**
+ * Closes the artifact and disposes the resources
+ */
+ public async close(): Promise {
+ if (this._deleteAtClose) {
+ try {
+ await utils.deleteFile(this._filePath);
+ } catch {
+
+ }
+ }
+ }
+}
diff --git a/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts b/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts
index 35782e9492..7dfbbb9684 100644
--- a/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts
+++ b/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts
@@ -9,25 +9,18 @@ import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { PredictColumn } from '../../../prediction/interfaces';
+import { ColumnsTable } from './columnsTable';
+import { ModelParameters } from '../../../modelManagement/interfaces';
/**
* View to render filters to pick an azure resource
*/
-const componentWidth = 60;
+
export class OutputColumnsComponent extends ModelViewBase implements IDataComponent {
private _form: azdata.FormContainer | undefined;
- private _flex: azdata.FlexContainer | undefined;
- private _columnName: azdata.InputBoxComponent | undefined;
- private _columnTypes: azdata.DropDownComponent | undefined;
- private _dataTypes: string[] = [
- 'int',
- 'nvarchar(MAX)',
- 'varchar(MAX)',
- 'float',
- 'double',
- 'bit'
- ];
+ private _columns: ColumnsTable | undefined;
+ private _modelParameters: ModelParameters | undefined;
/**
* Creates a new view
@@ -41,49 +34,29 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
- this._columnName = modelBuilder.inputBox().withProperties({
- width: this.componentMaxLength - componentWidth - this.spaceBetweenComponentsLength
- }).component();
- this._columnTypes = modelBuilder.dropDown().withProperties({
- width: componentWidth
- }).component();
-
- let flex = modelBuilder.flexContainer()
- .withLayout({
- width: this._columnName.width
- }).withItems([
- this._columnName]
- ).component();
- this._flex = modelBuilder.flexContainer()
- .withLayout({
- flexFlow: 'row',
- justifyContent: 'space-between',
- width: this.componentMaxLength
- }).withItems([
- flex, this._columnTypes]
- ).component();
+ this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this, false);
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
- component: this._flex
+ component: this._columns.component
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
- if (this._flex) {
+ if (this._columns) {
formBuilder.addFormItems([{
title: constants.outputColumns,
- component: this._flex
+ component: this._columns.component
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
- if (this._flex) {
+ if (this._columns) {
formBuilder.removeFormItem({
title: constants.outputColumns,
- component: this._flex
+ component: this._columns.component
});
}
}
@@ -99,9 +72,24 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon
* loads data in the components
*/
public async loadData(): Promise {
- if (this._columnTypes) {
- this._columnTypes.values = this._dataTypes;
- this._columnTypes.value = this._dataTypes[0];
+ if (this._modelParameters) {
+ this._columns?.loadOutputs(this._modelParameters);
+ }
+ }
+
+ public set modelParameters(value: ModelParameters) {
+ this._modelParameters = value;
+ }
+
+ public async onLoading(): Promise {
+ if (this._columns) {
+ await this._columns.onLoading();
+ }
+ }
+
+ public async onLoaded(): Promise {
+ if (this._columns) {
+ await this._columns.onLoaded();
}
}
@@ -116,9 +104,6 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon
* Returns selected data
*/
public get data(): PredictColumn[] | undefined {
- return this._columnName && this._columnTypes ? [{
- name: this._columnName.value || '',
- dataType: this._columnTypes.value || ''
- }] : undefined;
+ return this._columns?.data;
}
}
diff --git a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts
index c7c3784f72..61d41bebf0 100644
--- a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts
+++ b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts
@@ -14,6 +14,7 @@ import { WizardView } from '../../wizardView';
import { ModelSourcePage } from '../modelSourcePage';
import { ColumnsSelectionPage } from './columnsSelectionPage';
import { RegisteredModel } from '../../../modelManagement/interfaces';
+import { ModelArtifact } from './modelArtifact';
/**
* Wizard to register a model
@@ -21,7 +22,6 @@ import { RegisteredModel } from '../../../modelManagement/interfaces';
export class PredictWizard extends ModelViewBase {
public modelSourcePage: ModelSourcePage | undefined;
- //public modelDetailsPage: ModelDetailsPage | undefined;
public columnsSelectionPage: ColumnsSelectionPage | undefined;
public wizardView: WizardView | undefined;
private _parentView: ModelViewBase | undefined;
@@ -37,7 +37,7 @@ export class PredictWizard extends ModelViewBase {
/**
* Opens a dialog to manage packages used by notebooks.
*/
- public open(): void {
+ public async open(): Promise {
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
this.wizardView = new WizardView(this._apiWrapper);
@@ -50,16 +50,22 @@ export class PredictWizard extends ModelViewBase {
wizard.doneButton.label = constants.predictModel;
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
+ wizard.doneButton.onClick(async () => {
+ await this.onClose();
+ });
+ wizard.cancelButton.onClick(async () => {
+ await this.onClose();
+ });
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
- if (validated && pageInfo.newPage === undefined) {
- wizard.cancelButton.enabled = false;
- wizard.backButton.enabled = false;
- await this.predict();
- wizard.cancelButton.enabled = true;
- wizard.backButton.enabled = true;
- if (this._parentView) {
- this._parentView?.refresh();
+ if (validated) {
+ if (pageInfo.newPage === undefined) {
+ this.onLoading();
+ await this.predict();
+ this.onLoaded();
+ if (this._parentView) {
+ this._parentView?.refresh();
+ }
}
return true;
@@ -67,7 +73,22 @@ export class PredictWizard extends ModelViewBase {
return validated;
});
- wizard.open();
+ await wizard.open();
+ }
+
+ private onLoading(): void {
+ this.refreshButtons(true);
+ }
+
+ private onLoaded(): void {
+ this.refreshButtons(false);
+ }
+
+ private refreshButtons(loading: boolean): void {
+ if (this.wizardView && this.wizardView.wizard) {
+ this.wizardView.wizard.cancelButton.enabled = !loading;
+ this.wizardView.wizard.cancelButton.enabled = !loading;
+ }
}
public get modelResources(): ModelSourcesComponent | undefined {
@@ -82,16 +103,26 @@ export class PredictWizard extends ModelViewBase {
return this.modelSourcePage?.azureModelsComponent;
}
+ public async getModelFileName(): Promise {
+ if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
+ return new ModelArtifact(this.localModelsComponent.data, false);
+ } else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
+ return await this.azureModelsComponent.getDownloadedModel();
+ } else if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) {
+ return await this.modelSourcePage.registeredModelsComponent.getDownloadedModel();
+ }
+ return undefined;
+ }
+
private async predict(): Promise {
try {
- let modelFilePath: string = '';
+ let modelFilePath: string | undefined;
let registeredModel: RegisteredModel | undefined = undefined;
- if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
- modelFilePath = this.localModelsComponent.data;
- } else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
- modelFilePath = await this.downloadAzureModel(this.azureModelsComponent?.data);
- } else {
+ if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) {
registeredModel = this.modelSourcePage?.registeredModelsComponent?.data;
+ } else {
+ const artifact = await this.getModelFileName();
+ modelFilePath = artifact?.filePath;
}
await this.generatePredictScript(registeredModel, modelFilePath, this.columnsSelectionPage?.data);
@@ -102,6 +133,14 @@ export class PredictWizard extends ModelViewBase {
}
}
+ private async onClose(): Promise {
+ const artifact = await this.getModelFileName();
+ if (artifact) {
+ artifact.close();
+ }
+ await this.wizardView?.disposePages();
+ }
+
/**
* Refresh the pages
*/
diff --git a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts
index b113858ef7..2723c82dc6 100644
--- a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts
+++ b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts
@@ -15,7 +15,7 @@ import { IPageView } from '../../interfaces';
* View to render current registered models
*/
export class CurrentModelsPage extends ModelViewBase implements IPageView {
- private _tableComponent: azdata.DeclarativeTableComponent | undefined;
+ private _tableComponent: azdata.Component | undefined;
private _dataTable: CurrentModelsTable | undefined;
private _loader: azdata.LoadingComponent | undefined;
diff --git a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts
index 34b91d488c..7efd3bdc8b 100644
--- a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts
+++ b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts
@@ -4,11 +4,13 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
+import * as vscode from 'vscode';
import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { RegisteredModel } from '../../../modelManagement/interfaces';
import { IDataComponent } from '../../interfaces';
+import { ModelArtifact } from '../prediction/modelArtifact';
/**
* View to render registered models table
@@ -18,6 +20,10 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
private _table: azdata.DeclarativeTableComponent | undefined;
private _modelBuilder: azdata.ModelBuilder | undefined;
private _selectedModel: any;
+ private _loader: azdata.LoadingComponent | undefined;
+ private _downloadedFile: ModelArtifact | undefined;
+ private _onModelSelectionChanged: vscode.EventEmitter = new vscode.EventEmitter();
+ public readonly onModelSelectionChanged: vscode.Event = this._onModelSelectionChanged.event;
/**
* Creates new view
@@ -30,7 +36,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
*
* @param modelBuilder register the components
*/
- public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
+ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._modelBuilder = modelBuilder;
this._table = modelBuilder.declarativeTable()
.withProperties(
@@ -92,7 +98,12 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
ariaLabel: constants.mlsConfigTitle
})
.component();
- return this._table;
+ this._loader = modelBuilder.loadingComponent()
+ .withItem(this._table)
+ .withProperties({
+ loading: true
+ }).component();
+ return this._loader;
}
public addComponents(formBuilder: azdata.FormBuilder) {
@@ -111,14 +122,15 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
/**
* Returns the component
*/
- public get component(): azdata.DeclarativeTableComponent | undefined {
- return this._table;
+ public get component(): azdata.Component | undefined {
+ return this._loader;
}
/**
* Loads the data in the component
*/
public async loadData(): Promise {
+ await this.onLoading();
if (this._table) {
let models: RegisteredModel[] | undefined;
@@ -131,6 +143,20 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
this._table.data = tableData;
}
+ this.onModelSelected();
+ await this.onLoaded();
+ }
+
+ public async onLoading(): Promise {
+ if (this._loader) {
+ await this._loader.updateProperties({ loading: true });
+ }
+ }
+
+ public async onLoaded(): Promise {
+ if (this._loader) {
+ await this._loader.updateProperties({ loading: false });
+ }
}
private createTableRow(model: RegisteredModel): any[] {
@@ -142,8 +168,9 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
height: 15,
checked: false
}).component();
- selectModelButton.onDidClick(() => {
+ selectModelButton.onDidClick(async () => {
this._selectedModel = model;
+ await this.onModelSelected();
});
return [model.artifactName, model.title, model.created, selectModelButton];
}
@@ -151,6 +178,14 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
return [];
}
+ private async onModelSelected(): Promise {
+ this._onModelSelectionChanged.fire();
+ if (this._downloadedFile) {
+ await this._downloadedFile.close();
+ }
+ this._downloadedFile = undefined;
+ }
+
/**
* Returns selected data
*/
@@ -158,6 +193,22 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
return this._selectedModel;
}
+ public async getDownloadedModel(): Promise {
+ if (!this._downloadedFile) {
+ this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data));
+ }
+ return this._downloadedFile;
+ }
+
+ /**
+ * disposes the view
+ */
+ public async disposeComponent(): Promise {
+ if (this._downloadedFile) {
+ await this._downloadedFile.close();
+ }
+ }
+
/**
* Refreshes the view
*/
diff --git a/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts b/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts
index 73013595f1..2d42bf00e7 100644
--- a/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts
+++ b/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts
@@ -35,7 +35,7 @@ export class RegisterModelWizard extends ModelViewBase {
/**
* Opens a dialog to manage packages used by notebooks.
*/
- public open(): void {
+ public async open(): Promise {
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
this.wizardView = new WizardView(this._apiWrapper);
@@ -63,7 +63,7 @@ export class RegisterModelWizard extends ModelViewBase {
return validated;
});
- wizard.open();
+ await wizard.open();
}
public get modelResources(): ModelSourcesComponent | undefined {
diff --git a/extensions/machine-learning-services/src/views/viewBase.ts b/extensions/machine-learning-services/src/views/viewBase.ts
index 3702bab3bf..8cc36d9862 100644
--- a/extensions/machine-learning-services/src/views/viewBase.ts
+++ b/extensions/machine-learning-services/src/views/viewBase.ts
@@ -128,14 +128,14 @@ export abstract class ViewBase extends EventEmitterCollection {
if (connection) {
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`;
}
- return constants.packageManagerNoConnection;
+ return constants.noConnectionError;
}
public getServerTitle(): string {
if (this.connection) {
return this.connection.serverName;
}
- return constants.packageManagerNoConnection;
+ return constants.noConnectionError;
}
private async getCurrentConnectionUrl(): Promise {
diff --git a/extensions/machine-learning-services/src/views/wizardView.ts b/extensions/machine-learning-services/src/views/wizardView.ts
index 33976ec0fb..49b1adef5a 100644
--- a/extensions/machine-learning-services/src/views/wizardView.ts
+++ b/extensions/machine-learning-services/src/views/wizardView.ts
@@ -68,7 +68,7 @@ export class WizardView extends MainViewBase {
this._pages = pages;
this._wizard.pages = pages.map(x => this.createWizardPage(x.title || '', x));
this._wizard.onPageChanged(async (info) => {
- this.onWizardPageChanged(info);
+ await this.onWizardPageChanged(info);
});
return this._wizard;
@@ -85,17 +85,17 @@ export class WizardView extends MainViewBase {
return true;
}
- private onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
+ private async onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
let idxLast = pageInfo.lastPage;
let lastPage = this._pages[idxLast];
if (lastPage && lastPage.onLeave) {
- lastPage.onLeave();
+ await lastPage.onLeave();
}
let idx = pageInfo.newPage;
let page = this._pages[idx];
if (page && page.onEnter) {
- page.onEnter();
+ await page.onEnter();
}
}