Inference Client Python SDK
class inference_client.BaseClient(model: str, url: str, protocol: str, max_batch_size: int = 64, concurrency: int = 1, connection_config: dict | None = None)
Bases: ABC
Base class to generate clients that communicate with a triton server.
- Parameters:
ABC – Abstract class from which to inherit abstract methods property. - Raises:
- TypeError – Raise type error if the protocol given is not supported
by the server. - Exception – Raise Exception if connection error occurs while querying
the server
- TypeError – Raise type error if the protocol given is not supported
__init__(model: str, url: str, protocol: str, max_batch_size: int = 64, concurrency: int = 1, connection_config: dict | None = None)
Initialize a BaseClient object, and load model into server.
- Parameters:
- model (str) – Name of the model to be loaded into the inference server.
- url (str) – URL of the connection with the inference server.
- protocol (str) – Protocol with which to connect with the inference server.
May be “http”, “https” or “grpc”. - connection_config (dict , optional) – Dictionary containing connection specific
configurations, such as SSL protocol certificates.
root_certificates (str): path to SSL root certificate.
private_key (str): path to SSL private key.
certificate_chain (str): path to SSL certificate chain.
- Raises:
- ValueError – Raise value error if the number of concurrent connections
is invalid. - Exception – Raise Exception if connection error ocurrs while querying
the server
- ValueError – Raise value error if the number of concurrent connections
adapt_params(model_parameters: dict)
Convert float model parameters to strings.
This is so that they can be sent through tritonclient’s inference functions
to the model
get_triton_client(model: str, url: str, protocol: str, connection_config: dict)
Get the triton client to use for the given protocol.
abstract infer(sentences: list)
Query the server to obtain the embeddings of a list of strings.
- Parameters:
sentences (list) – List of N strings to obtain embeddings from. - Returns:
Numpy array of shape (N, M) containing the embedding for each
of the sentences in each of the rows of the array.
N: number of sentences embedded.
M: number of features of the embedding contains the embedding for each
of the sentences in each of the rows of the array. - Return type:
np.ndarray
unload_model()
Unload model from server.
class inference_client.ClientFactory
Bases: object
Implements the Factory pattern for the inference client generation.
- Raises:
TypeError – Raise Type error if the tokenizer_config dict is not given, or
if the protocol given is not supported by the server. - Returns:
Client object that inherits from BaseClient - Return type:
class
static get_client(model: str, url: str, protocol: str, connection_config: dict | None = None, max_batch_size: int = 64, concurrency: int = 1, tokenizer_config: dict | None = None, logging_level: int | str = 30)
Return the appropriate client for the model inputed.
Clients include TextInferClient, TokensInferClient or StreamingClient.
- Parameters:
- model (str) – Name of the model to be loaded into the inference server.
- url (str) – URL of the connection with the inference server.
- protocol (str) – Protocol with which to connect with the inference server.
May be “http”, “https” or “grpc”. - connection_config (dict , optional) – Dictionary containing connection specific
configurations, such as SSL protocol certificates.
root_certificates (str): path to SSL root certificate.
private_key (str): path to SSL private key.
certificate_chain (str): path to SSL certificate chain. - max_batch_size (int , optional) – Maximum batch size to use for inference. Defaults
to 64. - concurrency (int , optional) – Number of concurrent connections to the server.
Defaults to 1. - tokenizer_config (dict , optional) – Dictionary containing specific configurations
for the tokenizer to use:
path (str): Path for the pretrained transformers tokenizer. Raises
TypeError or KeyError if not present.
max_tokens_per_string (int, optional): Maximum number of tokens the
tokenizer
should return per string given to the tokenizer. Defaults to 64. - logging_level (Union [**int , str ] , optional) – Logging level for the client. Defaults to
logging.WARNING.
- Raises:
- TypeError – Raise key error if the tokenizer path is not given
in the dictionary. - TypeError – Raise Type error if the tokenizer_config dict is not
given, or if the protocol given is not supported by the server.
- TypeError – Raise key error if the tokenizer path is not given
- Returns:
Client object that inherits from BaseClient - Return type:
class
class inference_client.StreamingClient(model_name: str, url: str)
Bases: BaseClient
Client for streaming inference requests to a Triton server.
__init__(model_name: str, url: str)
Initialize a StreamingClient object, and load model into server.
- Parameters:
- model_name (str) – Name of the model to be loaded into the inference server.
- url (str) – URL of the connection with the inference server.
generate(prompt: list, callback: Callable, inference_parameters: dict | None = None)
Generate a response from the model given a prompt.
Each partial response is sent to the callback function as a string.
- Parameters:
- prompt (str) – The prompt to generate a response from.
- inference_parameters (dict , optional) – Dictionary containing optional parameters
to be used for inference by the model at the inference server. These
are model dependant; see examples for GPT models.
min_new_tokens (int) = 0: Minimum number of tokens to generate by the
model.
max_new_tokens (int) = 512: Maximum number of tokens to generate by
the model.
do_sample (bool) = True: Whether or not to use sampling ; use greedy
decoding otherwise.
temperature (float) = 1.0: The value used to modulate the next token
probabilities.
top_k (int) = 50: The number of highest probability vocabulary tokens
to keep for top-k-filtering.
top_p (float) = 1.0: If set to float < 1, only the smallest set of
most probable tokens with probabilities that add up to top_p or higher
are kept for generation.
repetition_penalty (float) = 1.0: The parameter for repetition penalty.
1.0 means no penalty. - callback (function) – The function to call with the generated text.
infer(sentences)
Inference is not supported for streaming clients.
is_generating()
Check if the generation process is still running.
If not, close the client.
class inference_client.TextInferClient(model: str, url: str, protocol: str, max_batch_size: int = 64, concurrency: int = 1, connection_config: dict | None = None)
Bases: BaseClient
Class to send text queries to inference models in the server.
This requires models that are bundled with a tokenizer.
- Parameters:
BaseClient (class) – Base class to generate clients that communicate with a
triton server.
__init__(model: str, url: str, protocol: str, max_batch_size: int = 64, concurrency: int = 1, connection_config: dict | None = None)
Initialize a BaseClient object, and load model into server.
- Parameters:
- model (str) – Name of the model to be loaded into the inference server.
- url (str) – URL of the connection with the inference server.
- protocol (str) – Protocol with which to connect with the inference server.
May be “http”, “https” or “grpc”. - connection_config (dict , optional) – Dictionary containing connection specific
configurations, such as SSL protocol certificates.
root_certificates (str): path to SSL root certificate.
private_key (str): path to SSL private key.
certificate_chain (str): path to SSL certificate chain.
- Raises:
- TypeError – Raise type error if the protocol given is not supported
by the server. - Exception – Raise Exception if connection error ocurrs while querying
the server.
- TypeError – Raise type error if the protocol given is not supported
infer(sentences: list, inference_parameters: dict | None = None)
Query the server to obtain the embeddings of a list of strings.
- Parameters:
- sentences (list) – List of N strings to obtain embeddings from.
- inference_parameters (dict , optional) – Dictionary containing optional parameters to
be used for inference by the model at the inference server. These are
model dependant; see examples for GPT models.
min_new_tokens (int) = 0: Minimum number of tokens to generate by the
model.
max_new_tokens (int) = 512: Maximum number of tokens to generate by
the model.
do_sample (bool) = True: Whether or not to use sampling ; use greedy
decoding otherwise.
temperature (float) = 1.0: The value used to modulate the next token
probabilities.
top_k (int) = 50: The number of highest probability vocabulary tokens
to keep for top-k-filtering.
top_p (float) = 1.0: If set to float < 1, only the smallest set of
most probable tokens with probabilities that add up to top_p or higher
are kept for generation.
repetition_penalty (float) = 1.0: The parameter for repetition penalty.
1.0 means no penalty.
- Returns:
Numpy array of shape (N, M) containing the embedding for each
of the sentences in each of the rows of the array.
N: number of sentences embedded.
M: number of features of the embedding contains the embedding for each
of the rows of the array. - Return type:
np.ndarray
class inference_client.TokensInferClient(model: str, url: str, protocol: str, max_batch_size: int = 64, concurrency: int = 1, connection_config: dict | None = None, tokenizer_config: dict | None = None)
Bases: BaseClient
Class to query the token-based inference models in the server.
This requires client-side tokenizer; i.e., texts must be tokenized before
sending them to the server.
- Parameters:
BaseClient (class) – Base class to generate clients that communicate with a
triton server.
__init__(model: str, url: str, protocol: str, max_batch_size: int = 64, concurrency: int = 1, connection_config: dict | None = None, tokenizer_config: dict | None = None)
Initialize a BaseClient object, and load model into server.
Also initializes the tokenizer to use in the infer function.
- Parameters:
- model (str) – Name of the model to be loaded into the inference server.
- url (str) – URL of the connection with the inference server.
- protocol (str) – Protocol with which to connect with the inference server.
May be “http”, “https” or “grpc”. - connection_config (dict , optional) – Dictionary containing connection specific
configurations, such as SSL protocol certificates.
root_certificates (str): path to SSL root certificate.
private_key (str): path to SSL private key.
certificate_chain (str): path to SSL certificate chain. - tokenizer_config (dict) – Dictionary containing specific configurations
for the tokenizer to use:
path (str): Path for the pretrained transformers tokenizer. Raises
TypeError or KeyError if not present.
max_tokens_per_string (int, optional): Maximum number of tokens the
tokenizer should return per string given to the tokenizer. Defaults
to 64.
- Raises:
- TypeError – Raise key error if the tokenizer path is not given in
the dictionary. - TypeError – Raise Type error if the tokenizer_config dict is not
given, or if the protocol given is not supported by the server. - Exception – Raise Exception if connection error ocurrs while querying
the server.
- TypeError – Raise key error if the tokenizer path is not given in
infer(sentences: list, inference_parameters: dict | None = None)
Query the server to obtain the embeddings of a list of strings.
- Parameters:
- sentences (list) – List of N strings to obtain embeddings from.
- inference_parameters (dict , optional) – Dictionary containing optional parameters
to be used for inference by the model at the inference server. These
are model dependant; see examples for GPT models:
min_new_tokens (int) = 0: Minimum number of tokens to generate by the
model.
max_new_tokens (int) = 512: Maximum number of tokens to generate by
the model.
do_sample (bool) = True: Whether or not to use sampling ; use greedy
decoding otherwise.
temperature (float) = 1.0: The value used to modulate the next token
probabilities.
top_k (int) = 50: The number of highest probability vocabulary tokens
to keep for top-k-filtering.
top_p (float) = 1.0: If set to float < 1, only the smallest set of
most probable tokens with probabilities that add up to top_p or higher
are kept for generation.
repetition_penalty (float) = 1.0: The parameter for repetition penalty.
1.0 means no penalty.
- Returns:
Numpy array of shape (N, M) containing the embedding for each
of the sentences in each of the rows of the array.
N: number of sentences embedded.
M: number of features of the embedding contains the embedding for each
of the sentences in each of the rows of the array. - Return type:
np.ndarray
infer_from_tokens(tokens: dict, inference_parameters: dict | None = None)
Query the server to obtain the embeddings of a set of precomputed tokens.
-
Parameters:
-
tokens (dict) –
Dictionary of tokens ready for formatting and sending to the
inference server. The structure of the dict should be the following:
{”input_ids”: NxP tensor,
“attention_mask”: NxP tensor,
“token_type_ids”: NxP tensor,}
N: number of tokenized sentences to embed
P: max number of tokens for all the tokenized sentences -
inference_parameters (dict , optional) – Dictionary containing optional parameters
to be used for inference by the model at the inference server. These
are model dependant; see examples for GPT models.
min_new_tokens (int) = 0: Minimum number of tokens to generate by the
model.
max_new_tokens (int) = 512: Maximum number of tokens to generate by
the model.
do_sample (bool) = True: Whether or not to use sampling ; use greedy
decoding otherwise.
temperature (float) = 1.0: The value used to modulate the next token
probabilities.
top_k (int) = 50: The number of highest probability vocabulary tokens
to keep for top-k-filtering.
top_p (float) = 1.0: If set to float < 1, only the smallest set of
most probable tokens with probabilities that add up to top_p or
higher are kept for generation.
repetition_penalty (float) = 1.0: The parameter for repetition penalty.
1.0 means no penalty.
-
-
Returns:
Numpy array of shape (N, M) containing the embedding for each
of the sentences in each of the rows of the array.
N: number of sentences embedded.
M: number of features of the embedding contains the embedding for
each of the sentences in each of the rows of the array. -
Return type:
np.ndarray
Updated 5 months ago