如果使用SageMaker TensorFlow Estimator API作为训练脚本入口,并在输入模型函数中使用TensorFlow Serving作为输出模型函数,请检查请求API的输入格式,因为这可能导致API接收到的输入格式与输出模型的编码格式不匹配。
以下是示例代码:
import tensorflow.compat.v1 as tf
def input_fn(request_body, request_content_type):
# Parse request_body here and return a TensorFlow tensor
# Example code:
if request_content_type == 'text/csv':
input_data = tf.constant(request_body, dtype=tf.float32)
elif request_content_type == 'application/json':
input_data = tf.constant(json.loads(request_body), dtype=tf.float32)
else:
raise ValueError("Invalid content-type: {}".format(request_content_type))
return {'inputs': input_data}
def output_fn(prediction, response_content_type):
# Convert prediction to appropriate response_content_type.
# Example code:
if response_content_type == 'text/csv':
output_data = prediction.astype(dtype=np.str)
elif response_content_type == 'application/json':
output_data = {'output': prediction.tolist()}
else:
raise ValueError('Unsupported content type "{}"'.format(response_content_type))
return output_data
在本例中,如果请求的请求内容类型为text/csv
,则传入的输入数据将以浮点张量的形式传递。如果请求的请求内容类型为application/json
,则将传入的输入数据转换为浮点张量。请注意,SageMaker TensorFlow Estimator API将这些数据提供给 TensorFlow 服务的方式。
如果您的输入模型函数在转换请求数据时使用了不同的编码(比如ASCII而不是UTF-8),则API可能会抛出一个编码不匹配的错误。在这种情况下,请将您的转换方法更改为正确的编码。