Initial
This commit is contained in:
4
bert/.formatter.exs
Normal file
4
bert/.formatter.exs
Normal file
@@ -0,0 +1,4 @@
|
||||
# Used by "mix format"
|
||||
[
|
||||
inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"]
|
||||
]
|
||||
23
bert/.gitignore
vendored
Normal file
23
bert/.gitignore
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
# The directory Mix will write compiled artifacts to.
|
||||
/_build/
|
||||
|
||||
# If you run "mix test --cover", coverage assets end up here.
|
||||
/cover/
|
||||
|
||||
# The directory Mix downloads your dependencies sources to.
|
||||
/deps/
|
||||
|
||||
# Where third-party dependencies like ExDoc output generated docs.
|
||||
/doc/
|
||||
|
||||
# If the VM crashes, it generates a dump, let's ignore it too.
|
||||
erl_crash.dump
|
||||
|
||||
# Also ignore archive artifacts (built via "mix archive.build").
|
||||
*.ez
|
||||
|
||||
# Ignore package tarball (built via "mix hex.build").
|
||||
bert-*.tar
|
||||
|
||||
# Temporary files, for example, from tests.
|
||||
/tmp/
|
||||
21
bert/README.md
Normal file
21
bert/README.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# Bert
|
||||
|
||||
**TODO: Add description**
|
||||
|
||||
## Installation
|
||||
|
||||
If [available in Hex](https://hex.pm/docs/publish), the package can be installed
|
||||
by adding `bert` to your list of dependencies in `mix.exs`:
|
||||
|
||||
```elixir
|
||||
def deps do
|
||||
[
|
||||
{:bert, "~> 0.1.0"}
|
||||
]
|
||||
end
|
||||
```
|
||||
|
||||
Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_doc)
|
||||
and published on [HexDocs](https://hexdocs.pm). Once published, the docs can
|
||||
be found at <https://hexdocs.pm/bert>.
|
||||
|
||||
4
bert/call.exs
Normal file
4
bert/call.exs
Normal file
@@ -0,0 +1,4 @@
|
||||
#elixir --sname node1 -S mix run --no-halt
|
||||
#iex --sname node2
|
||||
Node.connect(:"node1@kittykat")
|
||||
GenServer.call({BertEmbedding, :"node1@kittykat"}, {:embed, "I'm going to embed this text"})
|
||||
3
bert/config/config.exs
Normal file
3
bert/config/config.exs
Normal file
@@ -0,0 +1,3 @@
|
||||
import Config
|
||||
|
||||
config :nx, default_backend: EXLA.Backend
|
||||
156
bert/lib/bert.ex
Normal file
156
bert/lib/bert.ex
Normal file
@@ -0,0 +1,156 @@
|
||||
defmodule Bert do
|
||||
use GenServer
|
||||
require Logger
|
||||
|
||||
Nx.global_default_backend(EXLA.Backend)
|
||||
|
||||
@moduledoc """
|
||||
Documentation for `Bert`.
|
||||
"""
|
||||
|
||||
# Client API
|
||||
@doc """
|
||||
Starts the GenServer, which loads the BERT model and tokenizer.
|
||||
"""
|
||||
def start_link(opts) do
|
||||
Logger.info("Starting GenServer")
|
||||
GenServer.start_link(__MODULE__, :ok, opts)
|
||||
end
|
||||
|
||||
@doc """
|
||||
Requests the text embedding for a given string.
|
||||
"""
|
||||
def embed(server_pid, text) do
|
||||
Logger.info("Calling embed")
|
||||
GenServer.call(server_pid, {:embed, text})
|
||||
end
|
||||
|
||||
# Server Callbacks
|
||||
@impl true
|
||||
def init(:ok) do
|
||||
Logger.info("Calling init")
|
||||
|
||||
gpu = has_gpu_access?()
|
||||
|
||||
Logger.info("Has GPU? #{gpu}")
|
||||
|
||||
# TODO: Get these working as batched runs
|
||||
|
||||
# Load the model and tokenizer from the Hugging Face repository
|
||||
model_name = "google-bert/bert-base-uncased"
|
||||
{:ok, model_info} = Bumblebee.load_model({:hf, model_name}, architecture: :base)
|
||||
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
|
||||
|
||||
# Create the serving function for computing embeddings
|
||||
serving =
|
||||
Bumblebee.Text.TextEmbedding.text_embedding(
|
||||
model_info,
|
||||
tokenizer,
|
||||
defn_options: [compile: EXLA]
|
||||
)
|
||||
|
||||
# Create the serving function for computing embeddings
|
||||
serving2 =
|
||||
Bumblebee.Text.TextEmbedding.text_embedding(
|
||||
model_info,
|
||||
tokenizer,
|
||||
defn_options: [compile: EXLA],
|
||||
output_attribute: :hidden_state
|
||||
)
|
||||
|
||||
# TODO: This is a bit much currently
|
||||
if gpu do
|
||||
repo = {:hf, "microsoft/Phi-3.5-mini-instruct"}
|
||||
|
||||
{:ok, model_info} = Bumblebee.load_model(repo, backend: EXLA.Backend)
|
||||
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
|
||||
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
|
||||
|
||||
generation_config =
|
||||
Bumblebee.configure(generation_config,
|
||||
max_new_tokens: 256,
|
||||
strategy: %{type: :multinomial_sampling, top_p: 0.6}
|
||||
)
|
||||
|
||||
text_gen =
|
||||
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
|
||||
#compile: [batch_size: 1, sequence_length: 128_000],
|
||||
defn_options: [compiler: EXLA]
|
||||
)
|
||||
|
||||
# Store the serving function in the state
|
||||
{:ok, %{flat_serving: serving, all_serving: serving2, text_serving: text_gen}}
|
||||
else
|
||||
{:ok, %{flat_serving: serving, all_serving: serving2}}
|
||||
end
|
||||
end
|
||||
|
||||
@impl true
|
||||
def handle_call({:embed, text}, _from, state) do
|
||||
Logger.info("Handling single output embedding call")
|
||||
# Run the serving function with the text input
|
||||
result = Nx.Serving.run(state.flat_serving, text)
|
||||
embedding = result.embedding
|
||||
{:reply, embedding, state}
|
||||
end
|
||||
|
||||
@impl true
|
||||
def handle_call({:embed_all, text}, _from, state) do
|
||||
Logger.info("Handling all token output embedding call")
|
||||
# Run the serving function with the text input
|
||||
result = Nx.Serving.run(state.all_serving, text)
|
||||
embedding = result.embedding
|
||||
{:reply, embedding, state}
|
||||
end
|
||||
|
||||
@impl true
|
||||
def handle_call({:make_text, text}, _from, state) do
|
||||
Logger.info("Handling text gen call")
|
||||
|
||||
if has_gpu_access?() == false do
|
||||
{:reply, "No gpu", state}
|
||||
else
|
||||
# Run the serving function with the text input
|
||||
result = Nx.Serving.run(state.text_serving, text)
|
||||
{:reply, result, state}
|
||||
end
|
||||
end
|
||||
|
||||
@doc """
|
||||
Example using Axon to get all the output embeddings
|
||||
"""
|
||||
def get_embedding(text) do
|
||||
{:ok, model_info} =
|
||||
Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}, architecture: :base)
|
||||
|
||||
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"})
|
||||
inputs = Bumblebee.apply_tokenizer(tokenizer, text)
|
||||
outputs = Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0]
|
||||
outputs
|
||||
end
|
||||
|
||||
@doc """
|
||||
Return if Elixir has access to the GPU or not.
|
||||
"""
|
||||
@spec has_gpu_access? :: boolean()
|
||||
def has_gpu_access?() do
|
||||
try do
|
||||
case Nx.tensor(0) do
|
||||
# :host == CPU
|
||||
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :host}}} ->
|
||||
false
|
||||
|
||||
# :cuda == GPU
|
||||
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :cuda}}} ->
|
||||
true
|
||||
|
||||
_other ->
|
||||
false
|
||||
end
|
||||
rescue
|
||||
_exception ->
|
||||
Logger.error("Error trying to determine GPU access!")
|
||||
false
|
||||
end
|
||||
end
|
||||
end
|
||||
20
bert/lib/bert/application.ex
Normal file
20
bert/lib/bert/application.ex
Normal file
@@ -0,0 +1,20 @@
|
||||
defmodule Bert.Application do
|
||||
# See https://hexdocs.pm/elixir/Application.html
|
||||
# for more information on OTP Applications
|
||||
@moduledoc false
|
||||
|
||||
use Application
|
||||
|
||||
@impl true
|
||||
def start(_type, _args) do
|
||||
children = [
|
||||
# Starts a worker by calling: Bert.Worker.start_link(arg)
|
||||
{Bert, name: BertEmbedding}
|
||||
]
|
||||
|
||||
# See https://hexdocs.pm/elixir/Supervisor.html
|
||||
# for other strategies and supported options
|
||||
opts = [strategy: :one_for_one, name: Bert.Supervisor]
|
||||
Supervisor.start_link(children, opts)
|
||||
end
|
||||
end
|
||||
31
bert/mix.exs
Normal file
31
bert/mix.exs
Normal file
@@ -0,0 +1,31 @@
|
||||
defmodule Bert.MixProject do
|
||||
use Mix.Project
|
||||
|
||||
def project do
|
||||
[
|
||||
app: :bert,
|
||||
version: "0.1.0",
|
||||
elixir: "~> 1.18",
|
||||
start_permanent: Mix.env() == :prod,
|
||||
deps: deps()
|
||||
]
|
||||
end
|
||||
|
||||
# Run "mix help compile.app" to learn about applications.
|
||||
def application do
|
||||
[
|
||||
extra_applications: [:logger],
|
||||
mod: {Bert.Application, []}
|
||||
]
|
||||
end
|
||||
|
||||
# Run "mix help deps" to learn about dependencies.
|
||||
defp deps do
|
||||
[
|
||||
{:bumblebee, "~> 0.6.0"},
|
||||
{:exla, ">= 0.0.0"}
|
||||
# {:dep_from_hexpm, "~> 0.3.0"},
|
||||
# {:dep_from_git, git: "https://github.com/elixir-lang/my_dep.git", tag: "0.1.0"}
|
||||
]
|
||||
end
|
||||
end
|
||||
24
bert/mix.lock
Normal file
24
bert/mix.lock
Normal file
@@ -0,0 +1,24 @@
|
||||
%{
|
||||
"axon": {:hex, :axon, "0.7.0", "2e2c6d93b4afcfa812566b8922204fa022b60081e86ebd411df4db7ea30f5457", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.9", [hex: :nx, repo: "hexpm", optional: false]}, {:polaris, "~> 0.1", [hex: :polaris, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "ee9857a143c9486597ceff434e6ca833dc1241be6158b01025b8217757ed1036"},
|
||||
"bumblebee": {:hex, :bumblebee, "0.6.3", "c0028643c92de93258a9804da1d4d48797eaf7911b702464b3b3dd2cc7f938f1", [:mix], [{:axon, "~> 0.7.0", [hex: :axon, repo: "hexpm", optional: false]}, {:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.9.0 or ~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:nx_image, "~> 0.1.0", [hex: :nx_image, repo: "hexpm", optional: false]}, {:nx_signal, "~> 0.2.0", [hex: :nx_signal, repo: "hexpm", optional: false]}, {:progress_bar, "~> 3.0", [hex: :progress_bar, repo: "hexpm", optional: false]}, {:safetensors, "~> 0.1.3", [hex: :safetensors, repo: "hexpm", optional: false]}, {:tokenizers, "~> 0.4", [hex: :tokenizers, repo: "hexpm", optional: false]}, {:unpickler, "~> 0.1.0", [hex: :unpickler, repo: "hexpm", optional: false]}, {:unzip, "~> 0.12.0", [hex: :unzip, repo: "hexpm", optional: false]}], "hexpm", "c619197787561f8e5fb2ffba269c341654accaec9d591999b7fddd55761dd079"},
|
||||
"castore": {:hex, :castore, "1.0.15", "8aa930c890fe18b6fe0a0cff27b27d0d4d231867897bd23ea772dee561f032a3", [:mix], [], "hexpm", "96ce4c69d7d5d7a0761420ef743e2f4096253931a3ba69e5ff8ef1844fe446d3"},
|
||||
"complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"},
|
||||
"decimal": {:hex, :decimal, "2.3.0", "3ad6255aa77b4a3c4f818171b12d237500e63525c2fd056699967a3e7ea20f62", [:mix], [], "hexpm", "a4d66355cb29cb47c3cf30e71329e58361cfcb37c34235ef3bf1d7bf3773aeac"},
|
||||
"elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"},
|
||||
"exla": {:hex, :exla, "0.10.0", "93e7d75a774fbc06ce05b96de20c4b01bda413b315238cb3c727c09a05d2bc3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1.0", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.9.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "16fffdb64667d7f0a3bc683fdcd2792b143a9b345e4b1f1d5cd50330c63d8119"},
|
||||
"fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"},
|
||||
"jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"},
|
||||
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
|
||||
"nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"},
|
||||
"nx_image": {:hex, :nx_image, "0.1.2", "0c6e3453c1dc30fc80c723a54861204304cebc8a89ed3b806b972c73ee5d119d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "9161863c42405ddccb6dbbbeae078ad23e30201509cc804b3b3a7c9e98764b81"},
|
||||
"nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"},
|
||||
"polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"},
|
||||
"progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"},
|
||||
"rustler_precompiled": {:hex, :rustler_precompiled, "0.8.3", "4e741024b0b097fe783add06e53ae9a6f23ddc78df1010f215df0c02915ef5a8", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "c23f5f33cb6608542de4d04faf0f0291458c352a4648e4d28d17ee1098cddcc4"},
|
||||
"safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"},
|
||||
"telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"},
|
||||
"tokenizers": {:hex, :tokenizers, "0.5.1", "b0975d92b4ee5b18e8f47b5d65b9d5f1e583d9130189b1a2620401af4e7d4b35", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "5f08d97cc7f2ed3d71d370d68120da6d3de010948ccf676c9c0eb591ba4bacc9"},
|
||||
"unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"},
|
||||
"unzip": {:hex, :unzip, "0.12.0", "beed92238724732418b41eba77dcb7f51e235b707406c05b1732a3052d1c0f36", [:mix], [], "hexpm", "95655b72db368e5a84951f0bed586ac053b55ee3815fd96062fce10ce4fc998d"},
|
||||
"xla": {:hex, :xla, "0.9.1", "cca0040ff94902764007a118871bfc667f1a0085d4a5074533a47d6b58bec61e", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "eb5e443ae5391b1953f253e051f2307bea183b59acee138053a9300779930daf"},
|
||||
}
|
||||
1
bert/start.txt
Normal file
1
bert/start.txt
Normal file
@@ -0,0 +1 @@
|
||||
elixir --sname gpu -S mix run --no-halt
|
||||
8
bert/test/bert_test.exs
Normal file
8
bert/test/bert_test.exs
Normal file
@@ -0,0 +1,8 @@
|
||||
defmodule BertTest do
|
||||
use ExUnit.Case
|
||||
doctest Bert
|
||||
|
||||
test "greets the world" do
|
||||
assert Bert.hello() == :world
|
||||
end
|
||||
end
|
||||
1
bert/test/test_helper.exs
Normal file
1
bert/test/test_helper.exs
Normal file
@@ -0,0 +1 @@
|
||||
ExUnit.start()
|
||||
Reference in New Issue
Block a user