Integrate Contextualization API into Azure Data Studio to get better query recommendations from Copilot (#24044)

* Boilerplate for new metadata API endpoint

* Register all server metadata provider

* Fully registers data provider

* Registers metadata provider

* Instantiate metadata service

* Generates server metadata when connection is established

* Allow queryEditorInput to get server metadata

* Minor clean up

* Renames metadata provider and request endpoint

* Corrects documentation block

* Integrates get server metadata request endpoint

* Adjusts GetServerMetadataResult scripts type

* Add back Cargo.toml file

* Fix SQL hygiene error

* reflect changes made in in STS for table metadata

* Adds feature toggle to serverMetadataService

* Places toggle before request to get create scripts

* Fix build check issues

* Minor review changes

* Improves contextualization setting label

* Generalize contextualization service names

* Additional code review changes

* Update extensions/mssql/src/contracts.ts

Co-authored-by: Charles Gagnon <chgagnon@microsoft.com>

* Update src/sql/azdata.proposed.d.ts

Co-authored-by: Charles Gagnon <chgagnon@microsoft.com>

* Code reivew changes

* Capitalize c in contextualization

* Additional review changes

* Update provider type

* Simplify type and method names

* Unregister MSSQL ServerContextualization provider

---------

Co-authored-by: Charles Gagnon <chgagnon@microsoft.com>
This commit is contained in:
Lewis Sanchez
2023-08-21 19:54:44 -07:00
committed by GitHub
parent df5693ffd3
commit d5a9c172d1
19 changed files with 336 additions and 10 deletions

View File

@@ -887,12 +887,19 @@ declare module 'azdata' {
export enum DataProviderType {
TableDesignerProvider = 'TableDesignerProvider',
ExecutionPlanProvider = 'ExecutionPlanProvider'
ExecutionPlanProvider = 'ExecutionPlanProvider',
ServerContextualizationProvider = 'ServerContextualizationProvider'
}
export namespace dataprotocol {
export function registerTableDesignerProvider(provider: designers.TableDesignerProvider): vscode.Disposable;
export function registerExecutionPlanProvider(provider: executionPlan.ExecutionPlanProvider): vscode.Disposable;
/**
* Registers a server contextualization provider, which can provide context about a server to extensions like GitHub
* Copilot for improved suggestions.
* @param provider The provider to register
*/
export function registerServerContextualizationProvider(provider: contextualization.ServerContextualizationProvider): vscode.Disposable
}
export namespace designers {
@@ -1773,6 +1780,29 @@ declare module 'azdata' {
}
}
export namespace contextualization {
export interface GetServerContextualizationResult {
/**
* An array containing the generated server context.
*/
context: string[];
}
export interface ServerContextualizationProvider extends DataProvider {
/**
* Generates server context.
* @param ownerUri The URI of the connection to generate context for.
*/
generateServerContextualization(ownerUri: string): void;
/**
* Gets server context, which can be in the form of create scripts but is left up each provider.
* @param ownerUri The URI of the connection to get context for.
*/
getServerContextualization(ownerUri: string): Thenable<GetServerContextualizationResult>;
}
}
/**
* Component to display text with an icon representing the severity
*/

View File

@@ -46,6 +46,7 @@ export interface IQueryEditorConfiguration {
readonly tabColorMode: 'off' | 'border' | 'fill';
readonly showConnectionInfoInTitle: boolean;
readonly promptToSaveGeneratedFiles: boolean;
readonly githubCopilotContextualizationEnabled: boolean;
}
export interface IResultGridConfiguration {

View File

@@ -33,6 +33,7 @@ import { ITableDesignerService } from 'sql/workbench/services/tableDesigner/comm
import { IExecutionPlanService } from 'sql/workbench/services/executionPlan/common/interfaces';
import { extHostNamedCustomer, IExtHostContext } from 'vs/workbench/services/extensions/common/extHostCustomers';
import { SqlExtHostContext, SqlMainContext } from 'vs/workbench/api/common/extHost.protocol';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
/**
* Main thread class for handling data protocol management registration.
@@ -64,7 +65,8 @@ export class MainThreadDataProtocol extends Disposable implements MainThreadData
@IDataGridProviderService private _dataGridProviderService: IDataGridProviderService,
@IAdsTelemetryService private _telemetryService: IAdsTelemetryService,
@ITableDesignerService private _tableDesignerService: ITableDesignerService,
@IExecutionPlanService private _executionPlanService: IExecutionPlanService
@IExecutionPlanService private _executionPlanService: IExecutionPlanService,
@IServerContextualizationService private _serverContextualizationService: IServerContextualizationService
) {
super();
if (extHostContext) {
@@ -571,6 +573,14 @@ export class MainThreadDataProtocol extends Disposable implements MainThreadData
});
}
// Database server contextualization handler
public $registerServerContextualizationProvider(providerId: string, handle: number): void {
this._serverContextualizationService.registerProvider(providerId, <azdata.contextualization.ServerContextualizationProvider>{
generateServerContextualization: (ownerUri: string) => this._proxy.$generateServerContextualization(handle, ownerUri),
getServerContextualization: (ownerUri: string) => this._proxy.$getServerContextualization(handle, ownerUri)
});
}
// Connection Management handlers
public $onConnectionComplete(handle: number, connectionInfoSummary: azdata.ConnectionInfoSummary): void {
this._connectionManagementService.onConnectionComplete(handle, connectionInfoSummary);

View File

@@ -210,6 +210,12 @@ export class ExtHostDataProtocol extends ExtHostDataProtocolShape {
return rt;
}
$registerServerContextualizationProvider(provider: azdata.contextualization.ServerContextualizationProvider): vscode.Disposable {
let rt = this.registerProvider(provider, DataProviderType.ServerContextualizationProvider);
this._proxy.$registerServerContextualizationProvider(provider.providerId, provider.handle);
return rt;
}
// Capabilities Discovery handlers
override $getServerCapabilities(handle: number, client: azdata.DataProtocolClientCapabilities): Thenable<azdata.DataProtocolServerCapabilities> {
return this._resolveProvider<azdata.CapabilitiesProvider>(handle).getServerCapabilities(client);
@@ -963,4 +969,14 @@ export class ExtHostDataProtocol extends ExtHostDataProtocolShape {
public override $isExecutionPlan(handle: number, value: string): Thenable<azdata.executionPlan.IsExecutionPlanResult> {
return this._resolveProvider<azdata.executionPlan.ExecutionPlanProvider>(handle).isExecutionPlan(value);
}
// Database Server Contextualization API
public override $generateServerContextualization(handle: number, ownerUri: string): void {
this._resolveProvider<azdata.contextualization.ServerContextualizationProvider>(handle).generateServerContextualization(ownerUri);
}
public override $getServerContextualization(handle: number, ownerUri: string): Thenable<azdata.contextualization.GetServerContextualizationResult> {
return this._resolveProvider<azdata.contextualization.ServerContextualizationProvider>(handle).getServerContextualization(ownerUri);
}
}

View File

@@ -408,6 +408,10 @@ export function createAdsApiFactory(accessor: ServicesAccessor): IAdsExtensionAp
return extHostDataProvider.$registerExecutionPlanProvider(provider);
};
let registerServerContextualizationProvider = (provider: azdata.contextualization.ServerContextualizationProvider): vscode.Disposable => {
return extHostDataProvider.$registerServerContextualizationProvider(provider);
};
// namespace: dataprotocol
const dataprotocol: typeof azdata.dataprotocol = {
registerBackupProvider,
@@ -430,6 +434,7 @@ export function createAdsApiFactory(accessor: ServicesAccessor): IAdsExtensionAp
registerDataGridProvider,
registerTableDesignerProvider,
registerExecutionPlanProvider: registerExecutionPlanProvider,
registerServerContextualizationProvider: registerServerContextualizationProvider,
onDidChangeLanguageFlavor(listener: (e: azdata.DidChangeLanguageFlavorParams) => any, thisArgs?: any, disposables?: extHostTypes.Disposable[]) {
return extHostDataProvider.onDidChangeLanguageFlavor(listener, thisArgs, disposables);
},

View File

@@ -595,6 +595,14 @@ export abstract class ExtHostDataProtocolShape {
* Determines if the provided value is an execution plan and returns the appropriate file extension.
*/
$isExecutionPlan(handle: number, value: string): Thenable<azdata.executionPlan.IsExecutionPlanResult> { throw ni(); }
/**
* Generates server context.
*/
$generateServerContextualization(handle: number, ownerUri: string): void { throw ni(); }
/**
* Gets server context.
*/
$getServerContextualization(handle: number, ownerUri: string): Thenable<azdata.contextualization.GetServerContextualizationResult> { throw ni(); }
}
/**
@@ -687,6 +695,7 @@ export interface MainThreadDataProtocolShape extends IDisposable {
$registerDataGridProvider(providerId: string, title: string, handle: number): void;
$registerTableDesignerProvider(providerId: string, handle: number): Promise<any>;
$registerExecutionPlanProvider(providerId: string, handle: number): void;
$registerServerContextualizationProvider(providerId: string, handle: number): void;
$unregisterProvider(handle: number): Promise<any>;
$onConnectionComplete(handle: number, connectionInfoSummary: azdata.ConnectionInfoSummary): void;
$onIntelliSenseCacheComplete(handle: number, connectionUri: string): void;

View File

@@ -421,7 +421,8 @@ export enum DataProviderType {
SqlAssessmentServicesProvider = 'SqlAssessmentServicesProvider',
DataGridProvider = 'DataGridProvider',
TableDesignerProvider = 'TableDesignerProvider',
ExecutionPlanProvider = 'ExecutionPlanProvider'
ExecutionPlanProvider = 'ExecutionPlanProvider',
ServerContextualizationProvider = 'ServerContextualizationProvider'
}
export enum DeclarativeDataType {

View File

@@ -18,6 +18,8 @@ import { FILE_QUERY_EDITOR_TYPEID } from 'sql/workbench/common/constants';
import { IInstantiationService } from 'vs/platform/instantiation/common/instantiation';
import { EditorInput } from 'vs/workbench/common/editor/editorInput';
import { IResourceEditorInput } from 'vs/platform/editor/common/editor';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
export class FileQueryEditorInput extends QueryEditorInput {
@@ -30,9 +32,11 @@ export class FileQueryEditorInput extends QueryEditorInput {
@IConnectionManagementService connectionManagementService: IConnectionManagementService,
@IQueryModelService queryModelService: IQueryModelService,
@IConfigurationService configurationService: IConfigurationService,
@IInstantiationService instantiationService: IInstantiationService
@IInstantiationService instantiationService: IInstantiationService,
@IServerContextualizationService serverContextualizationService: IServerContextualizationService,
@IExtensionService extensionService: IExtensionService
) {
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService);
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService, extensionService);
}
public override resolve(): Promise<ITextFileEditorModel | BinaryEditorModel> {

View File

@@ -22,6 +22,8 @@ import { IUntitledQueryEditorInput } from 'sql/workbench/common/editor/query/unt
import { IEditorResolverService } from 'vs/workbench/services/editor/common/editorResolverService';
import { Uri } from 'vscode';
import { ILogService } from 'vs/platform/log/common/log';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
export class UntitledQueryEditorInput extends QueryEditorInput implements IUntitledQueryEditorInput {
@@ -36,9 +38,11 @@ export class UntitledQueryEditorInput extends QueryEditorInput implements IUntit
@IConfigurationService configurationService: IConfigurationService,
@IInstantiationService instantiationService: IInstantiationService,
@ILogService private readonly logService: ILogService,
@IEditorResolverService private readonly editorResolverService: IEditorResolverService
@IEditorResolverService private readonly editorResolverService: IEditorResolverService,
@IServerContextualizationService serverContextualizationService: IServerContextualizationService,
@IExtensionService extensionService: IExtensionService
) {
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService);
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService, extensionService);
// Set the mode explicitely to stop the auto language detection service from changing the mode unexpectedly.
// the auto language detection service won't do the language change only if the mode is explicitely set.
// if the mode (e.g. kusto, sql) do not exist for whatever reason, we will default it to sql.

View File

@@ -20,6 +20,8 @@ import { AbstractTextResourceEditorInput } from 'vs/workbench/common/editor/text
import { IQueryEditorConfiguration } from 'sql/platform/query/common/query';
import { EditorInput } from 'vs/workbench/common/editor/editorInput';
import { IInstantiationService } from 'vs/platform/instantiation/common/instantiation';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
const MAX_SIZE = 13;
@@ -142,6 +144,8 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
private _state = this._register(new QueryEditorState());
public get state(): QueryEditorState { return this._state; }
private _serverContext: string[];
constructor(
private _description: string | undefined,
protected _text: AbstractTextResourceEditorInput,
@@ -149,7 +153,9 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
@IConnectionManagementService private readonly connectionManagementService: IConnectionManagementService,
@IQueryModelService private readonly queryModelService: IQueryModelService,
@IConfigurationService private readonly configurationService: IConfigurationService,
@IInstantiationService protected readonly instantiationService: IInstantiationService
@IInstantiationService protected readonly instantiationService: IInstantiationService,
@IServerContextualizationService private readonly serverContextualizationService: IServerContextualizationService,
@IExtensionService private readonly extensionService: IExtensionService
) {
super();
@@ -235,6 +241,27 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
public override isDirty(): boolean { return this._text.isDirty(); }
public get resource(): URI { return this._text.resource; }
public async getServerContext(): Promise<string[]> {
const copilotExt = await this.extensionService.getExtension('github.copilot');
if (copilotExt && this.configurationService.getValue<IQueryEditorConfiguration>('queryEditor').githubCopilotContextualizationEnabled) {
if (!this._serverContext) {
const result = await this.serverContextualizationService.getServerContextualization(this.uri);
// TODO lewissanchez - Remove this from here once Copilot starts pulling context. That isn't implemented yet, so
// getting scripts this way for now.
this._serverContext = result.context;
return this._serverContext;
}
else {
return this._serverContext;
}
}
else {
return Promise.resolve([]);
}
}
public override getName(longForm?: boolean): string {
if (this.configurationService.getValue<IQueryEditorConfiguration>('queryEditor').showConnectionInfoInTitle) {
let profile = this.connectionManagementService.getConnectionProfile(this.uri);

View File

@@ -467,7 +467,7 @@ suite('commandLineService tests', () => {
let uri = URI.file(args._[0]);
const workbenchinstantiationService = workbenchInstantiationService();
const editorInput = workbenchinstantiationService.createInstance(FileEditorInput, uri, undefined, undefined, undefined, undefined, undefined, undefined);
const queryInput = new FileQueryEditorInput(undefined, editorInput, undefined, connectionManagementService.object, querymodelService.object, configurationService.object, workbenchinstantiationService);
const queryInput = new FileQueryEditorInput(undefined, editorInput, undefined, connectionManagementService.object, querymodelService.object, configurationService.object, workbenchinstantiationService, undefined, undefined);
queryInput.state.connected = true;
const editorService: TypeMoq.Mock<IEditorService> = TypeMoq.Mock.ofType<IEditorService>(TestEditorService, TypeMoq.MockBehavior.Strict);
editorService.setup(e => e.editors).returns(() => [queryInput]);

View File

@@ -374,6 +374,11 @@ const queryEditorConfiguration: IConfigurationNode = {
'type': 'boolean',
'default': false,
'description': localize('queryEditor.promptToSaveGeneratedFiles', "Prompt to save generated SQL files")
},
'queryEditor.githubCopilotContextualizationEnabled': {
'type': 'boolean',
'default': false,
'description': localize('queryEditor.githubCopilotContextualizationEnabled', "(Preview) Enable contextualization of queries for GitHub Copilot. This setting helps GitHub Copilot to return improved suggestions, if the Copilot extension is installed and providers have implemented contextualization.")
}
}
};

View File

@@ -316,6 +316,8 @@ suite('SQL QueryEditor Tests', () => {
configurationService.object,
testinstantiationService,
undefined,
undefined,
undefined,
undefined
);
});

View File

@@ -0,0 +1,43 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { createDecorator } from 'vs/platform/instantiation/common/instantiation';
export const SERVICE_ID = 'serverContextualizationService';
export const IServerContextualizationService = createDecorator<IServerContextualizationService>(SERVICE_ID);
export interface IServerContextualizationService {
_serviceBrand: undefined;
/**
* Register a server contextualization service provider
*/
registerProvider(providerId: string, provider: azdata.contextualization.ServerContextualizationProvider): void;
/**
* Unregister a server contextualization service provider
*/
unregisterProvider(providerId: string): void;
/**
* Gets a registered server contextualization service provider. An exception is thrown if a provider isn't registered with the specified ID
* @param providerId The ID of the registered provider
*/
getProvider(providerId: string): azdata.contextualization.ServerContextualizationProvider;
/**
* Generates server context
* @param ownerUri The URI of the connection to generate context for.
*/
generateServerContextualization(ownerUri: string): void;
/**
* Gets all database context.
* @param ownerUri The URI of the connection to get context for.
*/
getServerContextualization(ownerUri: string): Promise<azdata.contextualization.GetServerContextualizationResult>;
}

View File

@@ -0,0 +1,94 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { invalidProvider } from 'sql/base/common/errors';
import { IConnectionManagementService, IConnectionParams } from 'sql/platform/connection/common/connectionManagement';
import { IQueryEditorConfiguration } from 'sql/platform/query/common/query';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { Disposable } from 'vs/base/common/lifecycle';
import { IConfigurationService } from 'vs/platform/configuration/common/configuration';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
export class ServerContextualizationService extends Disposable implements IServerContextualizationService {
public _serviceBrand: undefined;
private _providers = new Map<string, azdata.contextualization.ServerContextualizationProvider>();
constructor(
@IConnectionManagementService private readonly _connectionManagementService: IConnectionManagementService,
@IConfigurationService private readonly _configurationService: IConfigurationService,
@IExtensionService private readonly _extensionService: IExtensionService
) {
super();
this._register(this._connectionManagementService.onConnect(async (e: IConnectionParams) => {
const copilotExt = await this._extensionService.getExtension('github.copilot');
if (copilotExt && this._configurationService.getValue<IQueryEditorConfiguration>('queryEditor').githubCopilotContextualizationEnabled) {
const ownerUri = e.connectionUri;
await this.generateServerContextualization(ownerUri);
}
}));
}
/**
* Register a server contextualization service provider
*/
public registerProvider(providerId: string, provider: azdata.contextualization.ServerContextualizationProvider): void {
if (this._providers.has(providerId)) {
throw new Error(`A server contextualization provider with ID "${providerId}" is already registered`);
}
this._providers.set(providerId, provider);
}
/**
* Unregister a server contextualization service provider.
*/
public unregisterProvider(providerId: string): void {
this._providers.delete(providerId);
}
/**
* Gets a registered server contextualization service provider. An exception is thrown if a provider isn't registered with the specified ID.
* @param providerId The ID of the registered provider.
*/
public getProvider(providerId: string): azdata.contextualization.ServerContextualizationProvider {
const provider = this._providers.get(providerId);
if (provider) {
return provider;
}
throw invalidProvider(providerId);
}
/**
* Generates server context
* @param ownerUri The URI of the connection to generate context for.
*/
public generateServerContextualization(ownerUri: string): void {
const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri);
const handler = this.getProvider(providerName);
if (handler) {
handler.generateServerContextualization(ownerUri);
}
}
/**
* Gets all database context.
* @param ownerUri The URI of the connection to get context for.
*/
public async getServerContextualization(ownerUri: string): Promise<azdata.contextualization.GetServerContextualizationResult> {
const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri);
const handler = this.getProvider(providerName);
if (handler) {
return await handler.getServerContextualization(ownerUri);
}
else {
return Promise.resolve({
context: []
});
}
}
}