Guides
Log In
Guides

Training Client Python SDK

class training_client.TrainingClient(base_url: str = 'http://localhost:8321')

Bases: object

Client for the NuPIC Training Server.

__init__(base_url: str = 'http://localhost:8321')

Initialize the API client with the base URL.

download_data()

Ask Training Module to download training data.
dataset_name and datapoints_split are provided in the
training config object.

download_model()

Download the trained model.

  • Returns:
    path to the model file, or None if the model is not ready.
  • Return type:
    str

download_results(output='./results.csv')

Download the prediction for the test set.

  • Parameters:
    output (str) – file path to store the results. Default: ./results.csv
  • Returns:
    path to the result file.
  • Return type:
    str

follow_download()

Follow the dataset download process.

follow_import()

Follow the import process.

follow_preprocess()

Follow the preprocessing process.

follow_training()

Follow the training process.

get_import_status()

Get the import status.

get_imported_model(on_disk=True, tokenizer=True)

Download the imported model.

  • Parameters:
    • on_disk (bool , optional) – True make the server store the model in its models folder,
      False to download the model. Defaults to False.
    • tokenizer (bool , optional) – True to download the tokenizer, False otherwise.
  • Returns:
    the model file path.
  • Return type:
    str

get_parameters()

Get the training parameters.

  • Returns:
    parameters for the current training session.
  • Return type:
    dict

get_status()

Get the training status.

  • Returns:
    current training status.
  • Return type:
    dict

import_model(model_type: str, model_name: str, sequence_length: int = 256, tokenizer: str = 'True')

Import a model from the model registry.

  • Parameters:
    • model_type (str) – type of the model. Currently support “huggingface”
    • model_name (str) – name of the model to import
    • sequence_length (int , optional) – sequence length, defaults to 256
    • tokenizer (bool) – whether to import the model’s tokenizer
  • Returns:
    the server response.
  • Return type:
    dict

monitor()

Get the training monitor.

  • Returns:
    training monitor showing progress
  • Return type:
    list

preprocess_data()

Preprocess the training data.

  • Returns:
    the server response.
  • Return type:
    dict

static results_accuracy(results_path: str)

Get the accuracy from the results file.

  • Parameters:
    results_path (str) – path to the results file.
  • Returns:
    accuracy
  • Return type:
    float

set_parameters(parameters: dict)

Set the training parameters.

  • Parameters:
    parameters (dict) – training parameters.
  • Returns:
    the server response.
  • Return type:
    dict

start_session()

Start a new session and store and return the key.

  • Returns:
    training session key
  • Return type:
    str

start_training()

Start training.

  • Returns:
    the server response.
  • Return type:
    dict

stop_training()

Stop training.

  • Returns:
    the server response.
  • Return type:
    dict

upload_data(train_path: str, test_path)

Upload the training data.

  • Parameters:
    • train_path (str) – path to the training data to upload.
    • test_path (str) – path to the test data to upload.
  • Returns:
    the server response.
  • Return type:
    dict

use_wandb(wandb_api_key: str, project_name: str = None, upload_files: bool = False)

Use wandb to log the training process.

  • Parameters:
    • wandb_api_key (str) – API key for WandB.
    • project_name (str , optional) – WandB project name, defaults to None
    • upload_files (bool , optional) – upload files to WandB, defaults to False