Training Client Python SDK
class training_client.TrainingClient(base_url: str = 'http://localhost:8321')
Bases: object
Client for the NuPIC Training Server.
__init__(base_url: str = 'http://localhost:8321')
Initialize the API client with the base URL.
- Parameters:
base_url (str , optional) – base URL of the API, defaults to “http://localhost:8321”
download_data()
Ask Training Module to download training data.
dataset_name and datapoints_split are provided in the
training config object.
download_model()
Download the trained model.
- Returns:
path to the model file, or None if the model is not ready. - Return type:
str
download_results(output='./results.csv')
Download the prediction for the test set.
- Parameters:
output (str) – file path to store the results. Default: ./results.csv - Returns:
path to the result file. - Return type:
str
follow_download()
Follow the dataset download process.
follow_import()
Follow the import process.
follow_preprocess()
Follow the preprocessing process.
follow_training()
Follow the training process.
get_import_status()
Get the import status.
get_imported_model(on_disk=True, tokenizer=True)
Download the imported model.
- Parameters:
- on_disk (bool , optional) – True make the server store the model in its models folder,
False to download the model. Defaults to False. - tokenizer (bool , optional) – True to download the tokenizer, False otherwise.
- on_disk (bool , optional) – True make the server store the model in its models folder,
- Returns:
the model file path. - Return type:
str
get_parameters()
Get the training parameters.
- Returns:
parameters for the current training session. - Return type:
dict
get_status()
Get the training status.
- Returns:
current training status. - Return type:
dict
import_model(model_type: str, model_name: str, sequence_length: int = 256, tokenizer: str = 'True')
Import a model from the model registry.
- Parameters:
- model_type (str) – type of the model. Currently support “huggingface”
- model_name (str) – name of the model to import
- sequence_length (int , optional) – sequence length, defaults to 256
- tokenizer (bool) – whether to import the model’s tokenizer
- Returns:
the server response. - Return type:
dict
monitor()
Get the training monitor.
- Returns:
training monitor showing progress - Return type:
list
preprocess_data()
Preprocess the training data.
- Returns:
the server response. - Return type:
dict
static results_accuracy(results_path: str)
Get the accuracy from the results file.
- Parameters:
results_path (str) – path to the results file. - Returns:
accuracy - Return type:
float
set_parameters(parameters: dict)
Set the training parameters.
- Parameters:
parameters (dict) – training parameters. - Returns:
the server response. - Return type:
dict
start_session()
Start a new session and store and return the key.
- Returns:
training session key - Return type:
str
start_training()
Start training.
- Returns:
the server response. - Return type:
dict
stop_training()
Stop training.
- Returns:
the server response. - Return type:
dict
upload_data(train_path: str, test_path)
Upload the training data.
- Parameters:
- train_path (str) – path to the training data to upload.
- test_path (str) – path to the test data to upload.
- Returns:
the server response. - Return type:
dict
use_wandb(wandb_api_key: str, project_name: str = None, upload_files: bool = False)
Use wandb to log the training process.
- Parameters:
- wandb_api_key (str) – API key for WandB.
- project_name (str , optional) – WandB project name, defaults to None
- upload_files (bool , optional) – upload files to WandB, defaults to False
Updated 5 months ago