mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-06 17:23:53 -05:00
Machine Learning Services - Model detection in predict wizard (#9609)
* Machine Learning Services - Model detection in predict wizard
This commit is contained in:
3
extensions/machine-learning-services/images/arrow.svg
Normal file
3
extensions/machine-learning-services/images/arrow.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M15.9531 8L8.60156 15.3516L7.89844 14.6484L14.0469 8.5H0V7.5H14.0469L7.89844 1.35156L8.60156 0.648438L15.9531 8Z" fill="#0078D4"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 243 B |
@@ -60,7 +60,7 @@ export function confirmInstallPythonPackages(packages: string): string {
|
||||
export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ...");
|
||||
export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed.");
|
||||
export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); }
|
||||
export const packageManagerNoConnection = localize('mls.packageManager.NoConnection', "No connection selected");
|
||||
export const noConnectionError = localize('mls.packageManager.NoConnection', "No connection selected");
|
||||
export const notebookExtensionNotLoaded = localize('mls.notebookExtensionNotLoaded', "Notebook extension is not loaded");
|
||||
export const mssqlExtensionNotLoaded = localize('mls.mssqlExtensionNotLoaded', "MSSQL extension is not loaded");
|
||||
export const mlsEnabledMessage = localize('mls.enabledMessage', "Machine Learning Services Enabled");
|
||||
@@ -74,6 +74,8 @@ export const mlsExternalExecuteScriptTitle = localize('mls.externalExecuteScript
|
||||
export const mlsPythonLanguageTitle = localize('mls.pythonLanguageTitle', "Python");
|
||||
export const mlsRLanguageTitle = localize('mls.rLanguageTitle', "R");
|
||||
export const downloadError = localize('mls.downloadError', "Error while downloading");
|
||||
export function invalidModelIdError(modelUrl: string | undefined): string { return localize('mls.invalidModelIdError', "Invalid model id. model url: {0}", modelUrl || ''); }
|
||||
export function noArtifactError(modelUrl: string | undefined): string { return localize('mls.noArtifactError', "Model doesn't have any artifact. model url: {0}", modelUrl || ''); }
|
||||
export const downloadingProgress = localize('mls.downloadingProgress', "Downloading");
|
||||
export const pythonConfigError = localize('mls.pythonConfigError', "Python executable is not configured");
|
||||
export const rConfigError = localize('mls.rConfigError', "R executable is not configured");
|
||||
@@ -119,12 +121,15 @@ export const modelCreated = localize('models.created', "Date Created");
|
||||
export const modelVersion = localize('models.version', "Version");
|
||||
export const browseModels = localize('models.browseButton', "...");
|
||||
export const azureAccount = localize('models.azureAccount', "Azure account");
|
||||
export const columnDatabase = localize('predict.columnDatabase', "Database");
|
||||
export const columnTable = localize('predict.columnTable', "Table");
|
||||
export const inputColumns = localize('predict.inputColumns', "Input columns");
|
||||
export const outputColumns = localize('predict.outputColumns', "Output column");
|
||||
export const columnName = localize('predict.columnName', "Name");
|
||||
export const inputName = localize('predict.inputName', "Input Name");
|
||||
export const columnDatabase = localize('predict.columnDatabase', "Target database");
|
||||
export const columnTable = localize('predict.columnTable', "Target table");
|
||||
export const inputColumns = localize('predict.inputColumns', "Model input mapping");
|
||||
export const outputColumns = localize('predict.outputColumns', "Model output");
|
||||
export const columnName = localize('predict.columnName', "Target columns");
|
||||
export const dataTypeName = localize('predict.dataTypeName', "Type");
|
||||
export const displayName = localize('predict.displayName', "Display name");
|
||||
export const inputName = localize('predict.inputName', "Required model input features");
|
||||
export const outputName = localize('predict.outputName', "Name");
|
||||
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
|
||||
export const azureGroup = localize('models.azureGroup', "Azure resource group");
|
||||
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
|
||||
@@ -134,7 +139,7 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model
|
||||
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
||||
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
||||
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Enter model source details");
|
||||
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Select input columns");
|
||||
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input");
|
||||
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details");
|
||||
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
||||
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
||||
@@ -156,6 +161,8 @@ export const invalidModelToSelectError = localize('models.invalidModelToSelectEr
|
||||
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
||||
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
||||
export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model");
|
||||
export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'");
|
||||
export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported");
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -23,16 +23,19 @@ export class ProcessService {
|
||||
scriptExecution.stdin.end();
|
||||
|
||||
// Add listeners to print stdout and stderr if an output channel was provided
|
||||
if (outputChannel) {
|
||||
scriptExecution.stdout.on('data', data => {
|
||||
|
||||
scriptExecution.stdout.on('data', data => {
|
||||
if (outputChannel) {
|
||||
this.outputDataChunk(data, outputChannel, ' stdout: ');
|
||||
output = output + data.toString();
|
||||
});
|
||||
scriptExecution.stderr.on('data', data => {
|
||||
}
|
||||
output = output + data.toString();
|
||||
});
|
||||
scriptExecution.stderr.on('data', data => {
|
||||
if (outputChannel) {
|
||||
this.outputDataChunk(data, outputChannel, ' stderr: ');
|
||||
output = output + data.toString();
|
||||
});
|
||||
}
|
||||
}
|
||||
output = output + data.toString();
|
||||
});
|
||||
|
||||
scriptExecution.on('exit', (code) => {
|
||||
if (timer) {
|
||||
|
||||
@@ -22,7 +22,17 @@ export async function execCommandOnTempFile<T>(content: string, command: (filePa
|
||||
return result;
|
||||
}
|
||||
finally {
|
||||
await fs.promises.unlink(tempFilePath);
|
||||
await deleteFile(tempFilePath);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a file
|
||||
* @param filePath file path
|
||||
*/
|
||||
export async function deleteFile(filePath: string) {
|
||||
if (filePath) {
|
||||
await fs.promises.unlink(filePath);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,7 +225,7 @@ export function getRegisteredModelsThreePartsName(config: Config) {
|
||||
const dbName = doubleEscapeSingleBrackets(config.registeredModelDatabaseName);
|
||||
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
|
||||
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
|
||||
return `[${dbName}].${schema}.[${tableName}]`;
|
||||
return `[${dbName}].[${schema}].[${tableName}]`;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -227,3 +237,14 @@ export function getRegisteredModelsTowPartsName(config: Config) {
|
||||
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
|
||||
return `[${schema}].[${tableName}]`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Write a file using a hex string
|
||||
* @param content file content
|
||||
*/
|
||||
export async function writeFileFromHex(content: string): Promise<string> {
|
||||
content = content.startsWith('0x') || content.startsWith('0X') ? content.substr(2) : content;
|
||||
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, Buffer.from(content, 'hex'));
|
||||
return tempFilePath;
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ import { HttpClient } from '../common/httpClient';
|
||||
import { LanguageController } from '../views/externalLanguages/languageController';
|
||||
import { LanguageService } from '../externalLanguage/languageService';
|
||||
import { ModelManagementController } from '../views/models/modelManagementController';
|
||||
import { RegisteredModelService } from '../modelManagement/registeredModelService';
|
||||
import { DeployedModelService } from '../modelManagement/deployedModelService';
|
||||
import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService';
|
||||
import { ModelImporter } from '../modelManagement/modelImporter';
|
||||
import { ModelPythonClient } from '../modelManagement/modelPythonClient';
|
||||
import { PredictService } from '../prediction/predictService';
|
||||
|
||||
/**
|
||||
@@ -100,11 +100,11 @@ export default class MainController implements vscode.Disposable {
|
||||
let mssqlService = await this.getLanguageExtensionService();
|
||||
let languagesModel = new LanguageService(this._apiWrapper, mssqlService);
|
||||
let languageController = new LanguageController(this._apiWrapper, this._rootPath, languagesModel);
|
||||
let modelImporter = new ModelImporter(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
|
||||
let modelImporter = new ModelPythonClient(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
|
||||
|
||||
// Model Management
|
||||
//
|
||||
let registeredModelService = new RegisteredModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter);
|
||||
let registeredModelService = new DeployedModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter);
|
||||
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
|
||||
let predictService = new PredictService(this._apiWrapper, this._queryRunner, this._config);
|
||||
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath,
|
||||
|
||||
@@ -28,10 +28,16 @@ import * as utils from '../common/utils';
|
||||
*/
|
||||
export class AzureModelRegistryService {
|
||||
|
||||
private _amlClient: AzureMachineLearningWorkspaces | undefined;
|
||||
private _modelClient: WorkspaceModels | undefined;
|
||||
/**
|
||||
*
|
||||
* Creates new service
|
||||
*/
|
||||
constructor(private _apiWrapper: ApiWrapper, private _config: Config, private _httpClient: HttpClient, private _outputChannel: vscode.OutputChannel) {
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _config: Config,
|
||||
private _httpClient: HttpClient,
|
||||
private _outputChannel: vscode.OutputChannel) {
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -120,10 +126,18 @@ export class AzureModelRegistryService {
|
||||
return downloadedFilePath;
|
||||
}
|
||||
|
||||
public set AzureMachineLearningClient(value: AzureMachineLearningWorkspaces) {
|
||||
this._amlClient = value;
|
||||
}
|
||||
|
||||
public set ModelClient(value: WorkspaceModels) {
|
||||
this._modelClient = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for the extension
|
||||
* Execute the background task to download the artifact
|
||||
*/
|
||||
public async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
|
||||
private async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
|
||||
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
|
||||
return results && results.length > 0 ? results[0] : constants.noResultError;
|
||||
}
|
||||
@@ -139,15 +153,14 @@ export class AzureModelRegistryService {
|
||||
|
||||
try {
|
||||
for (const tenant of account.properties.tenants) {
|
||||
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
|
||||
const token = tokens[tenant.id].token;
|
||||
const tokenType = tokens[tenant.id].tokenType;
|
||||
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id);
|
||||
const client = await this.getAmlClient(account, subscription, tenant);
|
||||
let result = resourceGroup ? await client.workspaces.listByResourceGroup(resourceGroup.name) : await client.workspaces.listBySubscription();
|
||||
resources.push(...result);
|
||||
if (result) {
|
||||
resources.push(...result);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
console.log(error);
|
||||
}
|
||||
return resources;
|
||||
}
|
||||
@@ -161,9 +174,11 @@ export class AzureModelRegistryService {
|
||||
|
||||
for (const tenant of account.properties.tenants) {
|
||||
try {
|
||||
let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
|
||||
const client = await this.getClient(baseUri, account, subscription, tenant);
|
||||
let modelsClient = new WorkspaceModels(client);
|
||||
let options: AzureMachineLearningWorkspacesOptions = {
|
||||
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
|
||||
};
|
||||
const client = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
|
||||
let modelsClient = this.getModelClient(client);
|
||||
resources = resources.concat(await modelsClient.listModels(resourceGroup.name, workspace.name || ''));
|
||||
|
||||
} catch (error) {
|
||||
@@ -182,22 +197,28 @@ export class AzureModelRegistryService {
|
||||
client: AzureMachineLearningWorkspaces): Promise<Asset> {
|
||||
|
||||
const modelId = this.getModelId(model);
|
||||
let modelsClient = new Assets(client);
|
||||
return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
|
||||
if (modelId) {
|
||||
let modelsClient = new Assets(client);
|
||||
return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
|
||||
} else {
|
||||
throw Error(constants.invalidModelIdError(model.url));
|
||||
}
|
||||
}
|
||||
|
||||
public async getAssetArtifactsDownloadLinks(
|
||||
private async getAssetArtifactsDownloadLinks(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
resourceGroup: azureResource.AzureResource,
|
||||
workspace: Workspace,
|
||||
model: WorkspaceModel,
|
||||
tenant: any): Promise<string[]> {
|
||||
let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
|
||||
const modelManagementClient = await this.getClient(baseUri, account, subscription, tenant);
|
||||
let options: AzureMachineLearningWorkspacesOptions = {
|
||||
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
|
||||
};
|
||||
const modelManagementClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
|
||||
const asset = await this.fetchModelAsset(subscription, resourceGroup, workspace, model, modelManagementClient);
|
||||
baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
|
||||
const experienceClient = await this.getClient(baseUri, account, subscription, tenant);
|
||||
options.baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
|
||||
const experienceClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
|
||||
const artifactClient = new Artifacts(experienceClient);
|
||||
let downloadLinks: string[] = [];
|
||||
if (asset && asset.artifacts) {
|
||||
@@ -230,17 +251,19 @@ export class AzureModelRegistryService {
|
||||
downloadLinkPromises.push(promise);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
downloadLinks = await Promise.all(downloadLinkPromises);
|
||||
} catch (rejectedPromiseError) {
|
||||
return rejectedPromiseError;
|
||||
}
|
||||
return downloadLinks;
|
||||
|
||||
} else {
|
||||
throw Error(constants.noArtifactError(model.url));
|
||||
}
|
||||
return downloadLinks;
|
||||
}
|
||||
|
||||
public getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
|
||||
private getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
|
||||
const artifactRegex = /^(.+?)\/(.+?)\/(.+?)$/;
|
||||
if (idOrPrefix) {
|
||||
const parts = artifactRegex.exec(idOrPrefix);
|
||||
@@ -263,16 +286,35 @@ export class AzureModelRegistryService {
|
||||
return baseUri;
|
||||
}
|
||||
|
||||
private async getClient(baseUri: string, account: azdata.Account, subscription: azureResource.AzureResourceSubscription, tenant: any): Promise<AzureMachineLearningWorkspaces> {
|
||||
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
|
||||
const token = tokens[tenant.id].token;
|
||||
const tokenType = tokens[tenant.id].tokenType;
|
||||
const options: AzureMachineLearningWorkspacesOptions = {
|
||||
baseUri: baseUri
|
||||
};
|
||||
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
|
||||
client.apiVersion = this._config.amlApiVersion;
|
||||
return client;
|
||||
private getModelClient(amlClient: AzureMachineLearningWorkspaces) {
|
||||
return this._modelClient ?? new WorkspaceModels(amlClient);
|
||||
}
|
||||
|
||||
private async getAmlClient(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
tenant: any,
|
||||
options: AzureMachineLearningWorkspacesOptions | undefined = undefined,
|
||||
apiVersion: string | undefined = undefined): Promise<AzureMachineLearningWorkspaces> {
|
||||
if (this._amlClient) {
|
||||
return this._amlClient;
|
||||
} else {
|
||||
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
|
||||
let token: string = '';
|
||||
let tokenType: string | undefined = undefined;
|
||||
if (tokens && tenant.id in tokens) {
|
||||
const tokenForId = tokens[tenant.id];
|
||||
if (tokenForId) {
|
||||
token = tokenForId.token;
|
||||
tokenType = tokenForId.tokenType;
|
||||
}
|
||||
}
|
||||
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
|
||||
if (apiVersion) {
|
||||
client.apiVersion = apiVersion;
|
||||
}
|
||||
return client;
|
||||
}
|
||||
}
|
||||
|
||||
private getModelId(model: WorkspaceModel): string {
|
||||
|
||||
@@ -9,73 +9,85 @@ import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as utils from '../common/utils';
|
||||
import { Config } from '../configurations/config';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { RegisteredModel, RegisteredModelDetails } from './interfaces';
|
||||
import { ModelImporter } from './modelImporter';
|
||||
import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces';
|
||||
import { ModelPythonClient } from './modelPythonClient';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
/**
|
||||
* Service to registered models
|
||||
* Service to deployed models
|
||||
*/
|
||||
export class RegisteredModelService {
|
||||
export class DeployedModelService {
|
||||
|
||||
/**
|
||||
*
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _config: Config,
|
||||
private _queryRunner: QueryRunner,
|
||||
private _modelImporter: ModelImporter) {
|
||||
private _modelClient: ModelPythonClient) {
|
||||
}
|
||||
|
||||
public async getRegisteredModels(): Promise<RegisteredModel[]> {
|
||||
/**
|
||||
* Returns deployed models
|
||||
*/
|
||||
public async getDeployedModels(): Promise<RegisteredModel[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: RegisteredModel[] = [];
|
||||
if (connection) {
|
||||
let query = this.getConfigureQuery(connection.databaseName);
|
||||
await this._queryRunner.safeRunQuery(connection, query);
|
||||
query = this.registeredModelsQuery();
|
||||
query = this.getDeployedModelsQuery();
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push(this.loadModelData(row));
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw Error(constants.noConnectionError);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
||||
return {
|
||||
id: +row[0].displayValue,
|
||||
artifactName: row[1].displayValue,
|
||||
title: row[2].displayValue,
|
||||
description: row[3].displayValue,
|
||||
version: row[4].displayValue,
|
||||
created: row[5].displayValue
|
||||
};
|
||||
}
|
||||
|
||||
public async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
|
||||
/**
|
||||
* Downloads model
|
||||
* @param model model object
|
||||
*/
|
||||
public async downloadModel(model: RegisteredModel): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let updatedModel: RegisteredModel | undefined = undefined;
|
||||
if (connection) {
|
||||
const query = this.getUpdateModelScript(connection.databaseName, model);
|
||||
const query = this.getModelContentQuery(model);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
const row = result.rows[0];
|
||||
updatedModel = this.loadModelData(row);
|
||||
const content = result.rows[0][0].displayValue;
|
||||
return await utils.writeFileFromHex(content);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToSelectError);
|
||||
}
|
||||
} else {
|
||||
throw Error(constants.noConnectionError);
|
||||
}
|
||||
return updatedModel;
|
||||
}
|
||||
|
||||
public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
|
||||
/**
|
||||
* Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(filePath: string): Promise<ModelParameters> {
|
||||
return await this._modelClient.loadModelParameters(filePath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deploys local model
|
||||
* @param filePath model file path
|
||||
* @param details model details
|
||||
*/
|
||||
public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
let currentModels = await this.getRegisteredModels();
|
||||
await this._modelImporter.registerModel(connection, filePath);
|
||||
let updatedModels = await this.getRegisteredModels();
|
||||
let currentModels = await this.getDeployedModels();
|
||||
await this._modelClient.deployModel(connection, filePath);
|
||||
let updatedModels = await this.getDeployedModels();
|
||||
if (details && updatedModels.length >= currentModels.length + 1) {
|
||||
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
|
||||
const addedModel = updatedModels[updatedModels.length - 1];
|
||||
@@ -92,16 +104,40 @@ export class RegisteredModelService {
|
||||
}
|
||||
}
|
||||
}
|
||||
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
||||
return {
|
||||
id: +row[0].displayValue,
|
||||
artifactName: row[1].displayValue,
|
||||
title: row[2].displayValue,
|
||||
description: row[3].displayValue,
|
||||
version: row[4].displayValue,
|
||||
created: row[5].displayValue
|
||||
};
|
||||
}
|
||||
|
||||
private async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let updatedModel: RegisteredModel | undefined = undefined;
|
||||
if (connection) {
|
||||
const query = this.getUpdateModelQuery(connection.databaseName, model);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result?.rows && result.rows.length > 0) {
|
||||
const row = result.rows[0];
|
||||
updatedModel = this.loadModelData(row);
|
||||
}
|
||||
}
|
||||
return updatedModel;
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
private getConfigureQuery(currentDatabaseName: string): string {
|
||||
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable());
|
||||
public getConfigureQuery(currentDatabaseName: string): string {
|
||||
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.getConfigureTableQuery());
|
||||
}
|
||||
|
||||
private registeredModelsQuery(): string {
|
||||
public getDeployedModelsQuery(): string {
|
||||
return `
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
@@ -116,7 +152,7 @@ export class RegisteredModelService {
|
||||
* @param databaseName
|
||||
* @param tableName
|
||||
*/
|
||||
private configureTable(): string {
|
||||
public getConfigureTableQuery(): string {
|
||||
let databaseName = this._config.registeredModelDatabaseName;
|
||||
let tableName = this._config.registeredModelTableName;
|
||||
let schemaName = this._config.registeredModelTableSchemaName;
|
||||
@@ -171,7 +207,7 @@ export class RegisteredModelService {
|
||||
`;
|
||||
}
|
||||
|
||||
private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string {
|
||||
public getUpdateModelQuery(currentDatabaseName: string, model: RegisteredModel): string {
|
||||
let updateScript = `
|
||||
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
|
||||
SET
|
||||
@@ -187,4 +223,12 @@ export class RegisteredModelService {
|
||||
WHERE artifact_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
|
||||
public getModelContentQuery(model: RegisteredModel): string {
|
||||
return `
|
||||
SELECT artifact_content
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
WHERE artifact_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
}
|
||||
@@ -53,6 +53,16 @@ export interface RegisteredModel extends RegisteredModelDetails {
|
||||
artifactName: string;
|
||||
}
|
||||
|
||||
export interface ModelParameter {
|
||||
name: string;
|
||||
type: string;
|
||||
}
|
||||
|
||||
export interface ModelParameters {
|
||||
inputs: ModelParameter[],
|
||||
outputs: ModelParameter[]
|
||||
}
|
||||
|
||||
/**
|
||||
* An interface representing registered model
|
||||
*/
|
||||
|
||||
@@ -13,33 +13,89 @@ import * as utils from '../common/utils';
|
||||
import { PackageManager } from '../packageManagement/packageManager';
|
||||
import * as constants from '../common/constants';
|
||||
import * as os from 'os';
|
||||
import { ModelParameters } from './interfaces';
|
||||
|
||||
/**
|
||||
* Service to import model to database
|
||||
* Python client for ONNX models
|
||||
*/
|
||||
export class ModelImporter {
|
||||
export class ModelPythonClient {
|
||||
|
||||
/**
|
||||
*
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
|
||||
}
|
||||
|
||||
public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
/**
|
||||
* Deploys models in the SQL database using mlflow
|
||||
* @param connection
|
||||
* @param modelPath
|
||||
*/
|
||||
public async deployModel(connection: azdata.connection.ConnectionProfile, modelPath: string): Promise<void> {
|
||||
await this.installDependencies();
|
||||
await this.executeScripts(connection, modelFolderPath);
|
||||
await this.executeDeployScripts(connection, modelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for model importer
|
||||
* Installs dependencies for python client
|
||||
*/
|
||||
public async installDependencies(): Promise<void> {
|
||||
private async installDependencies(): Promise<void> {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
|
||||
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
|
||||
}
|
||||
|
||||
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
/**
|
||||
*
|
||||
* @param modelPath Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(modelPath: string): Promise<ModelParameters> {
|
||||
await this.installDependencies();
|
||||
return await this.executeModelParametersScripts(modelPath);
|
||||
}
|
||||
|
||||
private async executeModelParametersScripts(modelFolderPath: string): Promise<ModelParameters> {
|
||||
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
|
||||
|
||||
let scripts: string[] = [
|
||||
'import onnx',
|
||||
'import json',
|
||||
`onnx_model_path = '${modelFolderPath}'`,
|
||||
`onnx_model = onnx.load_model(onnx_model_path)`,
|
||||
`type_map = {
|
||||
onnx.TensorProto.DataType.FLOAT: 'real',
|
||||
onnx.TensorProto.DataType.UINT8: 'tinyint',
|
||||
onnx.TensorProto.DataType.INT16: 'smallint',
|
||||
onnx.TensorProto.DataType.INT32: 'int',
|
||||
onnx.TensorProto.DataType.INT64: 'bigint',
|
||||
onnx.TensorProto.DataType.STRING: 'varchar(MAX)',
|
||||
onnx.TensorProto.DataType.DOUBLE: 'float'}`,
|
||||
`parameters = {
|
||||
"inputs": [],
|
||||
"outputs": []
|
||||
}`,
|
||||
`def addParameters(list, paramType):
|
||||
for id, p in enumerate(list):
|
||||
p_type = ''
|
||||
|
||||
if p.type.tensor_type.elem_type in type_map:
|
||||
p_type = type_map[p.type.tensor_type.elem_type]
|
||||
|
||||
parameters[paramType].append({
|
||||
'name': p.name,
|
||||
'type': p_type
|
||||
})`,
|
||||
|
||||
'addParameters(onnx_model.graph.input, "inputs")',
|
||||
'addParameters(onnx_model.graph.output, "outputs")',
|
||||
'print(json.dumps(parameters))'
|
||||
];
|
||||
let pythonExecutable = this._config.pythonExecutable;
|
||||
let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined);
|
||||
let parametersJson = JSON.parse(output);
|
||||
return Object.assign({}, parametersJson);
|
||||
}
|
||||
|
||||
private async executeDeployScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
let home = utils.makeLinuxPath(os.homedir());
|
||||
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
|
||||
|
||||
@@ -30,7 +30,7 @@ export abstract class SqlPackageManageProviderBase {
|
||||
if (connection) {
|
||||
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`;
|
||||
}
|
||||
return constants.packageManagerNoConnection;
|
||||
return constants.noConnectionError;
|
||||
}
|
||||
|
||||
protected async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
export interface PredictColumn {
|
||||
name: string;
|
||||
export interface TableColumn {
|
||||
columnName: string;
|
||||
dataType?: string;
|
||||
displayName?: string;
|
||||
}
|
||||
|
||||
export interface PredictColumn extends TableColumn {
|
||||
paramName?: string;
|
||||
}
|
||||
|
||||
export interface DatabaseTable {
|
||||
|
||||
@@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import * as utils from '../common/utils';
|
||||
import { RegisteredModel } from '../modelManagement/interfaces';
|
||||
import { PredictParameters, PredictColumn, DatabaseTable } from '../prediction/interfaces';
|
||||
import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
|
||||
import { Config } from '../configurations/config';
|
||||
|
||||
/**
|
||||
@@ -98,15 +98,18 @@ export class PredictService {
|
||||
*Returns list of column names of a database
|
||||
* @param databaseTable table info
|
||||
*/
|
||||
public async getTableColumnsList(databaseTable: DatabaseTable): Promise<string[]> {
|
||||
public async getTableColumnsList(databaseTable: DatabaseTable): Promise<TableColumn[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: string[] = [];
|
||||
let list: TableColumn[] = [];
|
||||
if (connection && databaseTable.databaseName) {
|
||||
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push(row[0].displayValue);
|
||||
list.push({
|
||||
columnName: row[0].displayValue,
|
||||
dataType: row[1].displayValue
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -119,7 +122,7 @@ export class PredictService {
|
||||
|
||||
private getTableColumnsScript(databaseTable: DatabaseTable): string {
|
||||
return `
|
||||
SELECT COLUMN_NAME,*
|
||||
SELECT COLUMN_NAME,DATA_TYPE
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
|
||||
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
|
||||
@@ -149,14 +152,14 @@ DECLARE @model VARBINARY(max) = (
|
||||
WITH predict_input
|
||||
AS (
|
||||
SELECT TOP 1000
|
||||
${this.getColumnNames(columns, 'pi')}
|
||||
${this.getInputColumnNames(columns, 'pi')}
|
||||
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
||||
)
|
||||
SELECT
|
||||
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
|
||||
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')}
|
||||
FROM PREDICT(MODEL = @model, DATA = predict_input)
|
||||
WITH (
|
||||
${this.getColumnTypes(outputColumns)}
|
||||
${this.getOutputParameters(outputColumns)}
|
||||
) AS p
|
||||
`;
|
||||
}
|
||||
@@ -170,33 +173,43 @@ WITH (
|
||||
WITH predict_input
|
||||
AS (
|
||||
SELECT TOP 1000
|
||||
${this.getColumnNames(columns, 'pi')}
|
||||
${this.getInputColumnNames(columns, 'pi')}
|
||||
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
||||
)
|
||||
SELECT
|
||||
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
|
||||
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getOutputColumnNames(outputColumns, 'p')}
|
||||
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input)
|
||||
WITH (
|
||||
${this.getColumnTypes(outputColumns)}
|
||||
${this.getOutputParameters(outputColumns)}
|
||||
) AS p
|
||||
`;
|
||||
}
|
||||
|
||||
private getColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return c.displayName ? `${tableName}.${c.name} AS ${c.displayName}` : `${tableName}.${c.name}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return c.displayName ? `${tableName}.${c.displayName}` : `${tableName}.${c.name}`;
|
||||
return this.getColumnName(tableName, c.paramName || '', c.columnName);
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getColumnTypes(columns: PredictColumn[]) {
|
||||
private getOutputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return `${c.name} ${c.dataType}`;
|
||||
return this.getColumnName(tableName, c.columnName, c.paramName || '');
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getColumnName(tableName: string, columnName: string, displayName: string) {
|
||||
return columnName && columnName !== displayName ? `${tableName}.${columnName} AS ${displayName}` : `${tableName}.${columnName}`;
|
||||
}
|
||||
|
||||
private getPredictColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return c.paramName ? `${tableName}.${c.paramName}` : `${tableName}.${c.columnName}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getOutputParameters(columns: PredictColumn[]) {
|
||||
return columns.map(c => {
|
||||
return `${c.paramName} ${c.dataType}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { HttpClient } from '../../common/httpClient';
|
||||
import { azureResource } from '../../typings/azure-resource';
|
||||
|
||||
import * as utils from '../utils';
|
||||
import { Workspace, WorkspacesListByResourceGroupResponse } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { WorkspaceModel, AssetsQueryByIdResponse, Asset, GetArtifactContentInformation2Response } from '../../modelManagement/interfaces';
|
||||
import { AzureMachineLearningWorkspaces, Workspaces } from '@azure/arm-machinelearningservices';
|
||||
import { WorkspaceModels } from '../../modelManagement/workspacesModels';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
httpClient: TypeMoq.IMock<HttpClient>;
|
||||
outputChannel: vscode.OutputChannel;
|
||||
op: azdata.BackgroundOperation;
|
||||
accounts: azdata.Account[];
|
||||
subscriptions: azureResource.AzureResourceSubscription[];
|
||||
groups: azureResource.AzureResourceResourceGroup[];
|
||||
workspaces: Workspace[];
|
||||
models: WorkspaceModel[];
|
||||
client: TypeMoq.IMock<AzureMachineLearningWorkspaces>;
|
||||
workspacesClient: TypeMoq.IMock<Workspaces>;
|
||||
modelClient: TypeMoq.IMock<WorkspaceModels>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
const context = utils.createContext();
|
||||
const workspaces = TypeMoq.Mock.ofType(Workspaces);
|
||||
const credentials = {
|
||||
signRequest: () => {
|
||||
return Promise.resolve(undefined!!);
|
||||
}
|
||||
};
|
||||
const client = TypeMoq.Mock.ofInstance(new AzureMachineLearningWorkspaces(credentials, 'subscription'));
|
||||
client.setup(x => x.apiVersion).returns(() => '20180101');
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
httpClient: TypeMoq.Mock.ofType(HttpClient),
|
||||
outputChannel: context.outputChannel,
|
||||
op: context.op,
|
||||
accounts: [
|
||||
{
|
||||
key: {
|
||||
providerId: '',
|
||||
accountId: 'a1'
|
||||
},
|
||||
displayInfo: {
|
||||
contextualDisplayName: '',
|
||||
accountType: '',
|
||||
displayName: 'a1',
|
||||
userId: 'a1'
|
||||
},
|
||||
properties:
|
||||
{
|
||||
tenants: [
|
||||
{
|
||||
id: '1',
|
||||
}
|
||||
]
|
||||
}
|
||||
,
|
||||
isStale: true
|
||||
}
|
||||
],
|
||||
subscriptions: [
|
||||
{
|
||||
name: 's1',
|
||||
id: 's1'
|
||||
}
|
||||
],
|
||||
groups: [
|
||||
{
|
||||
name: 'g1',
|
||||
id: 'g1'
|
||||
}
|
||||
],
|
||||
workspaces: [{
|
||||
name: 'w1',
|
||||
id: 'w1'
|
||||
}
|
||||
],
|
||||
models: [
|
||||
{
|
||||
name: 'm1',
|
||||
id: 'm1',
|
||||
url: 'aml://asset/test.test'
|
||||
}
|
||||
],
|
||||
client: client,
|
||||
workspacesClient: workspaces,
|
||||
modelClient: TypeMoq.Mock.ofInstance(new WorkspaceModels(client.object))
|
||||
};
|
||||
}
|
||||
|
||||
describe('AzureModelRegistryService', () => {
|
||||
it('getAccounts should return the list of accounts successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const accounts = testContext.accounts;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.getAllAccounts()).returns(() => Promise.resolve(accounts));
|
||||
let actual = await service.getAccounts();
|
||||
should.deepEqual(actual, testContext.accounts);
|
||||
});
|
||||
|
||||
it('getSubscriptions should return the list of subscriptions successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const expected = testContext.subscriptions;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ subscriptions: expected, errors: [] }));
|
||||
let actual = await service.getSubscriptions(testContext.accounts[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getGroups should return the list of groups successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const expected = testContext.groups;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ resourceGroups: expected, errors: [] }));
|
||||
let actual = await service.getGroups(testContext.accounts[0], testContext.subscriptions[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getWorkspaces should return the list of workspaces successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const response: WorkspacesListByResourceGroupResponse = Object.assign(new Array<Workspace>(...testContext.workspaces), {
|
||||
_response: undefined!
|
||||
});
|
||||
const expected = testContext.workspaces;
|
||||
testContext.workspacesClient.setup(x => x.listByResourceGroup(TypeMoq.It.isAny())).returns(() => Promise.resolve(response));
|
||||
testContext.workspacesClient.setup(x => x.listBySubscription()).returns(() => Promise.resolve(response));
|
||||
testContext.client.setup(x => x.workspaces).returns(() => testContext.workspacesClient.object);
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
|
||||
|
||||
service.AzureMachineLearningClient = testContext.client.object;
|
||||
let actual = await service.getWorkspaces(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getModels should return the list of models successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
|
||||
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
|
||||
const expected = testContext.models;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
service.AzureMachineLearningClient = testContext.client.object;
|
||||
service.ModelClient = testContext.modelClient.object;
|
||||
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
|
||||
let actual = await service.getModels(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('downloadModel should download model artifact successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const asset: Asset =
|
||||
{
|
||||
id: '1',
|
||||
name: 'asset',
|
||||
artifacts: [
|
||||
{
|
||||
id: '/1/2/3/4/5/'
|
||||
}
|
||||
]
|
||||
};
|
||||
const assetResponse: AssetsQueryByIdResponse = Object.assign(asset, {
|
||||
_response: undefined!
|
||||
});
|
||||
const artifactResponse: GetArtifactContentInformation2Response = Object.assign({
|
||||
contentUri: 'downloadUrl'
|
||||
}, {
|
||||
_response: undefined!
|
||||
});
|
||||
|
||||
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
|
||||
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
|
||||
testContext.config.setup(x => x.amlExperienceUrl).returns(() => 'test.url');
|
||||
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
|
||||
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('modelmanagement')), TypeMoq.It.isAny())).returns(() => Promise.resolve(assetResponse));
|
||||
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
|
||||
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('artifact')), TypeMoq.It.isAny())).returns(() => Promise.resolve(artifactResponse));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
service.AzureMachineLearningClient = testContext.client.object;
|
||||
service.ModelClient = testContext.modelClient.object;
|
||||
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
|
||||
let actual = await service.downloadModel(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0], testContext.models[0]);
|
||||
should.notEqual(actual, undefined);
|
||||
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,410 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as utils from '../../common/utils';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { DeployedModelService } from '../../modelManagement/deployedModelService';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as fs from 'fs';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
modelClient: TypeMoq.IMock<ModelPythonClient>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
||||
modelClient: TypeMoq.Mock.ofType(ModelPythonClient)
|
||||
};
|
||||
}
|
||||
|
||||
describe('DeployedModelService', () => {
|
||||
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
let connection: azdata.connection.ConnectionProfile;
|
||||
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
await should(service.getDeployedModels()).rejected();
|
||||
});
|
||||
|
||||
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const expected: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
}
|
||||
];
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [
|
||||
[
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'name1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'title1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'desc1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1.1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '2018-01-01',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]
|
||||
]
|
||||
};
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
const actual = await service.getDeployedModels();
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('loadModelParameters should load parameters using python client successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
},
|
||||
]
|
||||
};
|
||||
testContext.modelClient.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(expected));
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
const actual = await service.loadModelParameters('');
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('downloadModel should download model successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, 'test');
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [
|
||||
[
|
||||
{
|
||||
displayValue: await utils.readFileInHex(tempFilePath),
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]
|
||||
]
|
||||
};
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const actual = await service.downloadModel(model);
|
||||
should.notEqual(actual, undefined);
|
||||
});
|
||||
|
||||
it('deployLocalModel should returns models successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
const row = [
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'name1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'title1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'desc1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1.1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '2018-01-01',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
];
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [row]
|
||||
};
|
||||
let updatedResult = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [row, row]
|
||||
};
|
||||
let deployed = false;
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.modelClient.setup(x => x.deployModel(connection, '')).returns(() => {
|
||||
deployed = true;
|
||||
return Promise.resolve();
|
||||
});
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
|
||||
});
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
await should(service.deployLocalModel('', model)).resolved();
|
||||
});
|
||||
|
||||
it('getConfigureQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const dbName = 'curre[n]tDb';
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const expected = `
|
||||
IF NOT EXISTS (
|
||||
SELECT [name]
|
||||
FROM sys.databases
|
||||
WHERE [name] = N'd[]b'
|
||||
)
|
||||
CREATE DATABASE [d[[]]b]
|
||||
GO
|
||||
USE [d[[]]b]
|
||||
IF EXISTS
|
||||
( SELECT [t.name], [s.name]
|
||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||
WHERE [t.name] = 'ta[b]le'
|
||||
AND [s.name] = 'dbo'
|
||||
)
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [version] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='created')
|
||||
BEGIN
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [created] [datetime] NULL
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||
END
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='description')
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [description] [varchar](256) NULL
|
||||
END
|
||||
Else
|
||||
BEGIN
|
||||
CREATE TABLE [dbo].[ta[[b]]le](
|
||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[artifact_name] [varchar](256) NOT NULL,
|
||||
[group_path] [varchar](256) NOT NULL,
|
||||
[artifact_content] [varbinary](max) NOT NULL,
|
||||
[artifact_initial_size] [bigint] NULL,
|
||||
[name] [varchar](256) NULL,
|
||||
[version] [varchar](256) NULL,
|
||||
[created] [datetime] NULL,
|
||||
[description] [varchar](256) NULL,
|
||||
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
|
||||
(
|
||||
[artifact_id] ASC
|
||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
||||
END
|
||||
`;
|
||||
const actual = service.getConfigureQuery(dbName);
|
||||
should.equal(actual.indexOf(expected) > 0, true);
|
||||
});
|
||||
|
||||
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const expected = `
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||
Order by artifact_id
|
||||
`;
|
||||
const actual = service.getDeployedModelsQuery();
|
||||
should.deepEqual(expected, actual);
|
||||
});
|
||||
|
||||
it('getUpdateModelQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const dbName = 'curre[n]tDb';
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const expected = `
|
||||
UPDATE [dbo].[ta[[b]]le]
|
||||
SET
|
||||
name = 'title1',
|
||||
version = '1.1',
|
||||
description = 'desc1'
|
||||
WHERE artifact_id = 1`;
|
||||
const actual = service.getUpdateModelQuery(dbName, model);
|
||||
should.equal(actual.indexOf(expected) > 0, true);
|
||||
//should.deepEqual(actual, expected);
|
||||
|
||||
});
|
||||
|
||||
it('getModelContentQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const expected = `
|
||||
SELECT artifact_content
|
||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||
WHERE artifact_id = 1;
|
||||
`;
|
||||
const actual = service.getModelContentQuery(model);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,121 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { Config } from '../../configurations/config';
|
||||
|
||||
import * as utils from '../utils';
|
||||
import { ProcessService } from '../../common/processService';
|
||||
import { PackageManager } from '../../packageManagement/packageManager';
|
||||
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
outputChannel: vscode.OutputChannel;
|
||||
op: azdata.BackgroundOperation;
|
||||
processService: TypeMoq.IMock<ProcessService>;
|
||||
packageManager: TypeMoq.IMock<PackageManager>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
const context = utils.createContext();
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
outputChannel: context.outputChannel,
|
||||
op: context.op,
|
||||
processService: TypeMoq.Mock.ofType(ProcessService),
|
||||
packageManager: TypeMoq.Mock.ofType(PackageManager)
|
||||
};
|
||||
}
|
||||
|
||||
describe('ModelPythonClient', () => {
|
||||
it('deployModel should deploy the model successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
const modelPath = 'C:\\test';
|
||||
let service = new ModelPythonClient(
|
||||
testContext.outputChannel,
|
||||
testContext.apiWrapper.object,
|
||||
testContext.processService.object,
|
||||
testContext.config.object,
|
||||
testContext.packageManager.object);
|
||||
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
||||
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
|
||||
|
||||
await service.deployModel(connection, modelPath);
|
||||
});
|
||||
|
||||
it('loadModelParameters should load model parameters successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const modelPath = 'C:\\test';
|
||||
const expected = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
},
|
||||
]
|
||||
};
|
||||
const parametersJson = `
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"name": "p1",
|
||||
"type": "int"
|
||||
},
|
||||
{
|
||||
"name": "p2",
|
||||
"type": "varchar"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "o1",
|
||||
"type": "int"
|
||||
}
|
||||
]
|
||||
}
|
||||
`;
|
||||
let service = new ModelPythonClient(
|
||||
testContext.outputChannel,
|
||||
testContext.apiWrapper.object,
|
||||
testContext.processService.object,
|
||||
testContext.config.object,
|
||||
testContext.packageManager.object);
|
||||
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
||||
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
|
||||
const actual = await service.loadModelParameters(modelPath);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
});
|
||||
@@ -354,7 +354,7 @@ describe('SQL Python Package Manager', () => {
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getLocationTitle();
|
||||
|
||||
should.deepEqual(actual, constants.packageManagerNoConnection);
|
||||
should.deepEqual(actual, constants.noConnectionError);
|
||||
});
|
||||
|
||||
it('getLocationTitle Should return connection title string for valid connection', async function (): Promise<void> {
|
||||
|
||||
@@ -279,7 +279,7 @@ describe('SQL R Package Manager', () => {
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getLocationTitle();
|
||||
|
||||
should.deepEqual(actual, constants.packageManagerNoConnection);
|
||||
should.deepEqual(actual, constants.noConnectionError);
|
||||
});
|
||||
|
||||
it('getLocationTitle Should return connection title string for valid connection', async function (): Promise<void> {
|
||||
|
||||
@@ -11,6 +11,7 @@ import { QueryRunner } from '../../common/queryRunner';
|
||||
import { ProcessService } from '../../common/processService';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { HttpClient } from '../../common/httpClient';
|
||||
import * as utils from '../utils';
|
||||
import { PackageManagementService } from '../../packageManagement/packageManagementService';
|
||||
|
||||
export interface TestContext {
|
||||
@@ -27,31 +28,18 @@ export interface TestContext {
|
||||
}
|
||||
|
||||
export function createContext(): TestContext {
|
||||
let opStatus: azdata.TaskStatus;
|
||||
const context = utils.createContext();
|
||||
|
||||
return {
|
||||
outputChannel: {
|
||||
name: '',
|
||||
append: () => { },
|
||||
appendLine: () => { },
|
||||
clear: () => { },
|
||||
show: () => { },
|
||||
hide: () => { },
|
||||
dispose: () => { }
|
||||
},
|
||||
|
||||
outputChannel: context.outputChannel,
|
||||
processService: TypeMoq.Mock.ofType(ProcessService),
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
httpClient: TypeMoq.Mock.ofType(HttpClient),
|
||||
op: {
|
||||
updateStatus: (status: azdata.TaskStatus) => {
|
||||
opStatus = status;
|
||||
},
|
||||
id: '',
|
||||
onCanceled: new vscode.EventEmitter<void>().event,
|
||||
},
|
||||
getOpStatus: () => { return opStatus; },
|
||||
op: context.op,
|
||||
getOpStatus: context.getOpStatus,
|
||||
serverConfigManager: TypeMoq.Mock.ofType(PackageManagementService)
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,303 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { PredictService } from '../../prediction/predictService';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as fs from 'fs';
|
||||
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
|
||||
};
|
||||
}
|
||||
|
||||
describe('PredictService', () => {
|
||||
|
||||
it('getDatabaseList should return databases successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected: string[] = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.listDatabases(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(expected); });
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
const actual = await service.getDatabaseList();
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getTableList should return tables successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected: DatabaseTable[] = [
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
},
|
||||
{
|
||||
databaseName: 'db1',
|
||||
tableName: 'tb2',
|
||||
schema: 'dbo'
|
||||
}
|
||||
];
|
||||
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [[
|
||||
{
|
||||
displayValue: 'tb1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'dbo',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
], [
|
||||
{
|
||||
displayValue: 'tb2',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'dbo',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]]
|
||||
};
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
const actual = await service.getTableList('db1');
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getTableColumnsList should return table columns successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected: TableColumn[] = [
|
||||
{
|
||||
columnName: 'c1',
|
||||
dataType: 'int'
|
||||
},
|
||||
{
|
||||
columnName: 'c2',
|
||||
dataType: 'varchar'
|
||||
}
|
||||
];
|
||||
const table: DatabaseTable =
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
};
|
||||
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [[
|
||||
{
|
||||
displayValue: 'c1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'int',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
], [
|
||||
{
|
||||
displayValue: 'c2',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'varchar',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]]
|
||||
};
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
const actual = await service.getTableColumnsList(table);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('generatePredictScript should generate the script successfully using model', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const predictParams: PredictParameters = {
|
||||
inputColumns: [
|
||||
{
|
||||
paramName: 'p1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
{
|
||||
paramName: 'p2',
|
||||
dataType: 'varchar',
|
||||
columnName: ''
|
||||
}
|
||||
],
|
||||
outputColumns: [
|
||||
{
|
||||
paramName: 'o1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
],
|
||||
databaseName: '',
|
||||
tableName: '',
|
||||
schema: ''
|
||||
};
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
|
||||
const document: vscode.TextDocument = {
|
||||
uri: vscode.Uri.parse('file:///usr/home'),
|
||||
fileName: '',
|
||||
isUntitled: true,
|
||||
languageId: 'sql',
|
||||
version: 1,
|
||||
isDirty: true,
|
||||
isClosed: false,
|
||||
save: undefined!,
|
||||
eol: undefined!,
|
||||
lineCount: 1,
|
||||
lineAt: undefined!,
|
||||
offsetAt: undefined!,
|
||||
positionAt: undefined!,
|
||||
getText: undefined!,
|
||||
getWordRangeAtPosition: undefined!,
|
||||
validateRange: undefined!,
|
||||
validatePosition: undefined!
|
||||
};
|
||||
testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
|
||||
testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
|
||||
|
||||
const actual = await service.generatePredictScript(predictParams, model, undefined);
|
||||
should.notEqual(actual, undefined);
|
||||
should.equal(actual.indexOf('FROM PREDICT(MODEL = @model') > 0, true);
|
||||
});
|
||||
|
||||
it('generatePredictScript should generate the script successfully using file', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const predictParams: PredictParameters = {
|
||||
inputColumns: [
|
||||
{
|
||||
paramName: 'p1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
{
|
||||
paramName: 'p2',
|
||||
dataType: 'varchar',
|
||||
columnName: ''
|
||||
}
|
||||
],
|
||||
outputColumns: [
|
||||
{
|
||||
paramName: 'o1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
],
|
||||
databaseName: '',
|
||||
tableName: '',
|
||||
schema: ''
|
||||
};
|
||||
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, 'test');
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
|
||||
const document: vscode.TextDocument = {
|
||||
uri: vscode.Uri.parse('file:///usr/home'),
|
||||
fileName: '',
|
||||
isUntitled: true,
|
||||
languageId: 'sql',
|
||||
version: 1,
|
||||
isDirty: true,
|
||||
isClosed: false,
|
||||
save: undefined!,
|
||||
eol: undefined!,
|
||||
lineCount: 1,
|
||||
lineAt: undefined!,
|
||||
offsetAt: undefined!,
|
||||
positionAt: undefined!,
|
||||
getText: undefined!,
|
||||
getWordRangeAtPosition: undefined!,
|
||||
validateRange: undefined!,
|
||||
validatePosition: undefined!
|
||||
};
|
||||
testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
|
||||
testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
|
||||
|
||||
const actual = await service.generatePredictScript(predictParams, undefined, tempFilePath);
|
||||
should.notEqual(actual, undefined);
|
||||
should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true);
|
||||
});
|
||||
});
|
||||
38
extensions/machine-learning-services/src/test/utils.ts
Normal file
38
extensions/machine-learning-services/src/test/utils.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
export interface TestContext {
|
||||
|
||||
outputChannel: vscode.OutputChannel;
|
||||
op: azdata.BackgroundOperation;
|
||||
getOpStatus: () => azdata.TaskStatus;
|
||||
}
|
||||
|
||||
export function createContext(): TestContext {
|
||||
let opStatus: azdata.TaskStatus;
|
||||
|
||||
return {
|
||||
outputChannel: {
|
||||
name: '',
|
||||
append: () => { },
|
||||
appendLine: () => { },
|
||||
clear: () => { },
|
||||
show: () => { },
|
||||
hide: () => { },
|
||||
dispose: () => { }
|
||||
},
|
||||
op: {
|
||||
updateStatus: (status: azdata.TaskStatus) => {
|
||||
opStatus = status;
|
||||
},
|
||||
id: '',
|
||||
onCanceled: new vscode.EventEmitter<void>().event,
|
||||
},
|
||||
getOpStatus: () => { return opStatus; }
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import {
|
||||
ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
|
||||
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName
|
||||
}
|
||||
from '../../../views/models/modelViewBase';
|
||||
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces';
|
||||
import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||
import { PredictWizard } from '../../../views/models/prediction/predictWizard';
|
||||
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
|
||||
|
||||
describe('Predict Wizard', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new PredictWizard(testContext.apiWrapper.object, '');
|
||||
await view.open();
|
||||
should.notEqual(view.wizardView, undefined);
|
||||
should.notEqual(view.modelSourcePage, undefined);
|
||||
});
|
||||
|
||||
it('Should load data successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new PredictWizard(testContext.apiWrapper.object, '');
|
||||
await view.open();
|
||||
let accounts: azdata.Account[] = [
|
||||
{
|
||||
key: {
|
||||
accountId: '1',
|
||||
providerId: ''
|
||||
},
|
||||
displayInfo: {
|
||||
displayName: 'account',
|
||||
userId: '',
|
||||
accountType: '',
|
||||
contextualDisplayName: ''
|
||||
},
|
||||
isStale: false,
|
||||
properties: []
|
||||
}
|
||||
];
|
||||
let subscriptions: azureResource.AzureResourceSubscription[] = [
|
||||
{
|
||||
name: 'subscription',
|
||||
id: '2'
|
||||
}
|
||||
];
|
||||
let groups: azureResource.AzureResourceResourceGroup[] = [
|
||||
{
|
||||
name: 'group',
|
||||
id: '3'
|
||||
}
|
||||
];
|
||||
let workspaces: Workspace[] = [
|
||||
{
|
||||
name: 'workspace',
|
||||
id: '4'
|
||||
}
|
||||
];
|
||||
let models: WorkspaceModel[] = [
|
||||
{
|
||||
id: '5',
|
||||
name: 'model'
|
||||
}
|
||||
];
|
||||
let localModels: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'model',
|
||||
title: 'model'
|
||||
}
|
||||
];
|
||||
const dbNames: string[] = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const tableNames: DatabaseTable[] = [
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
},
|
||||
{
|
||||
databaseName: 'db1',
|
||||
tableName: 'tb2',
|
||||
schema: 'dbo'
|
||||
}
|
||||
];
|
||||
const columnNames: TableColumn[] = [
|
||||
{
|
||||
columnName: 'c1',
|
||||
dataType: 'int'
|
||||
},
|
||||
{
|
||||
columnName: 'c2',
|
||||
dataType: 'varchar'
|
||||
}
|
||||
];
|
||||
const modelParameters: ModelParameters = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
view.on(ListModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
|
||||
});
|
||||
view.on(ListAccountsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
|
||||
});
|
||||
view.on(ListSubscriptionsEventName, () => {
|
||||
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
|
||||
});
|
||||
view.on(ListGroupsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
|
||||
});
|
||||
view.on(ListWorkspacesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
|
||||
});
|
||||
view.on(ListAzureModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||
});
|
||||
view.on(ListDatabaseNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: dbNames });
|
||||
});
|
||||
view.on(ListTableNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: tableNames });
|
||||
});
|
||||
view.on(ListColumnNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { data: columnNames });
|
||||
});
|
||||
view.on(LoadModelParametersEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { data: modelParameters });
|
||||
});
|
||||
view.on(DownloadAzureModelEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { data: 'path' });
|
||||
});
|
||||
view.on(DownloadRegisteredModelEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' });
|
||||
});
|
||||
await view.refresh();
|
||||
should.notEqual(view.azureModelsComponent?.data, undefined);
|
||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||
|
||||
should.notEqual(await view.getModelFileName(), undefined);
|
||||
await view.columnsSelectionPage?.onEnter();
|
||||
|
||||
should.notEqual(view.columnsSelectionPage?.data, undefined);
|
||||
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name);
|
||||
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length);
|
||||
});
|
||||
});
|
||||
@@ -20,7 +20,7 @@ describe('Register Model Wizard', () => {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
|
||||
view.open();
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
should.notEqual(view.wizardView, undefined);
|
||||
should.notEqual(view.modelSourcePage, undefined);
|
||||
@@ -30,7 +30,7 @@ describe('Register Model Wizard', () => {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
|
||||
view.open();
|
||||
await view.open();
|
||||
let accounts: azdata.Account[] = [
|
||||
{
|
||||
key: {
|
||||
@@ -98,5 +98,7 @@ describe('Register Model Wizard', () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||
});
|
||||
await view.refresh();
|
||||
should.notEqual(view.azureModelsComponent?.data ,undefined);
|
||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,14 +7,12 @@ import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as mssql from '../../../../../mssql/src/mssql';
|
||||
import { createViewContext } from '../utils';
|
||||
import { ModelViewBase } from '../../../views/models/modelViewBase';
|
||||
|
||||
export interface TestContext {
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
view: azdata.ModelView;
|
||||
languageExtensionService: mssql.ILanguageExtensionService;
|
||||
onClick: vscode.EventEmitter<any>;
|
||||
}
|
||||
|
||||
@@ -34,16 +32,10 @@ export class ParentDialog extends ModelViewBase {
|
||||
export function createContext(): TestContext {
|
||||
|
||||
let viewTestContext = createViewContext();
|
||||
let languageExtensionService: mssql.ILanguageExtensionService = {
|
||||
listLanguages: () => { return Promise.resolve([]); },
|
||||
deleteLanguage: () => { return Promise.resolve(); },
|
||||
updateLanguage: () => { return Promise.resolve(); }
|
||||
};
|
||||
|
||||
return {
|
||||
apiWrapper: viewTestContext.apiWrapper,
|
||||
view: viewTestContext.view,
|
||||
languageExtensionService: languageExtensionService,
|
||||
onClick: viewTestContext.onClick
|
||||
};
|
||||
}
|
||||
|
||||
@@ -62,6 +62,9 @@ export function createViewContext(): ViewTestContext {
|
||||
onTextChanged: undefined!,
|
||||
onEnterKeyPressed: undefined!,
|
||||
value: ''
|
||||
});
|
||||
let image: () => azdata.ImageComponent = () => Object.assign({}, componentBase, {
|
||||
|
||||
});
|
||||
let dropdown: () => azdata.DropDownComponent = () => Object.assign({}, componentBase, {
|
||||
onValueChanged: onClick.event,
|
||||
@@ -124,6 +127,14 @@ export function createViewContext(): ViewTestContext {
|
||||
withProperties: () => inputBoxBuilder,
|
||||
withValidation: () => inputBoxBuilder
|
||||
};
|
||||
let imageBuilder: azdata.ComponentBuilder<azdata.ImageComponent> = {
|
||||
component: () => {
|
||||
let r = image();
|
||||
return r;
|
||||
},
|
||||
withProperties: () => imageBuilder,
|
||||
withValidation: () => imageBuilder
|
||||
};
|
||||
let dropdownBuilder: azdata.ComponentBuilder<azdata.DropDownComponent> = {
|
||||
component: () => {
|
||||
let r = dropdown();
|
||||
@@ -156,7 +167,7 @@ export function createViewContext(): ViewTestContext {
|
||||
editor: undefined!,
|
||||
diffeditor: undefined!,
|
||||
text: () => inputBoxBuilder,
|
||||
image: undefined!,
|
||||
image: () => imageBuilder,
|
||||
button: () => buttonBuilder,
|
||||
dropDown: () => dropdownBuilder,
|
||||
tree: undefined!,
|
||||
@@ -181,7 +192,7 @@ export function createViewContext(): ViewTestContext {
|
||||
try {
|
||||
await handler(view);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
onValidityChanged: undefined!,
|
||||
@@ -242,7 +253,13 @@ export function createViewContext(): ViewTestContext {
|
||||
enabled: true,
|
||||
description: '',
|
||||
onValidityChanged: onClick.event,
|
||||
registerContent: () => { },
|
||||
registerContent: async (handler) => {
|
||||
try {
|
||||
await handler(view);
|
||||
} catch (err) {
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
modelView: undefined!,
|
||||
valid: true
|
||||
};
|
||||
|
||||
@@ -107,14 +107,14 @@ export abstract class LanguageViewBase {
|
||||
if (connection) {
|
||||
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : constants.extLangLocal}`;
|
||||
}
|
||||
return constants.packageManagerNoConnection;
|
||||
return constants.noConnectionError;
|
||||
}
|
||||
|
||||
public getServerTitle(): string {
|
||||
if (this.connection) {
|
||||
return this.connection.serverName;
|
||||
}
|
||||
return constants.packageManagerNoConnection;
|
||||
return constants.noConnectionError;
|
||||
}
|
||||
|
||||
private async getCurrentConnectionUrl(): Promise<string> {
|
||||
|
||||
@@ -18,6 +18,7 @@ export interface IPageView {
|
||||
onLeave?: () => Promise<void>;
|
||||
validate?: () => Promise<boolean>;
|
||||
refresh: () => Promise<void>;
|
||||
disposePage?: () => Promise<void>;
|
||||
viewPanel: azdata.window.ModelViewPanel | undefined;
|
||||
title: string;
|
||||
}
|
||||
|
||||
@@ -37,6 +37,16 @@ export class MainViewBase {
|
||||
}
|
||||
}
|
||||
|
||||
public async disposePages(): Promise<void> {
|
||||
if (this._pages) {
|
||||
await Promise.all(this._pages.map(async (p) => {
|
||||
if (p.disposePage) {
|
||||
await p.disposePage();
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
public async refresh(): Promise<void> {
|
||||
if (this._pages) {
|
||||
await Promise.all(this._pages.map(async (p) => await p.refresh()));
|
||||
|
||||
@@ -9,6 +9,7 @@ import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
|
||||
import { AzureModelsTable } from './azureModelsTable';
|
||||
import { IDataComponent, AzureModelResource } from '../interfaces';
|
||||
import { ModelArtifact } from './prediction/modelArtifact';
|
||||
|
||||
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource> {
|
||||
|
||||
@@ -17,6 +18,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
|
||||
|
||||
private _loader: azdata.LoadingComponent | undefined;
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _downloadedFile: ModelArtifact | undefined;
|
||||
|
||||
/**
|
||||
* Component to render a view to pick an azure model
|
||||
@@ -37,8 +39,14 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
|
||||
.withProperties({
|
||||
loading: true
|
||||
}).component();
|
||||
this.azureModelsTable.onModelSelectionChanged(async () => {
|
||||
if (this._downloadedFile) {
|
||||
await this._downloadedFile.close();
|
||||
}
|
||||
this._downloadedFile = undefined;
|
||||
});
|
||||
|
||||
this.azureFilterComponent.onWorkspacesSelected(async () => {
|
||||
this.azureFilterComponent.onWorkspacesSelectedChanged(async () => {
|
||||
await this.onLoading();
|
||||
await this.azureModelsTable?.loadData(this.azureFilterComponent?.data);
|
||||
await this.onLoaded();
|
||||
@@ -107,6 +115,22 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
|
||||
});
|
||||
}
|
||||
|
||||
public async getDownloadedModel(): Promise<ModelArtifact> {
|
||||
if (!this._downloadedFile) {
|
||||
this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(this.data));
|
||||
}
|
||||
return this._downloadedFile;
|
||||
}
|
||||
|
||||
/**
|
||||
* disposes the view
|
||||
*/
|
||||
public async disposeComponent(): Promise<void> {
|
||||
if (this._downloadedFile) {
|
||||
await this._downloadedFile.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as constants from '../../common/constants';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
@@ -18,6 +19,8 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
|
||||
private _table: azdata.DeclarativeTableComponent;
|
||||
private _selectedModelId: any;
|
||||
private _models: WorkspaceModel[] | undefined;
|
||||
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
|
||||
|
||||
/**
|
||||
* Creates a view to render azure models in a table
|
||||
@@ -115,6 +118,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
this._onModelSelectionChanged.fire();
|
||||
}
|
||||
|
||||
private createTableRow(model: WorkspaceModel): any[] {
|
||||
@@ -128,6 +132,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
|
||||
}).component();
|
||||
selectModelButton.onDidClick(() => {
|
||||
this._selectedModelId = model.id;
|
||||
this._onModelSelectionChanged.fire();
|
||||
});
|
||||
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
|
||||
}
|
||||
|
||||
@@ -27,8 +27,8 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
|
||||
private _azureSubscriptions: azureResource.AzureResourceSubscription[] = [];
|
||||
private _azureGroups: azureResource.AzureResource[] = [];
|
||||
private _azureWorkspaces: Workspace[] = [];
|
||||
private _onWorkspacesSelected: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||
public readonly onWorkspacesSelected: vscode.Event<void> = this._onWorkspacesSelected.event;
|
||||
private _onWorkspacesSelectedChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||
public readonly onWorkspacesSelectedChanged: vscode.Event<void> = this._onWorkspacesSelectedChanged.event;
|
||||
|
||||
/**
|
||||
* Creates a new view
|
||||
@@ -59,7 +59,7 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
|
||||
await this.onGroupSelected();
|
||||
});
|
||||
this._workspaces.onValueChanged(async () => {
|
||||
await this.onWorkspaceSelected();
|
||||
await this.onWorkspaceSelectedChanged();
|
||||
});
|
||||
|
||||
this._form = this._modelBuilder.formContainer().withFormItems([{
|
||||
@@ -182,26 +182,26 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
|
||||
this._workspaces.values = values;
|
||||
this._workspaces.value = values[0];
|
||||
}
|
||||
this.onWorkspaceSelected();
|
||||
this.onWorkspaceSelectedChanged();
|
||||
}
|
||||
|
||||
private onWorkspaceSelected(): void {
|
||||
this._onWorkspacesSelected.fire();
|
||||
private onWorkspaceSelectedChanged(): void {
|
||||
this._onWorkspacesSelectedChanged.fire();
|
||||
}
|
||||
|
||||
private get workspace(): Workspace | undefined {
|
||||
return this._azureWorkspaces ? this._azureWorkspaces.find(a => a.id === (<azdata.CategoryValue>this._workspaces.value).name) : undefined;
|
||||
return this._azureWorkspaces && this._workspaces.value ? this._azureWorkspaces.find(a => a.id === (<azdata.CategoryValue>this._workspaces.value).name) : undefined;
|
||||
}
|
||||
|
||||
private get account(): azdata.Account | undefined {
|
||||
return this._azureAccounts ? this._azureAccounts.find(a => a.key.accountId === (<azdata.CategoryValue>this._accounts.value).name) : undefined;
|
||||
return this._azureAccounts && this._accounts.value ? this._azureAccounts.find(a => a.key.accountId === (<azdata.CategoryValue>this._accounts.value).name) : undefined;
|
||||
}
|
||||
|
||||
private get group(): azureResource.AzureResource | undefined {
|
||||
return this._azureGroups ? this._azureGroups.find(a => a.id === (<azdata.CategoryValue>this._groups.value).name) : undefined;
|
||||
return this._azureGroups && this._groups.value ? this._azureGroups.find(a => a.id === (<azdata.CategoryValue>this._groups.value).name) : undefined;
|
||||
}
|
||||
|
||||
private get subscription(): azureResource.AzureResourceSubscription | undefined {
|
||||
return this._azureSubscriptions ? this._azureSubscriptions.find(a => a.id === (<azdata.CategoryValue>this._subscriptions.value).name) : undefined;
|
||||
return this._azureSubscriptions && this._subscriptions.value ? this._azureSubscriptions.find(a => a.id === (<azdata.CategoryValue>this._subscriptions.value).name) : undefined;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,15 +9,15 @@ import { azureResource } from '../../typings/azure-resource';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
|
||||
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
|
||||
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||
import { DeployedModelService } from '../../modelManagement/deployedModelService';
|
||||
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
|
||||
import {
|
||||
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
||||
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
|
||||
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
||||
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs
|
||||
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName
|
||||
} from './modelViewBase';
|
||||
import { ControllerBase } from '../controllerBase';
|
||||
import { RegisterModelWizard } from './registerModels/registerModelWizard';
|
||||
@@ -39,7 +39,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
apiWrapper: ApiWrapper,
|
||||
private _root: string,
|
||||
private _amlService: AzureModelRegistryService,
|
||||
private _registeredModelService: RegisteredModelService,
|
||||
private _registeredModelService: DeployedModelService,
|
||||
private _predictService: PredictService) {
|
||||
super(apiWrapper);
|
||||
}
|
||||
@@ -61,7 +61,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
|
||||
// Open view
|
||||
//
|
||||
view.open();
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
}
|
||||
@@ -74,10 +74,15 @@ export class ModelManagementController extends ControllerBase {
|
||||
let view = new PredictWizard(this._apiWrapper, this._root);
|
||||
|
||||
this.registerEvents(view);
|
||||
view.on(LoadModelParametersEventName, async () => {
|
||||
const modelArtifact = await view.getModelFileName();
|
||||
await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService,
|
||||
modelArtifact?.filePath);
|
||||
});
|
||||
|
||||
// Open view
|
||||
//
|
||||
view.open();
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
}
|
||||
@@ -151,6 +156,11 @@ export class ModelManagementController extends ControllerBase {
|
||||
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
|
||||
predictArgs, predictArgs.model, predictArgs.filePath);
|
||||
});
|
||||
view.on(DownloadRegisteredModelEventName, async (arg) => {
|
||||
let model = <RegisteredModel>arg;
|
||||
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
|
||||
model);
|
||||
});
|
||||
view.on(SourceModelSelectedEventName, () => {
|
||||
view.refresh();
|
||||
});
|
||||
@@ -191,8 +201,8 @@ export class ModelManagementController extends ControllerBase {
|
||||
return await service.getWorkspaces(account, subscription, group);
|
||||
}
|
||||
|
||||
private async getRegisteredModels(registeredModelService: RegisteredModelService): Promise<RegisteredModel[]> {
|
||||
return registeredModelService.getRegisteredModels();
|
||||
private async getRegisteredModels(registeredModelService: DeployedModelService): Promise<RegisteredModel[]> {
|
||||
return registeredModelService.getDeployedModels();
|
||||
}
|
||||
|
||||
private async getAzureModels(
|
||||
@@ -207,9 +217,9 @@ export class ModelManagementController extends ControllerBase {
|
||||
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
|
||||
}
|
||||
|
||||
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
|
||||
private async registerLocalModel(service: DeployedModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
|
||||
if (filePath) {
|
||||
await service.registerLocalModel(filePath, details);
|
||||
await service.deployLocalModel(filePath, details);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
|
||||
@@ -218,7 +228,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
|
||||
private async registerAzureModel(
|
||||
azureService: AzureModelRegistryService,
|
||||
service: RegisteredModelService,
|
||||
service: DeployedModelService,
|
||||
account: azdata.Account | undefined,
|
||||
subscription: azureResource.AzureResourceSubscription | undefined,
|
||||
resourceGroup: azureResource.AzureResource | undefined,
|
||||
@@ -231,7 +241,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
|
||||
if (filePath) {
|
||||
|
||||
await service.registerLocalModel(filePath, details);
|
||||
await service.deployLocalModel(filePath, details);
|
||||
await fs.promises.unlink(filePath);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
@@ -246,7 +256,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
return await predictService.getTableList(databaseName);
|
||||
}
|
||||
|
||||
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<string[]> {
|
||||
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<TableColumn[]> {
|
||||
return await predictService.getTableColumnsList(databaseTable);
|
||||
}
|
||||
|
||||
@@ -263,6 +273,24 @@ export class ModelManagementController extends ControllerBase {
|
||||
return result;
|
||||
}
|
||||
|
||||
private async downloadRegisteredModel(
|
||||
registeredModelService: DeployedModelService,
|
||||
model: RegisteredModel | undefined): Promise<string> {
|
||||
if (!model) {
|
||||
throw Error(constants.invalidModelToPredictError);
|
||||
}
|
||||
return await registeredModelService.downloadModel(model);
|
||||
}
|
||||
|
||||
private async loadModelParameters(
|
||||
registeredModelService: DeployedModelService,
|
||||
model: string | undefined): Promise<ModelParameters | undefined> {
|
||||
if (!model) {
|
||||
return undefined;
|
||||
}
|
||||
return await registeredModelService.loadModelParameters(model);
|
||||
}
|
||||
|
||||
private async downloadAzureModel(
|
||||
azureService: AzureModelRegistryService,
|
||||
account: azdata.Account | undefined,
|
||||
|
||||
@@ -120,4 +120,14 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
||||
}
|
||||
return Promise.resolve(validated);
|
||||
}
|
||||
|
||||
public async disposePage(): Promise<void> {
|
||||
if (this.azureModelsComponent) {
|
||||
await this.azureModelsComponent.disposeComponent();
|
||||
|
||||
}
|
||||
if (this.registeredModelsComponent) {
|
||||
await this.registeredModelsComponent.disposeComponent();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import * as azdata from 'azdata';
|
||||
import { azureResource } from '../../typings/azure-resource';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { ViewBase } from '../viewBase';
|
||||
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
|
||||
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
|
||||
|
||||
@@ -47,9 +47,11 @@ export const ListWorkspacesEventName = 'listWorkspaces';
|
||||
export const RegisterLocalModelEventName = 'registerLocalModel';
|
||||
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
|
||||
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
|
||||
export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
|
||||
export const PredictModelEventName = 'predictModel';
|
||||
export const RegisterModelEventName = 'registerModel';
|
||||
export const SourceModelSelectedEventName = 'sourceModelSelected';
|
||||
export const LoadModelParametersEventName = 'loadModelParameters';
|
||||
|
||||
/**
|
||||
* Base class for all model management views
|
||||
@@ -75,7 +77,9 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
ListTableNamesEventName,
|
||||
ListColumnNamesEventName,
|
||||
PredictModelEventName,
|
||||
DownloadAzureModelEventName]);
|
||||
DownloadAzureModelEventName,
|
||||
DownloadRegisteredModelEventName,
|
||||
LoadModelParametersEventName]);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -124,7 +128,7 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
/**
|
||||
* lists column names
|
||||
*/
|
||||
public async listColumnNames(table: DatabaseTable): Promise<string[]> {
|
||||
public async listColumnNames(table: DatabaseTable): Promise<TableColumn[]> {
|
||||
return await this.sendDataRequest(ListColumnNamesEventName, table);
|
||||
}
|
||||
|
||||
@@ -151,6 +155,14 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
return await this.sendDataRequest(RegisterLocalModelEventName, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* downloads registered model
|
||||
* @param model model to download
|
||||
*/
|
||||
public async downloadRegisteredModel(model: RegisteredModel | undefined): Promise<string> {
|
||||
return await this.sendDataRequest(DownloadRegisteredModelEventName, model);
|
||||
}
|
||||
|
||||
/**
|
||||
* download azure model
|
||||
* @param args azure resource
|
||||
@@ -159,6 +171,13 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
return await this.sendDataRequest(DownloadAzureModelEventName, resource);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(): Promise<ModelParameters | undefined> {
|
||||
return await this.sendDataRequest(LoadModelParametersEventName);
|
||||
}
|
||||
|
||||
/**
|
||||
* registers azure model
|
||||
* @param args azure resource
|
||||
|
||||
@@ -8,7 +8,7 @@ import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../../interfaces';
|
||||
import { ColumnsFilterComponent } from './columnsFilterComponent';
|
||||
import { InputColumnsComponent } from './inputColumnsComponent';
|
||||
import { OutputColumnsComponent } from './outputColumnsComponent';
|
||||
import { PredictParameters } from '../../../prediction/interfaces';
|
||||
|
||||
@@ -19,7 +19,7 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public columnsFilterComponent: ColumnsFilterComponent | undefined;
|
||||
public inputColumnsComponent: InputColumnsComponent | undefined;
|
||||
public outputColumnsComponent: OutputColumnsComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
@@ -32,15 +32,14 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.columnsFilterComponent = new ColumnsFilterComponent(this._apiWrapper, this);
|
||||
this.columnsFilterComponent.registerComponent(modelBuilder);
|
||||
this.columnsFilterComponent.addComponents(this._formBuilder);
|
||||
this.refresh();
|
||||
this.inputColumnsComponent = new InputColumnsComponent(this._apiWrapper, this);
|
||||
this.inputColumnsComponent.registerComponent(modelBuilder);
|
||||
this.inputColumnsComponent.addComponents(this._formBuilder);
|
||||
|
||||
this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this);
|
||||
this.outputColumnsComponent.registerComponent(modelBuilder);
|
||||
this.outputColumnsComponent.addComponents(this._formBuilder);
|
||||
this.refresh();
|
||||
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
@@ -49,8 +48,8 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictParameters | undefined {
|
||||
return this.columnsFilterComponent?.data && this.outputColumnsComponent?.data ?
|
||||
Object.assign({}, this.columnsFilterComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
|
||||
return this.inputColumnsComponent?.data && this.outputColumnsComponent?.data ?
|
||||
Object.assign({}, this.inputColumnsComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
|
||||
undefined;
|
||||
}
|
||||
|
||||
@@ -66,8 +65,8 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this._formBuilder) {
|
||||
if (this.columnsFilterComponent) {
|
||||
await this.columnsFilterComponent.refresh();
|
||||
if (this.inputColumnsComponent) {
|
||||
await this.inputColumnsComponent.refresh();
|
||||
}
|
||||
if (this.outputColumnsComponent) {
|
||||
await this.outputColumnsComponent.refresh();
|
||||
@@ -75,6 +74,24 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
|
||||
}
|
||||
}
|
||||
|
||||
public async onEnter(): Promise<void> {
|
||||
await this.inputColumnsComponent?.onLoading();
|
||||
await this.outputColumnsComponent?.onLoading();
|
||||
try {
|
||||
const modelParameters = await this.loadModelParameters();
|
||||
if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) {
|
||||
this.inputColumnsComponent.modelParameters = modelParameters;
|
||||
this.outputColumnsComponent.modelParameters = modelParameters;
|
||||
await this.inputColumnsComponent.refresh();
|
||||
await this.outputColumnsComponent.refresh();
|
||||
}
|
||||
} catch (error) {
|
||||
this.showErrorMessage(constants.loadModelParameterFailedError, error);
|
||||
}
|
||||
await this.inputColumnsComponent?.onLoaded();
|
||||
await this.outputColumnsComponent?.onLoaded();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
|
||||
@@ -8,133 +8,280 @@ import * as constants from '../../../common/constants';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { PredictColumn, DatabaseTable } from '../../../prediction/interfaces';
|
||||
import { PredictColumn, DatabaseTable, TableColumn } from '../../../prediction/interfaces';
|
||||
import { ModelParameter, ModelParameters } from '../../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to render azure models in a table
|
||||
*/
|
||||
export class ColumnsTable extends ModelViewBase implements IDataComponent<PredictColumn[]> {
|
||||
|
||||
private _table: azdata.DeclarativeTableComponent;
|
||||
private _selectedColumns: PredictColumn[] = [];
|
||||
private _columns: string[] | undefined;
|
||||
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||
private _parameters: PredictColumn[] = [];
|
||||
private _loader: azdata.LoadingComponent;
|
||||
private _dataTypes: string[] = [
|
||||
'bigint',
|
||||
'int',
|
||||
'smallint',
|
||||
'real',
|
||||
'float',
|
||||
'varchar(MAX)',
|
||||
'bit'
|
||||
];
|
||||
|
||||
|
||||
/**
|
||||
* Creates a view to render azure models in a table
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _forInput: boolean = true) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._table = this.registerComponent(this._modelBuilder);
|
||||
this._loader = this.registerComponent(this._modelBuilder);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register components
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.LoadingComponent {
|
||||
let columnHeader: azdata.DeclarativeTableColumn[];
|
||||
if (this._forInput) {
|
||||
columnHeader = [
|
||||
{ // Action
|
||||
displayName: constants.columnName,
|
||||
ariaLabel: constants.columnName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Name
|
||||
displayName: '',
|
||||
ariaLabel: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Name
|
||||
displayName: constants.inputName,
|
||||
ariaLabel: constants.inputName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 120,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
];
|
||||
} else {
|
||||
columnHeader = [
|
||||
{ // Name
|
||||
displayName: constants.outputName,
|
||||
ariaLabel: constants.outputName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 200,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: constants.displayName,
|
||||
ariaLabel: constants.displayName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: constants.dataTypeName,
|
||||
ariaLabel: constants.dataTypeName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
];
|
||||
}
|
||||
this._table = modelBuilder.declarativeTable()
|
||||
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
columns: [
|
||||
{ // Name
|
||||
displayName: constants.columnDatabase,
|
||||
ariaLabel: constants.columnName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 120,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: constants.inputName,
|
||||
ariaLabel: constants.inputName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
],
|
||||
columns: columnHeader,
|
||||
data: [],
|
||||
ariaLabel: constants.mlsConfigTitle
|
||||
})
|
||||
.component();
|
||||
return this._table;
|
||||
this._loader = modelBuilder.loadingComponent()
|
||||
.withItem(this._table)
|
||||
.withProperties({
|
||||
loading: true
|
||||
}).component();
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
public get component(): azdata.DeclarativeTableComponent {
|
||||
return this._table;
|
||||
public async onLoading(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: true });
|
||||
}
|
||||
}
|
||||
|
||||
public async onLoaded(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: false });
|
||||
}
|
||||
}
|
||||
|
||||
public get component(): azdata.Component {
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load data in the component
|
||||
* @param workspaceResource Azure workspace
|
||||
*/
|
||||
public async loadData(table: DatabaseTable): Promise<void> {
|
||||
this._selectedColumns = [];
|
||||
if (this._table) {
|
||||
this._columns = await this.listColumnNames(table);
|
||||
let tableData: any[][] = [];
|
||||
public async loadInputs(modelParameters: ModelParameters | undefined, table: DatabaseTable): Promise<void> {
|
||||
await this.onLoading();
|
||||
this._parameters = [];
|
||||
let tableData: any[][] = [];
|
||||
|
||||
if (this._columns) {
|
||||
tableData = tableData.concat(this._columns.map(model => this.createTableRow(model)));
|
||||
if (this._table) {
|
||||
if (this._forInput) {
|
||||
const columns = await this.listColumnNames(table);
|
||||
if (modelParameters?.inputs && columns) {
|
||||
tableData = tableData.concat(modelParameters.inputs.map(input => this.createInputTableRow(input, columns)));
|
||||
}
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
await this.onLoaded();
|
||||
}
|
||||
|
||||
private createTableRow(column: string): any[] {
|
||||
if (this._modelBuilder) {
|
||||
let selectRowButton = this._modelBuilder.checkBox().withProperties({
|
||||
public async loadOutputs(modelParameters: ModelParameters | undefined): Promise<void> {
|
||||
this.onLoading();
|
||||
this._parameters = [];
|
||||
let tableData: any[][] = [];
|
||||
|
||||
width: 15,
|
||||
height: 15,
|
||||
checked: true
|
||||
if (this._table) {
|
||||
if (!this._forInput) {
|
||||
if (modelParameters?.outputs && this._dataTypes) {
|
||||
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes)));
|
||||
}
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
this.onLoaded();
|
||||
}
|
||||
|
||||
private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] {
|
||||
if (this._modelBuilder) {
|
||||
|
||||
let nameInput = this._modelBuilder.dropDown().withProperties({
|
||||
values: dataTypes,
|
||||
width: this.componentMaxLength
|
||||
}).component();
|
||||
let nameInputBox = this._modelBuilder.inputBox().withProperties({
|
||||
value: '',
|
||||
width: 150
|
||||
}).component();
|
||||
this._selectedColumns.push({ name: column });
|
||||
selectRowButton.onChanged(() => {
|
||||
if (selectRowButton.checked) {
|
||||
if (!this._selectedColumns.find(x => x.name === column)) {
|
||||
this._selectedColumns.push({ name: column });
|
||||
}
|
||||
} else {
|
||||
if (this._selectedColumns.find(x => x.name === column)) {
|
||||
this._selectedColumns = this._selectedColumns.filter(x => x.name !== column);
|
||||
const name = modelParameter.name;
|
||||
const dataType = dataTypes.find(x => x === modelParameter.type);
|
||||
if (dataType) {
|
||||
nameInput.value = dataType;
|
||||
}
|
||||
this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type });
|
||||
|
||||
nameInput.onValueChanged(() => {
|
||||
const value = <string>nameInput.value;
|
||||
if (value !== modelParameter.type) {
|
||||
let selectedRow = this._parameters.find(x => x.paramName === name);
|
||||
if (selectedRow) {
|
||||
selectedRow.dataType = value;
|
||||
}
|
||||
}
|
||||
});
|
||||
nameInputBox.onTextChanged(() => {
|
||||
let selectedRow = this._selectedColumns.find(x => x.name === column);
|
||||
let displayNameInput = this._modelBuilder.inputBox().withProperties({
|
||||
value: name,
|
||||
width: 200
|
||||
}).component();
|
||||
displayNameInput.onTextChanged(() => {
|
||||
let selectedRow = this._parameters.find(x => x.paramName === name);
|
||||
if (selectedRow) {
|
||||
selectedRow.displayName = nameInputBox.value;
|
||||
selectedRow.columnName = displayNameInput.value || name;
|
||||
}
|
||||
});
|
||||
return [column, nameInputBox, selectRowButton];
|
||||
return [`${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`, displayNameInput, nameInput];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
private createInputTableRow(modelParameter: ModelParameter, columns: TableColumn[] | undefined): any[] {
|
||||
if (this._modelBuilder && columns) {
|
||||
const values = columns.map(c => { return { name: c.columnName, displayName: `${c.columnName}(${c.dataType})` }; });
|
||||
let nameInput = this._modelBuilder.dropDown().withProperties({
|
||||
values: values,
|
||||
width: this.componentMaxLength
|
||||
}).component();
|
||||
const name = modelParameter.name;
|
||||
let column = values.find(x => x.name === modelParameter.name);
|
||||
if (!column) {
|
||||
column = values[0];
|
||||
}
|
||||
nameInput.value = column;
|
||||
|
||||
this._parameters.push({ columnName: column.name, paramName: name });
|
||||
|
||||
nameInput.onValueChanged(() => {
|
||||
const selectedColumn = nameInput.value;
|
||||
const value = selectedColumn ? (<azdata.CategoryValue>selectedColumn).name : undefined;
|
||||
|
||||
let selectedRow = this._parameters.find(x => x.paramName === name);
|
||||
if (selectedRow) {
|
||||
selectedRow.columnName = value || '';
|
||||
}
|
||||
});
|
||||
const label = this._modelBuilder.inputBox().withProperties({
|
||||
value: `${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`,
|
||||
enabled: false,
|
||||
width: this.componentMaxLength
|
||||
}).component();
|
||||
const image = this._modelBuilder.image().withProperties({
|
||||
width: 50,
|
||||
height: 50,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/arrow.svg'),
|
||||
light: this.asAbsolutePath('images/arrow.svg')
|
||||
},
|
||||
iconWidth: 20,
|
||||
iconHeight: 20,
|
||||
title: 'maps'
|
||||
}).component();
|
||||
return [nameInput, image, label];
|
||||
}
|
||||
|
||||
return [];
|
||||
@@ -144,7 +291,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictColumn[] | undefined {
|
||||
return this._selectedColumns;
|
||||
return this._parameters;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -8,13 +8,14 @@ import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { ColumnsTable } from './columnsTable';
|
||||
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
|
||||
import { ModelParameters } from '../../../modelManagement/interfaces';
|
||||
import { ColumnsTable } from './columnsTable';
|
||||
|
||||
/**
|
||||
* View to render filters to pick an azure resource
|
||||
*/
|
||||
export class ColumnsFilterComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
|
||||
export class InputColumnsComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _databases: azdata.DropDownComponent | undefined;
|
||||
@@ -22,7 +23,9 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
|
||||
private _columns: ColumnsTable | undefined;
|
||||
private _dbNames: string[] = [];
|
||||
private _tableNames: DatabaseTable[] = [];
|
||||
|
||||
private _modelParameters: ModelParameters | undefined;
|
||||
private _dbTableComponent: azdata.FlexContainer | undefined;
|
||||
private tableMaxLength = this.componentMaxLength * 2 + 70;
|
||||
/**
|
||||
* Creates a new view
|
||||
*/
|
||||
@@ -52,27 +55,47 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
|
||||
});
|
||||
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.azureAccount,
|
||||
const databaseForm = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.columnDatabase,
|
||||
component: this._databases
|
||||
}, {
|
||||
title: constants.azureSubscription,
|
||||
}]).withLayout({
|
||||
padding: '0px'
|
||||
}).component();
|
||||
const tableForm = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.columnTable,
|
||||
component: this._tables
|
||||
}]).withLayout({
|
||||
padding: '0px'
|
||||
}).component();
|
||||
this._dbTableComponent = modelBuilder.flexContainer().withItems([
|
||||
databaseForm,
|
||||
tableForm
|
||||
], {
|
||||
flex: '0 0 auto',
|
||||
CSSStyles: {
|
||||
'align-items': 'flex-start'
|
||||
}
|
||||
}).withLayout({
|
||||
flexFlow: 'row',
|
||||
justifyContent: 'space-between',
|
||||
width: this.tableMaxLength
|
||||
}).component();
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: '',
|
||||
component: this._dbTableComponent
|
||||
}, {
|
||||
title: constants.azureGroup,
|
||||
title: constants.inputColumns,
|
||||
component: this._columns.component
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._databases && this._tables && this._columns) {
|
||||
if (this._columns && this._dbTableComponent) {
|
||||
formBuilder.addFormItems([{
|
||||
title: constants.columnDatabase,
|
||||
component: this._databases
|
||||
}, {
|
||||
title: constants.columnTable,
|
||||
component: this._tables
|
||||
title: '',
|
||||
component: this._dbTableComponent
|
||||
}, {
|
||||
title: constants.inputColumns,
|
||||
component: this._columns.component
|
||||
@@ -81,17 +104,13 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._databases && this._tables && this._columns) {
|
||||
if (this._columns && this._dbTableComponent) {
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureAccount,
|
||||
component: this._databases
|
||||
title: '',
|
||||
component: this._dbTableComponent
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureSubscription,
|
||||
component: this._tables
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureGroup,
|
||||
title: constants.inputColumns,
|
||||
component: this._columns.component
|
||||
});
|
||||
}
|
||||
@@ -125,6 +144,22 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
|
||||
await this.onDatabaseSelected();
|
||||
}
|
||||
|
||||
public set modelParameters(value: ModelParameters) {
|
||||
this._modelParameters = value;
|
||||
}
|
||||
|
||||
public async onLoading(): Promise<void> {
|
||||
if (this._columns) {
|
||||
await this._columns.onLoading();
|
||||
}
|
||||
}
|
||||
|
||||
public async onLoaded(): Promise<void> {
|
||||
if (this._columns) {
|
||||
await this._columns.onLoaded();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* refreshes the view
|
||||
*/
|
||||
@@ -146,7 +181,7 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
|
||||
}
|
||||
|
||||
private async onTableSelected(): Promise<void> {
|
||||
this._columns?.loadData(this.databaseTable);
|
||||
this._columns?.loadInputs(this._modelParameters, this.databaseTable);
|
||||
}
|
||||
|
||||
private get databaseName(): string | undefined {
|
||||
@@ -0,0 +1,35 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as utils from '../../../common/utils';
|
||||
|
||||
/**
|
||||
* Wizard to register a model
|
||||
*/
|
||||
export class ModelArtifact {
|
||||
|
||||
/**
|
||||
* Creates new model artifact
|
||||
*/
|
||||
constructor(private _filePath: string, private _deleteAtClose: boolean = true) {
|
||||
}
|
||||
|
||||
public get filePath(): string {
|
||||
return this._filePath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes the artifact and disposes the resources
|
||||
*/
|
||||
public async close(): Promise<void> {
|
||||
if (this._deleteAtClose) {
|
||||
try {
|
||||
await utils.deleteFile(this._filePath);
|
||||
} catch {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,25 +9,18 @@ import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { PredictColumn } from '../../../prediction/interfaces';
|
||||
import { ColumnsTable } from './columnsTable';
|
||||
import { ModelParameters } from '../../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to render filters to pick an azure resource
|
||||
*/
|
||||
const componentWidth = 60;
|
||||
|
||||
export class OutputColumnsComponent extends ModelViewBase implements IDataComponent<PredictColumn[]> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _flex: azdata.FlexContainer | undefined;
|
||||
private _columnName: azdata.InputBoxComponent | undefined;
|
||||
private _columnTypes: azdata.DropDownComponent | undefined;
|
||||
private _dataTypes: string[] = [
|
||||
'int',
|
||||
'nvarchar(MAX)',
|
||||
'varchar(MAX)',
|
||||
'float',
|
||||
'double',
|
||||
'bit'
|
||||
];
|
||||
private _columns: ColumnsTable | undefined;
|
||||
private _modelParameters: ModelParameters | undefined;
|
||||
|
||||
/**
|
||||
* Creates a new view
|
||||
@@ -41,49 +34,29 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._columnName = modelBuilder.inputBox().withProperties({
|
||||
width: this.componentMaxLength - componentWidth - this.spaceBetweenComponentsLength
|
||||
}).component();
|
||||
this._columnTypes = modelBuilder.dropDown().withProperties({
|
||||
width: componentWidth
|
||||
}).component();
|
||||
|
||||
let flex = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
width: this._columnName.width
|
||||
}).withItems([
|
||||
this._columnName]
|
||||
).component();
|
||||
this._flex = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'row',
|
||||
justifyContent: 'space-between',
|
||||
width: this.componentMaxLength
|
||||
}).withItems([
|
||||
flex, this._columnTypes]
|
||||
).component();
|
||||
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this, false);
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.azureAccount,
|
||||
component: this._flex
|
||||
component: this._columns.component
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flex) {
|
||||
if (this._columns) {
|
||||
formBuilder.addFormItems([{
|
||||
title: constants.outputColumns,
|
||||
component: this._flex
|
||||
component: this._columns.component
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flex) {
|
||||
if (this._columns) {
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.outputColumns,
|
||||
component: this._flex
|
||||
component: this._columns.component
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -99,9 +72,24 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon
|
||||
* loads data in the components
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
if (this._columnTypes) {
|
||||
this._columnTypes.values = this._dataTypes;
|
||||
this._columnTypes.value = this._dataTypes[0];
|
||||
if (this._modelParameters) {
|
||||
this._columns?.loadOutputs(this._modelParameters);
|
||||
}
|
||||
}
|
||||
|
||||
public set modelParameters(value: ModelParameters) {
|
||||
this._modelParameters = value;
|
||||
}
|
||||
|
||||
public async onLoading(): Promise<void> {
|
||||
if (this._columns) {
|
||||
await this._columns.onLoading();
|
||||
}
|
||||
}
|
||||
|
||||
public async onLoaded(): Promise<void> {
|
||||
if (this._columns) {
|
||||
await this._columns.onLoaded();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,9 +104,6 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictColumn[] | undefined {
|
||||
return this._columnName && this._columnTypes ? [{
|
||||
name: this._columnName.value || '',
|
||||
dataType: <string>this._columnTypes.value || ''
|
||||
}] : undefined;
|
||||
return this._columns?.data;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import { WizardView } from '../../wizardView';
|
||||
import { ModelSourcePage } from '../modelSourcePage';
|
||||
import { ColumnsSelectionPage } from './columnsSelectionPage';
|
||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||
import { ModelArtifact } from './modelArtifact';
|
||||
|
||||
/**
|
||||
* Wizard to register a model
|
||||
@@ -21,7 +22,6 @@ import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||
export class PredictWizard extends ModelViewBase {
|
||||
|
||||
public modelSourcePage: ModelSourcePage | undefined;
|
||||
//public modelDetailsPage: ModelDetailsPage | undefined;
|
||||
public columnsSelectionPage: ColumnsSelectionPage | undefined;
|
||||
public wizardView: WizardView | undefined;
|
||||
private _parentView: ModelViewBase | undefined;
|
||||
@@ -37,7 +37,7 @@ export class PredictWizard extends ModelViewBase {
|
||||
/**
|
||||
* Opens a dialog to manage packages used by notebooks.
|
||||
*/
|
||||
public open(): void {
|
||||
public async open(): Promise<void> {
|
||||
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
|
||||
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
|
||||
this.wizardView = new WizardView(this._apiWrapper);
|
||||
@@ -50,16 +50,22 @@ export class PredictWizard extends ModelViewBase {
|
||||
wizard.doneButton.label = constants.predictModel;
|
||||
wizard.generateScriptButton.hidden = true;
|
||||
wizard.displayPageTitles = true;
|
||||
wizard.doneButton.onClick(async () => {
|
||||
await this.onClose();
|
||||
});
|
||||
wizard.cancelButton.onClick(async () => {
|
||||
await this.onClose();
|
||||
});
|
||||
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
||||
if (validated && pageInfo.newPage === undefined) {
|
||||
wizard.cancelButton.enabled = false;
|
||||
wizard.backButton.enabled = false;
|
||||
await this.predict();
|
||||
wizard.cancelButton.enabled = true;
|
||||
wizard.backButton.enabled = true;
|
||||
if (this._parentView) {
|
||||
this._parentView?.refresh();
|
||||
if (validated) {
|
||||
if (pageInfo.newPage === undefined) {
|
||||
this.onLoading();
|
||||
await this.predict();
|
||||
this.onLoaded();
|
||||
if (this._parentView) {
|
||||
this._parentView?.refresh();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
@@ -67,7 +73,22 @@ export class PredictWizard extends ModelViewBase {
|
||||
return validated;
|
||||
});
|
||||
|
||||
wizard.open();
|
||||
await wizard.open();
|
||||
}
|
||||
|
||||
private onLoading(): void {
|
||||
this.refreshButtons(true);
|
||||
}
|
||||
|
||||
private onLoaded(): void {
|
||||
this.refreshButtons(false);
|
||||
}
|
||||
|
||||
private refreshButtons(loading: boolean): void {
|
||||
if (this.wizardView && this.wizardView.wizard) {
|
||||
this.wizardView.wizard.cancelButton.enabled = !loading;
|
||||
this.wizardView.wizard.cancelButton.enabled = !loading;
|
||||
}
|
||||
}
|
||||
|
||||
public get modelResources(): ModelSourcesComponent | undefined {
|
||||
@@ -82,16 +103,26 @@ export class PredictWizard extends ModelViewBase {
|
||||
return this.modelSourcePage?.azureModelsComponent;
|
||||
}
|
||||
|
||||
public async getModelFileName(): Promise<ModelArtifact | undefined> {
|
||||
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||
return new ModelArtifact(this.localModelsComponent.data, false);
|
||||
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
|
||||
return await this.azureModelsComponent.getDownloadedModel();
|
||||
} else if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) {
|
||||
return await this.modelSourcePage.registeredModelsComponent.getDownloadedModel();
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private async predict(): Promise<boolean> {
|
||||
try {
|
||||
let modelFilePath: string = '';
|
||||
let modelFilePath: string | undefined;
|
||||
let registeredModel: RegisteredModel | undefined = undefined;
|
||||
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||
modelFilePath = this.localModelsComponent.data;
|
||||
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
|
||||
modelFilePath = await this.downloadAzureModel(this.azureModelsComponent?.data);
|
||||
} else {
|
||||
if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) {
|
||||
registeredModel = this.modelSourcePage?.registeredModelsComponent?.data;
|
||||
} else {
|
||||
const artifact = await this.getModelFileName();
|
||||
modelFilePath = artifact?.filePath;
|
||||
}
|
||||
|
||||
await this.generatePredictScript(registeredModel, modelFilePath, this.columnsSelectionPage?.data);
|
||||
@@ -102,6 +133,14 @@ export class PredictWizard extends ModelViewBase {
|
||||
}
|
||||
}
|
||||
|
||||
private async onClose(): Promise<void> {
|
||||
const artifact = await this.getModelFileName();
|
||||
if (artifact) {
|
||||
artifact.close();
|
||||
}
|
||||
await this.wizardView?.disposePages();
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh the pages
|
||||
*/
|
||||
|
||||
@@ -15,7 +15,7 @@ import { IPageView } from '../../interfaces';
|
||||
* View to render current registered models
|
||||
*/
|
||||
export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
||||
private _tableComponent: azdata.DeclarativeTableComponent | undefined;
|
||||
private _tableComponent: azdata.Component | undefined;
|
||||
private _dataTable: CurrentModelsTable | undefined;
|
||||
private _loader: azdata.LoadingComponent | undefined;
|
||||
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { ModelArtifact } from '../prediction/modelArtifact';
|
||||
|
||||
/**
|
||||
* View to render registered models table
|
||||
@@ -18,6 +20,10 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||
private _modelBuilder: azdata.ModelBuilder | undefined;
|
||||
private _selectedModel: any;
|
||||
private _loader: azdata.LoadingComponent | undefined;
|
||||
private _downloadedFile: ModelArtifact | undefined;
|
||||
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
|
||||
|
||||
/**
|
||||
* Creates new view
|
||||
@@ -30,7 +36,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
*
|
||||
* @param modelBuilder register the components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._modelBuilder = modelBuilder;
|
||||
this._table = modelBuilder.declarativeTable()
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
@@ -92,7 +98,12 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
ariaLabel: constants.mlsConfigTitle
|
||||
})
|
||||
.component();
|
||||
return this._table;
|
||||
this._loader = modelBuilder.loadingComponent()
|
||||
.withItem(this._table)
|
||||
.withProperties({
|
||||
loading: true
|
||||
}).component();
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
@@ -111,14 +122,15 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.DeclarativeTableComponent | undefined {
|
||||
return this._table;
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the data in the component
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
await this.onLoading();
|
||||
if (this._table) {
|
||||
let models: RegisteredModel[] | undefined;
|
||||
|
||||
@@ -131,6 +143,20 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
this.onModelSelected();
|
||||
await this.onLoaded();
|
||||
}
|
||||
|
||||
public async onLoading(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: true });
|
||||
}
|
||||
}
|
||||
|
||||
public async onLoaded(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: false });
|
||||
}
|
||||
}
|
||||
|
||||
private createTableRow(model: RegisteredModel): any[] {
|
||||
@@ -142,8 +168,9 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
height: 15,
|
||||
checked: false
|
||||
}).component();
|
||||
selectModelButton.onDidClick(() => {
|
||||
selectModelButton.onDidClick(async () => {
|
||||
this._selectedModel = model;
|
||||
await this.onModelSelected();
|
||||
});
|
||||
return [model.artifactName, model.title, model.created, selectModelButton];
|
||||
}
|
||||
@@ -151,6 +178,14 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
return [];
|
||||
}
|
||||
|
||||
private async onModelSelected(): Promise<void> {
|
||||
this._onModelSelectionChanged.fire();
|
||||
if (this._downloadedFile) {
|
||||
await this._downloadedFile.close();
|
||||
}
|
||||
this._downloadedFile = undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
@@ -158,6 +193,22 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
return this._selectedModel;
|
||||
}
|
||||
|
||||
public async getDownloadedModel(): Promise<ModelArtifact> {
|
||||
if (!this._downloadedFile) {
|
||||
this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data));
|
||||
}
|
||||
return this._downloadedFile;
|
||||
}
|
||||
|
||||
/**
|
||||
* disposes the view
|
||||
*/
|
||||
public async disposeComponent(): Promise<void> {
|
||||
if (this._downloadedFile) {
|
||||
await this._downloadedFile.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
|
||||
@@ -35,7 +35,7 @@ export class RegisterModelWizard extends ModelViewBase {
|
||||
/**
|
||||
* Opens a dialog to manage packages used by notebooks.
|
||||
*/
|
||||
public open(): void {
|
||||
public async open(): Promise<void> {
|
||||
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
|
||||
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
|
||||
this.wizardView = new WizardView(this._apiWrapper);
|
||||
@@ -63,7 +63,7 @@ export class RegisterModelWizard extends ModelViewBase {
|
||||
return validated;
|
||||
});
|
||||
|
||||
wizard.open();
|
||||
await wizard.open();
|
||||
}
|
||||
|
||||
public get modelResources(): ModelSourcesComponent | undefined {
|
||||
|
||||
@@ -128,14 +128,14 @@ export abstract class ViewBase extends EventEmitterCollection {
|
||||
if (connection) {
|
||||
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`;
|
||||
}
|
||||
return constants.packageManagerNoConnection;
|
||||
return constants.noConnectionError;
|
||||
}
|
||||
|
||||
public getServerTitle(): string {
|
||||
if (this.connection) {
|
||||
return this.connection.serverName;
|
||||
}
|
||||
return constants.packageManagerNoConnection;
|
||||
return constants.noConnectionError;
|
||||
}
|
||||
|
||||
private async getCurrentConnectionUrl(): Promise<string> {
|
||||
|
||||
@@ -68,7 +68,7 @@ export class WizardView extends MainViewBase {
|
||||
this._pages = pages;
|
||||
this._wizard.pages = pages.map(x => this.createWizardPage(x.title || '', x));
|
||||
this._wizard.onPageChanged(async (info) => {
|
||||
this.onWizardPageChanged(info);
|
||||
await this.onWizardPageChanged(info);
|
||||
});
|
||||
|
||||
return this._wizard;
|
||||
@@ -85,17 +85,17 @@ export class WizardView extends MainViewBase {
|
||||
return true;
|
||||
}
|
||||
|
||||
private onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
|
||||
private async onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
|
||||
let idxLast = pageInfo.lastPage;
|
||||
let lastPage = this._pages[idxLast];
|
||||
if (lastPage && lastPage.onLeave) {
|
||||
lastPage.onLeave();
|
||||
await lastPage.onLeave();
|
||||
}
|
||||
|
||||
let idx = pageInfo.newPage;
|
||||
let page = this._pages[idx];
|
||||
if (page && page.onEnter) {
|
||||
page.onEnter();
|
||||
await page.onEnter();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user