europeana-5TK1F5VfdIk-unsplash.jpg

型レベル多項式

 
0
このエントリーをはてなブックマークに追加
Kazuki Moriyama
Kazuki Moriyama (森山 和樹)

[mathjax]

dottyの柔軟な型を利用して型レベルの多項式とその一般の足し算を定義する。

多項式の定式化

まずはenumを用いて多項式の要素を型に落とし込む。
必要な要素は項とその足し算だ。

enum Polynomial:
  case X[Coef <: Int, N <: Int]()
  case Plus[XX <: X[?, ?], YY <: Polynomial]()

Xが項、Plusがその足し算を表している。

Xの型変数Coefは係数、Nは右の肩に乗数を表している。
つまりX[2, 3]は[2x^3]のことである。
infix記法を使って2 X 3のように書けばより普通の書き方の様に見える。
定数項は2 X 0の様にして表せる。

3 X 0 // 3
2 X 1 // 2x
(2 X 1) Plus (3 X 0) // 2x + 3
(1 X 2) Plus ((2 X 1) Plus (3 X 0)) // x^2 + 2x + 3

ヘルパー型

後で使いやすいようにXに対してヘルパー型を定義する。

type N[XX <: X[?, ?]] = XX match
  case ? X n => n

type C[XX <: X[?, ?]] = XX match
  case c X ? => c

それぞれXの型パラメータを抜き出す便利型である。
こんな便利型も昔はむちゃくちゃ頑張らないと書けなかったしいい時代になった。

summon[N[1 X 2] =:= 2]
summon[C[1 X 2] =:= 1]

一般の多項式同士の足し算

Plusでは一般の足し算が表現できないことは上で述べた。
ここで一般の足し算を定義したい。
更に項は乗数のオーダーでソートされてほしい。
そのような型演算は以下のようにして定義できる。

import scala.compiletime.ops.int
import scala.compiletime.ops.any.*
import Polynomial.*

type +[A <: Polynomial, B <: Polynomial] <: Polynomial = A match
  case c X n =>
    B match
      case cc X nn =>
        n == nn match
          case true => int.+(c, cc) X n
          case false =>
            int.>(n, nn) match
              case true  => A Plus B
              case false => B Plus A
      case xx Plus yy =>
        int.>(n, N[xx]) match
          case true => A Plus B
          case false =>
            n == N[xx] match
              case true  => (int.+(c, C[xx]) X n) Plus yy
              case false => xx Plus (A + yy)
  case xx Plus yy =>
    B match
      case cc X nn =>
        N[xx] == nn match
          case true => (int.+(C[xx], cc) X nn) Plus yy
          case false =>
            int.>(N[xx], nn) match
              case true  => xx Plus (B + yy)
              case false => B Plus A
      case xxx Plus yyy =>
        N[xx] == N[xxx] match
          case true => (int.+(C[xx], C[xxx]) X N[xx]) Plus (yy + yyy)
          case false =>
            int.>(N[xx], N[xxx]) match
              case true  => xx Plus (yy + B)
              case false => xxx Plus (yyy + A)

Match Typeが暴れている。
これを使えば以下のように一般の形での多項式の足し算ができる。

summon[(3 X 1) + (4 X 1) =:= (7 X 1)]
summon[(3 X 1) + (2 X 0) =:= ((3 X 1) Plus (2 X 0))]
summon[(3 X 1) + (4 X 2) =:= (4 X 2) + (3 X 1)]
summon[(3 X 2) + (4 X 1) + (2 X 0) =:= ((3 X 2) Plus ((4 X 1) Plus (2 X 0)))]
summon[(3 X 2) + ((4 X 1) + (2 X 0)) =:= (3 X 2) + (4 X

 1) + (2 X 0)]
summon[(3 X 1) + ((4 X 1) + (2 X 0)) =:= (7 X 1) + (2 X 0)]
summon[(3 X 0) + ((4 X 1) + (2 X 0)) =:= (4 X 1) + (5 X 0)]
summon[((3 X 1) + (2 X 0)) + (4 X 1) =:= (7 X 1) + (2 X 0)]
summon[((3 X 1) + (2 X 0)) + (4 X 2) =:= (4 X 2) + (3 X 1) + (2 X 0)]
summon[(3 X 1) + ((4 X 2) + (2 X 0)) =:= ((4 X 2) + (3 X 1) + (2 X 0))]
summon[
  ((5 X 3) + (3 X 1)) + ((4 X 2) + (2 X 0)) =:= (5 X 3) + (4 X 2) + (3 X 1) + (2 X 0)
]

多分あってる。

終わり

dottyの登場によって型レベル演算はscalaコンパイラを熟知したhackyなものではなく、どれだけ型レベルのifを書けるかの勝負になっている感がある。
最後に全コード。

import scala.compiletime.ops.int
import scala.compiletime.ops.any.*
import Polynomial.*

enum Polynomial:
  case X[Coef <: Int, N <: Int]()
  case Plus[XX <: X[?, ?], YY <: Polynomial]()

type N[XX <: X[?, ?]] = XX match
  case ? X n => n

type C[XX <: X[?, ?]] = XX match
  case c X ? => c

type +[A <: Polynomial, B <: Polynomial] <: Polynomial = A match
  case c X n =>
    B match
      case cc X nn =>
        n == nn match
          case true => int.+(c, cc) X n
          case false =>
            int.>(n, nn) match
              case true  => A Plus B
              case false => B Plus A
      case xx Plus yy =>
        int.>(n, N[xx]) match
          case true => A Plus B
          case false =>
            n == N[xx] match
              case true  => (int.+(c, C[xx]) X n) Plus yy
              case false => xx Plus (A + yy)
  case xx Plus yy =>
    B match
      case cc X nn =>
        N[xx] == nn match
          case true => (int.+(C[xx], cc) X nn) Plus yy
          case false =>
            int.>(N[xx], nn) match
              case true  => xx Plus (B + yy)
              case false => B Plus A
      case xxx Plus yyy =>
        N[xx] == N[xxx] match
          case true => (int.+(C[xx], C[xxx]) X N[xx]) Plus (yy + yyy)
          case false =>
            int.>(N[xx], N[xxx]) match
              case true  => xx Plus (yy + B)
              case false => xxx Plus (yyy + A)

object Polynomial:
  summon[N[1 X 2] =:= 2]
  summon[C[1 X 2] =:= 1]

  summon[(3 X 1) + (4 X 1) =:= (7 X 1)]
  summon[(3 X 1) + (2 X 0) =:= ((3 X 1) Plus (2 X 0))]
  summon[(3 X 1) + (4 X 2) =:= (4 X 2) + (3 X 1)]
  summon[(3 X 2) + (4 X 1) + (2 X 0) =:= ((3 X 2) Plus ((4 X 1) Plus (2 X 0)))]
  summon[(3 X 2) + ((4 X 1) + (2 X 0)) =:= (3 X 2) + (4 X 1) + (2 X 0)]
  summon[(3 X 1) + ((4 X 1) + (2 X 0)) =:= (7 X 1) + (2 X 0)]
  summon[(3 X 0) + ((4 X 1) + (2 X 0)) =:= (4 X 1) + (5 X 0)]
  summon[((3 X 1) + (2 X 0)) + (4 X 1) =:= (7 X 1) + (2 X 0)]
  summon[((3 X 1) + (2 X 0)) + (4 X 2) =:= (4 X 2) + (3 X 1) + (2 X 0)]
  summon[(3 X 1) + ((4 X 2) + (2 X 0)) =:= ((4 X 2) + (3 X 1) + (2 X 0))]
  summon[
    ((5 X 3) + (3 X 1)) + ((4 X 2) + (2 X 0)) =:= (5 X 3) + (4 X 2) + (3 X 1) + (2 X 0)
  ]
info-outline

お知らせ

K.DEVは株式会社KDOTにより運営されています。記事の内容や会社でのITに関わる一般的なご相談に専門の社員がお答えしております。ぜひお気軽にご連絡ください。