予測モデルのオンライン予測を取得する

Vertex AI には、トレーニング済みの予測モデルを使用して将来の値を予測するために、オンライン予測とバッチ予測の 2 つのオプションがあります。

オンライン予測は同期リクエストです。アプリケーションの入力に応じてリクエストを送信する場合、またはタイムリーな推定が必要となる状況でリクエストを送信する場合は、オンライン予測を使用します。

バッチ予測リクエストは非同期リクエストです。即時のレスポンスが必要なく、累積されたデータを 1 回のリクエストで処理する場合は、バッチ予測を使用します。

このページでは、オンライン予測を使用して将来の値を予測する方法について説明します。バッチ予測を使用して値を予測する方法については、予測モデルのバッチ予測を取得するをご覧ください。

予測に使用するには、モデルをエンドポイントにデプロイする必要があります。エンドポイントは物理リソースのセットです。

予測の代わりに説明をリクエストすることができます。説明のローカル特徴量の重要度の値は、各特徴量が予測結果に及ぼした影響の度合いを示します。コンセプトの概要については、予測の特徴アトリビューションをご覧ください。

オンライン予測の料金については、Tabular Workflows の料金をご覧ください。

始める前に

オンライン予測リクエストを行うには、まずモデルをトレーニングする必要があります。

エンドポイントを作成または選択する

関数 aiplatform.Endpoint.create() を使用してエンドポイントを作成します。エンドポイントがすでにある場合は、aiplatform.Endpoint() 関数を使用してエンドポイントを選択します。

次のコードで例を示します。

# Import required modules
from google.cloud import aiplatform
from google.cloud.aiplatform import models

PROJECT_ID = "PROJECT_ID"
REGION = "REGION"

# Initialize the Vertex SDK for Python for your project.
aiplatform.init(project=PROJECT_ID, location=REGION)
endpoint = aiplatform.Endpoint.create(display_name='ENDPOINT_NAME')

次のように置き換えます。

  • PROJECT_ID: プロジェクト ID。
  • REGION: Vertex AI を使用するリージョン。
  • ENDPOINT_NAME: エンドポイントの表示名。

トレーニング済みモデルを選択する

aiplatform.Model() 関数を使用して、トレーニング済みモデルを選択します。

# Create reference to the model trained ahead of time.
model_obj = models.Model("TRAINED_MODEL_PATH")

次のように置き換えます。

  • TRAINED_MODEL_PATH(例: projects/PROJECT_ID/locations/REGION/models/[TRAINED_MODEL_ID]

エンドポイントにモデルをデプロイする

deploy() 関数を使用して、モデルをエンドポイントにデプロイします。次のコードで例を示します。

deployed_model = endpoint.deploy(
    model_obj,
    machine_type='MACHINE_TYPE',
    traffic_percentage=100,
    min_replica_count='MIN_REPLICA_COUNT',
    max_replica_count='MAX_REPLICA_COUNT',
    sync=True,
    deployed_model_display_name='DEPLOYED_MODEL_NAME',
)

次のように置き換えます。

  • MACHINE_TYPE(例: n1-standard-8)。マシンタイプの詳細を確認してください。
  • MIN_REPLICA_COUNT: このデプロイの最小ノード数。ノード数は、予測負荷に応じてノードの最大数まで増減できますが、この数より少なくなることはありません。1 以上の値を指定してください。min_replica_count 変数が設定されていない場合、値はデフォルトで 1 に設定されます。
  • MAX_REPLICA_COUNT: このデプロイの最大ノード数。ノード数は、予測負荷に応じてこのノード数まで増減できますが、最小ノード数より少なくなることはありません。max_replica_count 変数を設定しない場合、ノードの最大数は min_replica_count の値に設定されます。
  • DEPLOYED_MODEL_NAME: DeployedModel の名前。DeployedModelModel の表示名を使用することもできます。

モデルのデプロイに約 10 分を要する場合があります。

オンライン予測を取得する

予測を取得するには、predict() 関数を使用して、1 つ以上の入力インスタンスを指定します。次のコードは例を示しています。

predictions = endpoint.predict(instances=[{...}, {...}])

各入力インスタンスは、モデルのトレーニングに使用されたスキーマと同じスキーマを持つ Python 辞書です。時間列に対応する「予測時に利用可能」の Key-Value ペアと、ターゲット予測列の過去の値を含む「予測時に使用不可」の Key-Value ペアが含まれている必要があります。Vertex AI では、各入力インスタンスが単一の時系列に属していることが想定されます。インスタンス内の Key-Value ペアの順序は重要ではありません。

入力インスタンスには次の制約があります。

  • 「予測時に利用可能」な Key-Value ペアのデータポイントの数は、すべて同じであることが必要です。
  • 「予測時に利用不可」の Key-Value ペアのデータポイントの数は、すべて同じであることが必要です。
  • 「予測時に利用可能」な Key-Value ペアには、少なくとも「予測時に利用不可」の Key-Value ペアと同じ数のデータポイントが設定されている必要があります。

予測に使用される列の種類の詳細については、特徴タイプと予測時の可用性をご覧ください。

次のコードは、2 つの入力インスタンスのセットを示しています。Category 列には属性データが含まれています。Timestamp 列には、予測時に利用可能なデータが含まれています。3 つのポイントはコンテキスト データ、2 つのポイントはホライズン データです。Sales 列には、予測時に使用できないデータが含まれています。3 つのポイントはすべてコンテキスト データです。予測でコンテキストとホライズンがどのように使用されるかについては、予測ホライズン、コンテキスト ウィンドウ、予測ウィンドウをご覧ください。

instances=[
  {
    # Attribute
    "Category": "Electronics",
    # Available at forecast: three days of context, two days of horizon
    "Timestamp": ['2023-08-03', '2023-08-04', '2023-08-05', '2023-08-06', '2023-08-07'],
    # Unavailable at forecast: three days of context
    "Sales": [490.50, 325.25, 647.00],
  },
  {
    # Attribute
    "Category": "Food",
    # Available at forecast: three days of context, two days of horizon
    "Timestamp": ['2023-08-03', '2023-08-04', '2023-08-05', '2023-08-06', '2023-08-07'],
    # Unavailable at forecast: three days of context
    "Sales": [190.50, 395.25, 47.00],
  }
])

各インスタンスについて、Vertex AI は 2 つのホライズン タイムスタンプ(「2023-08-06」と「2023-08-07」)に対応する Sales の 2 つの予測を返します。

最適なパフォーマンスを得るには、各入力インスタンスのコンテキスト データポイントの数とホライズン データポイントの数をモデルをトレーニングした際のコンテキストおよびホライズンの長さと一致させる必要があります。不一致がある場合、Vertex AI はモデルのサイズに合わせてインスタンスをパディングするか、切り捨てます。

入力インスタンス内のコンテキスト データポイントの数が、モデルのトレーニングに使用されるコンテキスト データポイントの数よりも少ない、または多い場合は、このポイント数が、すべての予測時に利用可能な Key-Value ペアと、予測時に利用不可の Key-Value ペアすべてで一貫していることを確認します。

たとえば、4 日間のコンテキスト データと 2 日間のホライズン データでトレーニングされたモデルについて考えてみましょう。予測リクエストは、わずか 3 日間のコンテキスト データを使用して行うことができます。この場合、予測時に使用不可の Key-Value ペアには 3 つの値が含まれます。予測時に利用可能な Key-Value ペアには 5 つの値が含まれている必要があります。

オンライン予測の出力

Vertex AI は、value フィールドにオンライン予測の出力を提供します。

{
  'value': [...]
}

予測レスポンスの長さは、モデルのトレーニングで使用されるホライズンと入力インスタンスのホライズンによって異なります。予測レスポンスの長さは、この 2 つの値の最小値です。

以下の例を考えてみましょう。

  • context = 15horizon = 50 でモデルをトレーニングします。入力インスタンスには context = 15horizon = 20 が設定されています。予測レスポンスの長さは 20 です。
  • context = 15horizon = 50 でモデルをトレーニングします。入力インスタンスには context = 15horizon = 100 が設定されています。予測レスポンスの長さは 50 です。

TFT モデルのオンライン予測の出力

Temporal Fusion Transformer(TFT)でトレーニングされたモデルの場合、Vertex AI は value フィールドの予測に加えて、TFT 解釈可能性 tft_feature_importance を提供します。

{
  "tft_feature_importance": {
    "attribute_weights": [...],
    "attribute_columns": [...],
    "context_columns": [...],
    "context_weights": [...],
    "horizon_weights": [...],
    "horizon_columns": [...]
  },
  "value": [...]
}
  • attribute_columns: 時間不変である予測特徴。
  • attribute_weights: 各 attribute_columns に関連付けられた重み。
  • context_columns: コンテキスト ウィンドウの値が TFT 長・短期記憶(LSTM)エンコーダへの入力として機能する予測特徴。
  • context_weights: 予測インスタンスの各 context_columns に関連付けられた特徴の重要度の重み。
  • horizon_columns: 予測ホライズン値が TFT 長・短期記憶(LSTM)デコーダへの入力として機能する予測特徴。
  • horizon_weights: 予測インスタンスの各 horizon_columns に関連付けられた特徴の重要度の重み。

分位点損失に最適化されたモデルのオンライン予測の出力

分位点損失に最適化されたモデルの場合、Vertex AI は次のオンライン予測出力を提供します。

{
  "value": [...],
  "quantile_values": [...],
  "quantile_predictions": [...]
}
  • value: 分位のセットに中央値が含まれている場合、value は中央値における予測値です。それ以外の場合、value はセット内の最小分位の予測値です。たとえば、分位のセットが [0.1, 0.5, 0.9] の場合、value は分位 0.5 の予測です。分位のセットが [0.1, 0.9] の場合、value は分位 0.1 の予測です。
  • quantile_values: モデルのトレーニング中に設定される分位の値。
  • quantile_predictions: quantile_values に関連付けられた予測値。

たとえば、ターゲット列が売上額であるモデルについて考えてみましょう。分位点の値は [0.1, 0.5, 0.9] として定義されます。Vertex AI は、分位予測 [4484, 5615, 6853] を返します。ここでは、分位のセットに中央値が含まれているため、value は分位 0.5 の予測です(5615)。分位予測は次のように解釈できます。

  • P(売上値 < 4484)= 10%
  • P(売上値 < 5615)= 50%
  • P(売上値 < 6,853)= 90%

確率的推論を使用するモデルのオンライン予測の出力

モデルで確率的推論を使用する場合、value フィールドには最適化目標の最小化値が含まれます。たとえば、最適化目標が minimize-rmse の場合、value フィールドには平均値が含まれます。minimize-mae の場合、value フィールドには中央値が含まれます。

モデルで分位数を使用した確率論的推論を使用する場合、Vertex AI は最適化目標の最小化値に加えて、分位点の値と予測も提供します。分位点の値はモデルのトレーニング時に設定されます。分位点の予測は、分位点の値に関連付けられた予測値です。

オンライン説明を取得する

説明を取得するには、explain() 関数を使用して、1 つ以上の入力インスタンスを指定します。次のコードは例を示しています。

explanations = endpoint.explain(instances=[{...}, {...}])

入力インスタンスの形式は、オンライン予測とオンライン説明で同じです。詳細については、オンライン予測を取得するをご覧ください。

特徴アトリビューションのコンセプトの概要については、予測の特徴アトリビューションをご覧ください。

オンライン説明の出力

次のコードは、説明の結果を出力する方法を示しています。

# Import required modules
import json
from google.protobuf import json_format

def explanation_to_dict(explanation):
  """Converts the explanation proto to a human-friendly json."""
  return json.loads(json_format.MessageToJson(explanation._pb))

for response in explanations.explanations:
  print(explanation_to_dict(response))

説明の結果の形式は次のとおりです。

{
  "attributions": [
    {
      "baselineOutputValue": 1.4194682836532593,
      "instanceOutputValue": 2.152980089187622,
      "featureAttributions": {
        ...
        "store_id": [
          0.007947325706481934
        ],
        ...
        "dept_id": [
          5.960464477539062e-08
        ],
        "item_id": [
          0.1100526452064514
        ],
        "date": [
          0.8525647521018982
        ],
        ...
        "sales": [
          0.0
        ]
      },
      "outputIndex": [
        2
      ],
      "approximationError": 0.01433318599207033,
      "outputName": "value"
    },
    ...
  ]
}

attributions 要素の数は、モデル トレーニングで使用されるホライズンと入力インスタンスのホライズンによって異なります。要素の数はこの 2 つの値の最小値です。

attributions 要素の featureAttributions フィールドには、入力データセットの各列に 1 つの値が含まれます。Vertex AI は、属性、予測時に利用可能、予測時に利用不可など、すべてのタイプの特徴の説明を生成します。attributions 要素のフィールドについて詳しくは、アトリビューションをご覧ください。

エンドポイントを削除する

undeploy_all() 関数と delete() 関数を使用して、エンドポイントを削除します。次のコードは例を示しています。

endpoint.undeploy_all()
endpoint.delete()

次のステップ