要在AWS SageMaker训练脚本中传递自定义用户参数,可以使用argparse模块来解析命令行参数。以下是一个示例:
import argparse
# 创建argparse解析器
parser = argparse.ArgumentParser()
# 添加自定义参数
parser.add_argument('--param1', type=str, default='default_value1', help='参数1的帮助信息')
parser.add_argument('--param2', type=int, default=0, help='参数2的帮助信息')
# 解析命令行参数
args = parser.parse_args()
# 使用参数
print(f'参数1: {args.param1}')
print(f'参数2: {args.param2}')
将上述代码保存为一个.py文件,例如train.py
。在SageMaker训练作业中,可以通过指定hyperparameters
参数传递自定义参数。例如:
from sagemaker.estimator import Estimator
# 创建Estimator对象
estimator = Estimator(role=role,
train_instance_count=1,
train_instance_type='ml.c4.xlarge',
image_name='image_name',
hyperparameters={
'param1': 'custom_value1',
'param2': 123
})
# 启动训练作业
estimator.fit(inputs)
在上述示例中,param1
和param2
是自定义参数,可以根据需要进行调整。在创建Estimator对象时,通过hyperparameters
参数传递自定义参数的值。在训练脚本中使用args.param1
和args.param2
即可获取传递的参数值。